1 """Background thread to operate L{qubx.fit}.
2
3 Copyright 2008-2013 Research Foundation State University of New York
4 This file is part of QUB Express.
5
6 QUB Express is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
10
11 QUB Express is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License,
17 named LICENSE.txt, in the QUB Express program directory. If not, see
18 <http://www.gnu.org/licenses/>.
19
20 """
21
22 import gobject
23 import os
24 import re
25 import traceback
26 import multiprocessing
27 import platform
28 import __main__
29 import qubx.pyenv
30 import qubx.task
31 from qubx.util_types import WeakEvent, Reffer, Anon, memoize
32 from qubx.fit import *
33
35 """
36 Background thread with multiprocessing.pool to operate L{qubx.fit}.
37 Use via FitSession; Call L{InitFitRobots}() before starting any sessions.
38 """
46 traceback.print_exception(typ, val, trace)
48 if self.fitter:
49 self.fitter.stop(cancel)
50
51 robots = None
52
54 try:
55 exec(globals_script, __main__.__dict__)
56 exec(globals_script, qubx.global_namespace.__dict__)
57 except:
58 return traceback.format_exc()
59 return str(dir(qubx.global_namespace))
60
63
71
72
73
75 """
76 Background thread to operate L{qubx.fit}.
77 Instead of reading instance variables, respond to events.
78 Events are fired in the gobject thread when something changes, or when you request_something().
79 Stats are given after OnEndFit, or when requested.
80
81 Stats:
82 - correlation: numpy.array(shape=(npar,npar)) of cross-correlation between -1 and 1
83 - is_pseudo: True covariance wasn't positive-definite; correlation is strictly undefined
84 - std_err_est: list of the "standard error of the estimate" of each param_val
85 - ssr: sum-squared residual
86 - r2: R-squared
87 - runs_prob: wald-wolfowitz runs probability
88
89 @ivar OnData: L{WeakEvent}(xx, yy, vvv, v_names)
90 @ivar OnWeight: L{WeakEvent}(weight_expr)
91 @ivar OnExpr: L{WeakEvent}(curve_name, curve_expr, params, param_vals, lo, hi, can_fit)
92 @ivar OnParam: L{WeakEvent}(index, name, value, lo, hi, can_fit)
93 @ivar OnMaxIter: L{WeakEvent}(max_iter)
94 @ivar OnToler: L{WeakEvent}(toler)
95 @ivar OnStrategy: L{WeakEvent}(strategy_script)
96 @ivar OnStrategyWarn: L{WeakEvent}(L{Strategy}) when curve has changed and script may reference invalid vars
97 @ivar OnOutput: L{WeakEvent}(str) duplication of stdout during strategy
98 @ivar OnFitter: L{WeakEvent}(name_of_fitter)
99 @ivar OnStats: L{WeakEvent}(correlation, is_pseudo, std_err_est[], ssr, r2, runs_prob)
100 @ivar OnStartFit: L{WeakEvent}()
101 @ivar OnIteration: L{WeakEvent}(param_vals, iteration)
102 @ivar OnStatus: L{WeakEvent}(str)
103 @ivar OnEndFit: L{WeakEvent}()
104 @ivar OnStartFit_Robot: L{WeakEvent}() called in robot thread
105 @ivar OnEndFit_Robot: L{WeakEvent}() called in robot thread
106 @ivar fitter: L{Simplex_LM_Fitter} instance; replace it if you want
107 """
108 - def __init__(self, label='Fit', custom_robots=None):
164 curve = property(lambda self: self.__curve)
169
176 - def do(self, *args, **kw):
179 self.__curve.pool = self.robots.pool
180 - def set_data(self, xx, yy, vvv=[], v_names=[]):
181 """Makes available the x, y and variable series. Doesn't auto-update-stats.
182
183 @param xx: numpy.array(dtype='float32') of x coords
184 @param yy: numpy.array(dtype='float32') of y coords
185 @param vvv: list of numpy.array(dtype='float32') per extra variable
186 @param v_names: list of names corresponding to the vvv; can be used in expr and weight
187 """
188 self.robots.do(self.robot_set_data, xx, yy, vvv, v_names)
190 """Tells how to calculate each sample's weight; ssr = sum( (weight*dy)**2 ).
191 @param expr: python expression (str) in terms of x and v_names
192 """
193 self.robots.do(self.robot_set_weight, expr)
194 - def set_curve(self, curve_class, expr=None):
195 """Changes the curve function.
196
197 @param curve_class: L{qubx.fit.BaseCurve} class, or any f(expr=default) -> subclass instance
198 """
199 self.robots.do(self.robot_set_curve, curve_class, expr)
201 """Changes the curve function.
202 @param expr: either the right-hand side of "y=something+f(other)", or a system of ODEs; see L{acceptODEs}.
203 """
204 self.robots.do(self.robot_set_expr, expr)
205 - def set_param(self, nm, value=None, lo=None, hi=None, can_fit=None):
206 """Changes the value, low bound, high bound, and/or can_fit of a curve param.
207 @param nm: name of curve param
208 @param value: new value, or None to leave unchanged
209 @param lo: new lower bound, or UNSET_VALUE, or None to leave unchanged
210 @param hi: new upper bound, or UNSET_VALUE, or None to leave unchanged
211 @param can_fit: True if the fitter can change it
212 """
213 self.robots.do(self.robot_set_param, nm, value, lo, hi, can_fit)
227 """Replaces the fitter with fitter_class(max_iter, toler) ."""
228 self.robots.do(self.robot_set_fitter, fitter_class)
245 """Triggers OnStats. If multiple stats requests are pending, only the newest one is honored."""
246 self.__serial_stats += 1
247 self.robots.do(self.robot_request_stats, self.__serial_stats)
248 - def fit(self, grab_initial=True):
249 """
250 Improves param values; Triggers OnStartFit, OnIteration, ..., OnEndFit, OnParam, ..., OnStats.
251 If multiple fit requests are pending, only the newest one is honored.
252 """
253 self.__serial_fit += 1
254 self.robots.do(self.robot_fit, self.__serial_fit, grab_initial)
255
257 gobject.idle_add(self.OnData, self.__xx, self.__yy, self.__vvv, self.__v_names)
258 gobject.idle_add(self.OnWeight, self.__weight)
260 gobject.idle_add(self.OnExpr, self.__curve.name, self.__expr, self.__curve.params[:], self.__param_vals[:],
261 self.__curve.lo[:], self.__curve.hi[:], self.__curve.can_fit[:])
263 gobject.idle_add(self.OnParam, i, self.__curve.params[i], self.__param_vals[i],
264 self.__curve.lo[i], self.__curve.hi[i], self.__curve.can_fit[i])
266 gobject.idle_add(self.OnMaxIter, self.fitter.max_iter)
267 gobject.idle_add(self.OnToler, self.fitter.toler)
269 gobject.idle_add(self.OnStrategy, self.__strategy.script)
271 gobject.idle_add(self.OnFitter, self.fitter.name)
273 if serial < self.__serial_stats: return
274 correlation, is_pseudo, std_err_est, ssr, r2, runs_prob = self.robot_do_stats()
275 gobject.idle_add(self.OnStats, correlation, is_pseudo, std_err_est, ssr, r2, runs_prob)
277 correlation, is_pseudo, std_err_est, ssr, r2 = Correlation(self.__curve, self.__param_vals, self.__xx, self.__yy, self.__vvv, self.__ww)
278 ff = self.__curve.eval(self.__param_vals, self.__xx, self.__vvv)
279 runs_prob = RunsProb(self.__yy, ff)
280 ss = self.__strategy.locals.stats
281 ss.correlation = correlation
282 ss.is_pseudo = is_pseudo
283 ss.std_err_est = std_err_est
284 ss.ssr = ssr
285 ss.r2 = r2
286 ss.runs_prob = runs_prob
287 return correlation, is_pseudo, std_err_est, ssr, r2, runs_prob
289 try:
290 name_low = name.lower()
291 if name_low == 'x':
292 return numpy.min(self.__xx)
293 elif name_low == 'y':
294 return numpy.min(self.__yy)
295 else:
296 return numpy.min(self.__vvv[self.__v_names.index(name)])
297 except:
298 return 0.0
300 try:
301 name_low = name.lower()
302 if name_low == 'x':
303 return numpy.max(self.__xx)
304 elif name_low == 'y':
305 return numpy.max(self.__yy)
306 elif self.__v_names:
307 return numpy.max(self.__vvv[self.__v_names.index(name)])
308 except:
309 print 'Check the variable name: "%s" not in %s' % (name, self.__v_names)
310 return 0.0
312 self.__xx = numpy.array(xx, dtype='float32', copy=True)
313 self.__yy = numpy.array(yy, dtype='float32', copy=True)
314 self.__vvv = [numpy.array(vv, dtype='float32', copy=True) for vv in vvv]
315 self.__v_names = v_names
316 old_params = self.robot_save_params()
317 self.__curve.set_vars(v_names)
318 if len(xx):
319 self.robot_re_param(old_params)
320 self.__weight_f.set_vars(v_names+['y'])
321 try:
322 self.__ww = self.__weight_f.eval([], xx, vvv+[yy])
323 except:
324 traceback.print_exc()
325 print v_names
326 print
327 self.__ww = numpy.array([1.0]*len(xx), dtype='float32')
328 else:
329 self.__ww = numpy.array([], dtype='float32')
330 self.locals.lo.reset()
331 self.locals.hi.reset()
332 sd = self.__strategy.locals.data
333 sd.X = xx
334 sd.Y = yy
335 sd.Series = vvv
336 sd.SeriesName = v_names
337 self.robot_request_data()
362 return dict([(self.__curve.params[i], (self.__param_vals[i], self.__curve.lo[i], self.__curve.hi[i], self.__curve.can_fit[i]))
363 for i in xrange(len(self.__param_vals))])
365 self.__param_vals = []
366 for i,p in enumerate(self.__curve.params):
367 if p in old_params:
368 val,lo,hi,can_fit = old_params[p]
369 self.__param_vals.append(val)
370 self.__curve.lo[i] = lo
371 self.__curve.hi[i] = hi
372 self.__curve.can_fit[i] = can_fit
373 elif self.__curve.param_defaults:
374 self.__param_vals.append(self.__curve.param_defaults[i])
375 else:
376 self.__param_vals.append(1.0)
377 self.robot_request_curve()
378 if self.__strategy.renit(self.__curve):
379 self.robot_request_strategy()
380 self.__strategy.locals.param_vals = self.__param_vals
413 """Fits by running strategy script; "fit" in script refers to robot_strategy_fit, below."""
414 if serial < self.__serial_fit: return
415 try:
416 self.robots.OnInterrupt += self.OnInterrupt
417 self.OnStartFit_Robot()
418 self.robots.fitter = self.fitter
419 gobject.idle_add(self.OnStartFit)
420 if grab_initial:
421 self.__init_vals = self.__param_vals[:]
422 qubx.pyenv.env.OnOutput += self.__ref(self.__onOutput)
423 self.__strategy.run()
424 qubx.pyenv.env.OnOutput -= self.__ref(self.__onOutput)
425 finally:
426 gobject.idle_add(self.OnEndFit)
427 self.OnEndFit_Robot()
428 self.robots.OnInterrupt -= self.OnInterrupt
429 self.robots.fitter = None
430 for i in xrange(len(self.__param_vals)):
431 if self.__curve.can_fit[i]:
432 self.robot_request_param(i)
433 self.request_stats()
435 """Implementation of strategy.locals.fit"""
436 save_vars = self.__curve.can_fit[:]
437 try:
438 if not (variables is None):
439 self.__curve.can_fit[:] = [(pname in variables) for pname in self.__curve.params]
440 for param, val in params.iteritems():
441 try:
442 ix = self.__curve.params.index(param)
443 v = val
444 if v == self.__strategy.locals.Initial:
445 v = self.__init_vals[ix]
446 if not (v is None):
447 self.__param_vals[ix] = v
448 except ValueError:
449 print 'Warning: strategy variable "%s" is not a curve parameter.' % param
450 self.__param_vals, ssr, iterations, ff = self.fitter(self.__curve, self.__param_vals, self.__xx, self.__yy, self.__vvv, self.__ww,
451 self.robot_on_iter, self.robot_on_status)
452 self.robot_do_stats()
453 finally:
454 self.__curve.can_fit[:] = save_vars
456 try:
457 gobject.idle_add(self.OnIteration, param_vals, iter)
458 return 1
459 except:
460 return 0
462 gobject.idle_add(self.OnStatus, msg)
467
468
470 """Strategy
471 module_name: for pyenv
472 run_later: facility for deferred execution such as gobject.idle_add
473
474 Manages and executes the fitting script. locals include Initial='Initial', Last=None. Required but not provided:
475 fit(variables=[par_names] or Last, {par_name}=value or Initial or Last, ...other par_names...)
476
477 Properties:
478 locals L{Anon} of names available in script; you add "fit"
479 script text of the script; empty for default
480 initial text of default script for this curve
481
482 Events:
483 OnChange(Strategy, script) script has changed
484 OnWarn(Strategy) Curve param(s) removed; possibly invalid script
485 """
486 - def __init__(self, module_name, run_later=lambda f:f()):
496 locals = property(lambda self: self.__locals)
497 script = property(lambda self: self.__script, lambda self, x: self.set_script(x))
498 initial = property(lambda self: self.__initial)
500 self.__script = (x.strip() and (x.strip()+'\n')) or self.__initial
501 self.OnChange(self, x)
502 - def init(self, curve, rewrite=True):
503 """Builds initial script from Curve; if rewrite: current = initial."""
504 self.names = curve.params[:]
505 var_i = [i for i in xrange(len(curve.params)) if curve.can_fit[i]]
506 vars = [self.names[i] for i in var_i]
507 starts = ["%s=%s" % (self.names[i], (i in var_i) and "Initial" or "Last") for i in xrange(len(curve.params))]
508 args = ["variables=%s" % vars] + starts
509 self.__initial = "fit(%s)\n" % ', '.join(args)
510 if rewrite:
511 self.script = self.__initial
512 - def renit(self, curve, attn=True):
513 """Builds initial script, rewriting if not modified, OnWarn if pars dropped."""
514 newpars = curve.params[:]
515 droppars = not all(par in newpars for par in self.names)
516 mod = self.__script != self.__initial
517 self.init(curve, not mod)
518 mod = self.__script != self.__initial
519 if mod and attn and droppars:
520 self.run_later(self.OnWarn, self)
521 - def run(self, expr=None):
529 """Appends a copy of the last line or initial script to the current script."""
530 match = re.search(r"(^fit.*\) *(\n?)$)", self.script)
531 strat = self.script
532 if match:
533 if not len(match.group(2)):
534 strat = strat + '\n'
535 strat = strat + match.group(1)
536 else:
537 strat = strat + self.__initial[:-1]
538 self.script = strat
539
540
542 import time
543 import gtk
544 gobject.threads_init()
545 qubx.pyenv.Init(os.path.join(os.path.expanduser('~'), '.qubx-test'), gobject.idle_add, gobject.timeout_add, gobject.MainLoop, lambda: gtk.main_iteration(False), '.')
546
547 InitFitRobots()
548
549 def on_data(xx, yy, vvv, v_names):
550 print 'data changed'
551 def on_weight(weight):
552 print 'new weight: %s' % weight
553 def on_expr(name, expr, params, vals, lo, hi, can_fit):
554 print 'new expr: %s' % expr
555 print 'params:',params
556 print 'vals:',vals
557 def on_param(i, nm, val, lo, hi, can_fit):
558 print 'param %s = %f\t[%f..%f]%s' % (nm, val, lo, hi, can_fit and ' FITTED' or '')
559 def on_max_iter(max_iter):
560 print 'new max iter:',max_iter
561 def on_toler(toler):
562 print 'new toler:', toler
563 def on_stats(correlation, is_pseudo, std_err_est, ssr, r2, runs_prob):
564 print 'SSR:',ssr
565 print '%sCorrelation:' % (is_pseudo and 'pseudo-' or '')
566 print correlation
567 print 'Errors:',std_err_est
568 print 'r**2:',r2
569 print 'P(runs):',runs_prob
570 print
571 def on_start_fit():
572 print 'starting to fit...'
573 def on_iteration(pvals, iter):
574 print '%s\t'%iter,
575 print pvals
576 def on_status(msg):
577 print ']]',msg
578 def on_quit(*args):
579 gtk.main_quit()
580 def on_end_fit():
581 print 'done fit.'
582 def on_exception(task, typ, val, tb):
583 traceback.print_exception(typ, val, tb)
584
585 robots.OnException += on_exception
586 robot = FitSession()
587 robot.OnData += on_data
588 robot.OnWeight += on_weight
589 robot.OnExpr += on_expr
590 robot.OnParam += on_param
591 robot.OnMaxIter += on_max_iter
592 robot.OnToler += on_toler
593 robot.OnStats += on_stats
594 robot.OnStartFit += on_start_fit
595 robot.OnIteration += on_iteration
596 robot.OnStatus += on_status
597 robot.OnEndFit += on_end_fit
598
599 xx = numpy.arange(11, dtype='float32') / 10.0
600 yy = numpy.array(xx**3, dtype='float32')
601 yy += .1 * numpy.random.randn(len(xx))
602 robot.set_data(xx, yy)
603 robot.set_weight('1.0')
604 robot.set_expr('a * x**k + b')
605 robot.set_param('a', 0.5)
606 robot.set_param('k', 0.5)
607 robot.set_param('b', -0.2)
608 robot.set_max_iter(300)
609 robot.set_toler(1e-12)
610
611 robot.fit()
612
613 robot.request_data()
614 robot.request_curve()
615 robot.request_fitparams()
616
617 yy = numpy.zeros(shape=xx.shape, dtype='float32')
618 yy[:5] = xx[:5] * 2 + 5
619 yy[5:] = xx[5:] * 3 + 5
620 yy += .1 * numpy.random.randn(len(xx))
621 mm = numpy.zeros(shape=xx.shape, dtype='float32')
622 mm[:5] = 2
623 mm[5:] = 3
624 robot.set_expr('m * x + b')
625 robot.set_data(xx, yy, [mm], ['m'])
626 robot.fit()
627 robot.request_curve()
628
629 robots.do(gobject.idle_add, on_quit)
630
631 gtk.main()
632 robots.stop()
633
634
635 if __name__ == '__main__':
636 main()
637