1 """Background thread to operate L{qubx.fit}.
2
3 Copyright 2008-2013 Research Foundation State University of New York
4 This file is part of QUB Express.
5
6 QUB Express is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
10
11 QUB Express is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License,
17 named LICENSE.txt, in the QUB Express program directory. If not, see
18 <http://www.gnu.org/licenses/>.
19
20 """
21
22 import gobject
23 import re
24 import traceback
25 import qubx.pyenv
26 import qubx.task
27 from qubx.util_types import WeakEvent, Reffer, Anon, memoize
28 from qubx.fit import *
29
31 """
32 Background thread to operate L{qubx.fit}.
33 Instead of reading instance variables, respond to events.
34 Events are fired in the gobject thread when something changes, or when you request_something().
35 Stats are given after OnEndFit, or when requested.
36
37 Stats:
38 - correlation: numpy.array(shape=(npar,npar)) of cross-correlation between -1 and 1
39 - is_pseudo: True covariance wasn't positive-definite; correlation is strictly undefined
40 - std_err_est: list of the "standard error of the estimate" of each param_val
41 - ssr: sum-squared residual
42 - r2: R-squared
43 - runs_prob: wald-wolfowitz runs probability
44
45 @ivar OnData: L{WeakEvent}(xx, yy, vvv, v_names)
46 @ivar OnWeight: L{WeakEvent}(weight_expr)
47 @ivar OnExpr: L{WeakEvent}(curve_name, curve_expr, params, param_vals, lo, hi, can_fit)
48 @ivar OnParam: L{WeakEvent}(index, name, value, lo, hi, can_fit)
49 @ivar OnMaxIter: L{WeakEvent}(max_iter)
50 @ivar OnToler: L{WeakEvent}(toler)
51 @ivar OnStrategy: L{WeakEvent}(strategy_script)
52 @ivar OnStrategyWarn: L{WeakEvent}(L{Strategy}) when curve has changed and script may reference invalid vars
53 @ivar OnOutput: L{WeakEvent}(str) duplication of stdout during strategy
54 @ivar OnFitter: L{WeakEvent}(name_of_fitter)
55 @ivar OnStats: L{WeakEvent}(correlation, is_pseudo, std_err_est[], ssr, r2, runs_prob)
56 @ivar OnStartFit: L{WeakEvent}()
57 @ivar OnIteration: L{WeakEvent}(param_vals, iteration)
58 @ivar OnStatus: L{WeakEvent}(str)
59 @ivar OnEndFit: L{WeakEvent}()
60 @ivar OnStartFit_Robot: L{WeakEvent}() called in robot thread
61 @ivar OnEndFit_Robot: L{WeakEvent}() called in robot thread
62 @ivar fitter: L{Simplex_LM_Fitter} instance; replace it if you want
63 """
65 self.__ref = Reffer()
66 qubx.task.Robot.__init__(self, label, self.__ref(lambda: qubx.task.Tasks.add_task(self)),
67 self.__ref(lambda: qubx.task.Tasks.remove_task(self)))
68 self.OnData = WeakEvent()
69 self.OnWeight = WeakEvent()
70 self.OnExpr = WeakEvent()
71 self.OnParam = WeakEvent()
72 self.OnMaxIter = WeakEvent()
73 self.OnToler = WeakEvent()
74 self.OnStrategy = WeakEvent()
75 self.OnOutput = WeakEvent()
76 self.OnFitter = WeakEvent()
77 self.OnStats = WeakEvent()
78 self.OnStartFit = WeakEvent()
79 self.OnIteration = WeakEvent()
80 self.OnStatus = WeakEvent()
81 self.OnEndFit = WeakEvent()
82 self.OnStartFit_Robot = WeakEvent()
83 self.OnEndFit_Robot = WeakEvent()
84 self.locals = Anon()
85 self.locals.lo = memoize(self.robot_locals_lo)
86 self.locals.hi = memoize(self.robot_locals_hi)
87 self.fitter = Simplex_LM_Fitter()
88 self.__xx = numpy.array([])
89 self.__yy = numpy.array([])
90 self.__ff = self.__yy
91 self.__vvv = []
92 self.__v_names = []
93 self.__weight = '1.0'
94 self.__weight_f = Curve(self.__weight, locals=self.locals.__dict__)
95 self.__weight_f.set_vars(self.__v_names)
96 self.__ww = numpy.array([])
97
98 self.__expr = 'a * x**k'
99 self.__curve = Curve(self.__expr, allow_params=True, allow_ode=True, locals=self.locals.__dict__)
100 self.__curve.set_vars(self.__v_names)
101 self.__param_vals = [0.0] * len(self.__curve.params)
102 self.__serial_stats = self.__serial_fit = 0
103 self.__strategy = Strategy(label, gobject.idle_add)
104 self.__strategy.OnChange += self.__ref(self.__onChangeStrategy)
105 self.__strategy.locals.fit = self.robot_strategy_fit
106 self.__strategy.locals.lo = self.locals.lo
107 self.__strategy.locals.hi = self.locals.hi
108 self.__strategy.locals.set_curve_expr = self.robot_set_expr
109 self.__strategy.locals.set_weight_expr = self.robot_set_weight
110 self.__strategy.locals.set_max_iter = self.robot_set_max_iter
111 self.__strategy.locals.set_toler = self.robot_set_toler
112 self.__strategy.locals.set_fitter = self.robot_set_fitter
113 self.__strategy.locals.data = Anon()
114 self.__strategy.locals.stats = Anon()
115 self.__strategy.init(self.__curve)
116 self.OnStrategyWarn = self.__strategy.OnWarn
117 self.OnException += self.__ref(self.__onException)
118 self.OnInterrupt += self.__ref(self.__onInterrupt)
119 curve = property(lambda self: self.__curve)
124 - def set_data(self, xx, yy, vvv=[], v_names=[]):
125 """Makes available the x, y and variable series. Doesn't auto-update-stats.
126
127 @param xx: numpy.array(dtype='float32') of x coords
128 @param yy: numpy.array(dtype='float32') of y coords
129 @param vvv: list of numpy.array(dtype='float32') per extra variable
130 @param v_names: list of names corresponding to the vvv; can be used in expr and weight
131 """
132 self.do(self.robot_set_data, xx, yy, vvv, v_names)
134 """Tells how to calculate each sample's weight; ssr = sum( (weight*dy)**2 ).
135 @param expr: python expression (str) in terms of x and v_names
136 """
137 self.do(self.robot_set_weight, expr)
138 - def set_curve(self, curve_class, expr=None):
139 """Changes the curve function.
140
141 @param curve_class: L{qubx.fit.BaseCurve} class, or any f(expr=default) -> subclass instance
142 """
143 self.do(self.robot_set_curve, curve_class, expr)
145 """Changes the curve function.
146 @param expr: either the right-hand side of "y=something+f(other)", or a system of ODEs; see L{acceptODEs}.
147 """
148 self.do(self.robot_set_expr, expr)
149 - def set_param(self, nm, value=None, lo=None, hi=None, can_fit=None):
150 """Changes the value, low bound, high bound, and/or can_fit of a curve param.
151 @param nm: name of curve param
152 @param value: new value, or None to leave unchanged
153 @param lo: new lower bound, or UNSET_VALUE, or None to leave unchanged
154 @param hi: new upper bound, or UNSET_VALUE, or None to leave unchanged
155 @param can_fit: True if the fitter can change it
156 """
157 self.do(self.robot_set_param, nm, value, lo, hi, can_fit)
159 """Changes the max number of param improvements per fit()."""
160 self.do(self.robot_set_max_iter, x)
171 """Replaces the fitter with fitter_class(max_iter, toler) ."""
172 self.do(self.robot_set_fitter, fitter_class)
189 """Triggers OnStats. If multiple stats requests are pending, only the newest one is honored."""
190 self.__serial_stats += 1
191 self.do(self.robot_request_stats, self.__serial_stats)
192 - def fit(self, grab_initial=True):
193 """
194 Improves param values; Triggers OnStartFit, OnIteration, ..., OnEndFit, OnParam, ..., OnStats.
195 If multiple fit requests are pending, only the newest one is honored.
196 """
197 self.__serial_fit += 1
198 self.do(self.robot_fit, self.__serial_fit, grab_initial)
199
201 gobject.idle_add(self.OnData, self.__xx, self.__yy, self.__vvv, self.__v_names)
202 gobject.idle_add(self.OnWeight, self.__weight)
204 gobject.idle_add(self.OnExpr, self.__curve.name, self.__expr, self.__curve.params[:], self.__param_vals[:],
205 self.__curve.lo[:], self.__curve.hi[:], self.__curve.can_fit[:])
207 gobject.idle_add(self.OnParam, i, self.__curve.params[i], self.__param_vals[i],
208 self.__curve.lo[i], self.__curve.hi[i], self.__curve.can_fit[i])
210 gobject.idle_add(self.OnMaxIter, self.fitter.max_iter)
211 gobject.idle_add(self.OnToler, self.fitter.toler)
213 gobject.idle_add(self.OnStrategy, self.__strategy.script)
215 gobject.idle_add(self.OnFitter, self.fitter.name)
217 if serial < self.__serial_stats: return
218 correlation, is_pseudo, std_err_est, ssr, r2, runs_prob = self.robot_do_stats()
219 gobject.idle_add(self.OnStats, correlation, is_pseudo, std_err_est, ssr, r2, runs_prob)
221 correlation, is_pseudo, std_err_est, ssr, r2 = Correlation(self.__curve, self.__param_vals, self.__xx, self.__yy, self.__vvv, self.__ww)
222 ff = self.__curve.eval(self.__param_vals, self.__xx, self.__vvv)
223 runs_prob = RunsProb(self.__yy, ff)
224 ss = self.__strategy.locals.stats
225 ss.correlation = correlation
226 ss.is_pseudo = is_pseudo
227 ss.std_err_est = std_err_est
228 ss.ssr = ssr
229 ss.r2 = r2
230 ss.runs_prob = runs_prob
231 return correlation, is_pseudo, std_err_est, ssr, r2, runs_prob
233 try:
234 name_low = name.lower()
235 if name_low == 'x':
236 return numpy.min(self.__xx)
237 elif name_low == 'y':
238 return numpy.min(self.__yy)
239 else:
240 return numpy.min(self.__vvv[self.__v_names.index(name)])
241 except:
242 return 0.0
244 try:
245 name_low = name.lower()
246 if name_low == 'x':
247 return numpy.max(self.__xx)
248 elif name_low == 'y':
249 return numpy.max(self.__yy)
250 else:
251 return numpy.max(self.__vvv[self.__v_names.index(name)])
252 except:
253 return 0.0
255 self.__xx = numpy.array(xx, dtype='float32', copy=True)
256 self.__yy = numpy.array(yy, dtype='float32', copy=True)
257 self.__vvv = [numpy.array(vv, dtype='float32', copy=True) for vv in vvv]
258 self.__v_names = v_names
259 old_params = self.robot_save_params()
260 self.__curve.set_vars(v_names)
261 self.robot_re_param(old_params)
262 self.__weight_f.set_vars(v_names)
263 if len(xx):
264 try:
265 self.__ww = self.__weight_f.eval([], xx, vvv)
266 except:
267 traceback.print_exc()
268 print v_names
269 print
270 self.__ww = numpy.array([1.0]*len(xx), dtype='float32')
271 else:
272 self.__ww = numpy.array([], dtype='float32')
273 self.locals.lo.reset()
274 self.locals.hi.reset()
275 sd = self.__strategy.locals.data
276 sd.X = xx
277 sd.Y = yy
278 sd.Series = vvv
279 sd.SeriesName = v_names
280 self.robot_request_data()
304 return dict([(self.__curve.params[i], (self.__param_vals[i], self.__curve.lo[i], self.__curve.hi[i], self.__curve.can_fit[i]))
305 for i in xrange(len(self.__param_vals))])
307 self.__param_vals = []
308 for i,p in enumerate(self.__curve.params):
309 if p in old_params:
310 val,lo,hi,can_fit = old_params[p]
311 self.__param_vals.append(val)
312 self.__curve.lo[i] = lo
313 self.__curve.hi[i] = hi
314 self.__curve.can_fit[i] = can_fit
315 elif self.__curve.param_defaults:
316 self.__param_vals.append(self.__curve.param_defaults[i])
317 else:
318 self.__param_vals.append(1.0)
319 self.robot_request_curve()
320 if self.__strategy.renit(self.__curve):
321 self.robot_request_strategy()
322 self.__strategy.locals.param_vals = self.__param_vals
355 """Fits by running strategy script; "fit" in script refers to robot_strategy_fit, below."""
356 if serial < self.__serial_fit: return
357 try:
358 self.OnStartFit_Robot()
359 gobject.idle_add(self.OnStartFit)
360 if grab_initial:
361 self.__init_vals = self.__param_vals[:]
362 qubx.pyenv.env.OnOutput += self.__ref(self.__onOutput)
363 self.__strategy.run()
364 qubx.pyenv.env.OnOutput -= self.__ref(self.__onOutput)
365 finally:
366 gobject.idle_add(self.OnEndFit)
367 self.OnEndFit_Robot()
368 for i in xrange(len(self.__param_vals)):
369 if self.__curve.can_fit[i]:
370 self.robot_request_param(i)
371 self.request_stats()
373 """Implementation of strategy.locals.fit"""
374 save_vars = self.__curve.can_fit[:]
375 try:
376 if not (variables is None):
377 self.__curve.can_fit[:] = [(pname in variables) for pname in self.__curve.params]
378 for param, val in params.iteritems():
379 ix = self.__curve.params.index(param)
380 v = val
381 if v == self.__strategy.locals.Initial:
382 v = self.__init_vals[ix]
383 if not (v is None):
384 self.__param_vals[ix] = v
385 self.__param_vals, ssr, iterations, ff = self.fitter(self.__curve, self.__param_vals, self.__xx, self.__yy, self.__vvv, self.__ww,
386 self.robot_on_iter, self.robot_on_status)
387 self.robot_do_stats()
388 finally:
389 self.__curve.can_fit[:] = save_vars
391 try:
392 gobject.idle_add(self.OnIteration, param_vals, iter)
393 return 1
394 except:
395 return 0
397 gobject.idle_add(self.OnStatus, msg)
399 traceback.print_exception(typ, val, trace)
401 self.fitter.stop(cancel)
406
407
409 """Strategy
410 module_name: for pyenv
411 run_later: facility for deferred execution such as gobject.idle_add
412
413 Manages and executes the fitting script. locals include Initial='Initial', Last=None. Required but not provided:
414 fit(variables=[par_names] or Last, {par_name}=value or Initial or Last, ...other par_names...)
415
416 Properties:
417 locals L{Anon} of names available in script; you add "fit"
418 script text of the script; empty for default
419 initial text of default script for this curve
420
421 Events:
422 OnChange(Strategy, script) script has changed
423 OnWarn(Strategy) Curve param(s) removed; possibly invalid script
424 """
425 - def __init__(self, module_name, run_later=lambda f:f()):
435 locals = property(lambda self: self.__locals)
436 script = property(lambda self: self.__script, lambda self, x: self.set_script(x))
437 initial = property(lambda self: self.__initial)
439 self.__script = (x.strip() and (x.strip()+'\n')) or self.__initial
440 self.OnChange(self, x)
441 - def init(self, curve, rewrite=True):
442 """Builds initial script from Curve; if rewrite: current = initial."""
443 self.names = curve.params[:]
444 var_i = [i for i in xrange(len(curve.params)) if curve.can_fit[i]]
445 vars = [self.names[i] for i in var_i]
446 starts = ["%s=%s" % (self.names[i], (i in var_i) and "Initial" or "Last") for i in xrange(len(curve.params))]
447 args = ["variables=%s" % vars] + starts
448 self.__initial = "fit(%s)\n" % ', '.join(args)
449 if rewrite:
450 self.script = self.__initial
451 - def renit(self, curve, attn=True):
452 """Builds initial script, rewriting if not modified, OnWarn if pars dropped."""
453 newpars = curve.params[:]
454 droppars = not all(par in newpars for par in self.names)
455 mod = self.__script != self.__initial
456 self.init(curve, not mod)
457 mod = self.__script != self.__initial
458 if mod and attn and droppars:
459 self.run_later(self.OnWarn, self)
460 - def run(self, expr=None):
468 """Appends a copy of the last line or initial script to the current script."""
469 match = re.search(r"(^fit.*\) *(\n?)$)", self.script)
470 strat = self.script
471 if match:
472 if not len(match.group(2)):
473 strat = strat + '\n'
474 strat = strat + match.group(1)
475 else:
476 strat = strat + self.__initial[:-1]
477 self.script = strat
478
479
481 import time
482 import gtk
483 gobject.threads_init()
484
485 def on_data(xx, yy, vvv, v_names):
486 print 'data changed'
487 def on_weight(weight):
488 print 'new weight: %s' % weight
489 def on_expr(expr, params, vals, lo, hi, can_fit):
490 print 'new expr: %s' % expr
491 print 'params:',params
492 print 'vals:',vals
493 def on_param(i, nm, val, lo, hi, can_fit):
494 print 'param %s = %f\t[%f..%f]%s' % (nm, val, lo, hi, can_fit and ' FITTED' or '')
495 def on_max_iter(max_iter):
496 print 'new max iter:',max_iter
497 def on_toler(toler):
498 print 'new toler:', toler
499 def on_stats(correlation, is_pseudo, std_err_est, ssr, r2, runs_prob):
500 print 'SSR:',ssr
501 print '%sCorrelation:' % (is_pseudo and 'pseudo-' or '')
502 print correlation
503 print 'Errors:',std_err_est
504 print 'r**2:',r2
505 print 'P(runs):',runs_prob
506 print
507 def on_start_fit():
508 print 'starting to fit...'
509 def on_iteration(pvals, iter):
510 print '%s\t'%iter,
511 print pvals
512 def on_status(msg):
513 print ']]',msg
514 def on_quit(*args):
515 gtk.main_quit()
516 def on_end_fit():
517 print 'done fit.'
518 def on_exception(task, typ, val, tb):
519 traceback.print_exception(typ, val, tb)
520
521 robot = FitRobot()
522 robot.OnException += on_exception
523 robot.OnData += on_data
524 robot.OnWeight += on_weight
525 robot.OnExpr += on_expr
526 robot.OnParam += on_param
527 robot.OnMaxIter += on_max_iter
528 robot.OnToler += on_toler
529 robot.OnStats += on_stats
530 robot.OnStartFit += on_start_fit
531 robot.OnIteration += on_iteration
532 robot.OnStatus += on_status
533 robot.OnEndFit += on_end_fit
534
535 xx = numpy.arange(11, dtype='float32') / 10.0
536 yy = numpy.array(xx**3, dtype='float32')
537 yy += .1 * numpy.random.randn(len(xx))
538 robot.set_data(xx, yy)
539 robot.set_weight('1.0')
540 robot.set_expr('a * x**k + b')
541 robot.set_param('a', 0.5)
542 robot.set_param('k', 0.5)
543 robot.set_param('b', -0.2)
544 robot.set_max_iter(300)
545 robot.set_toler(1e-12)
546
547 robot.fit()
548
549 robot.request_data()
550 robot.request_curve()
551 robot.request_fitparams()
552
553 yy = numpy.zeros(shape=xx.shape, dtype='float32')
554 yy[:5] = xx[:5] * 2 + 5
555 yy[5:] = xx[5:] * 3 + 5
556 yy += .1 * numpy.random.randn(len(xx))
557 mm = numpy.zeros(shape=xx.shape, dtype='float32')
558 mm[:5] = 2
559 mm[5:] = 3
560 robot.set_expr('m * x + b')
561 robot.set_data(xx, yy, [mm], ['m'])
562 robot.fit()
563 robot.request_curve()
564
565 robot.do(gobject.idle_add, on_quit)
566
567 gtk.main()
568 robot.stop()
569
570
571 if __name__ == '__main__':
572 main()
573