Package qubx :: Package fast :: Module fit
[hide private]
[frames] | no frames]

Source Code for Module qubx.fast.fit

  1  """Compiled routines for curve fitting. 
  2   
  3  Copyright 2008-2013 Research Foundation State University of New York  
  4  This file is part of QUB Express.                                           
  5   
  6  QUB Express is free software; you can redistribute it and/or modify           
  7  it under the terms of the GNU General Public License as published by  
  8  the Free Software Foundation, either version 3 of the License, or     
  9  (at your option) any later version.                                   
 10   
 11  QUB Express is distributed in the hope that it will be useful,                
 12  but WITHOUT ANY WARRANTY; without even the implied warranty of        
 13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         
 14  GNU General Public License for more details.                          
 15   
 16  You should have received a copy of the GNU General Public License,    
 17  named LICENSE.txt, in the QUB Express program directory.  If not, see         
 18  <http://www.gnu.org/licenses/>.                                       
 19   
 20  """ 
 21   
 22  from qubx.fast.fast_utils import * 
 23   
 24   
 25  CurveFunc = CFUNCTYPE(None, c_int_p, c_double_p, c_float_p) 
 26  IterFunc = CFUNCTYPE(c_int, c_int_p, c_double_p, c_int) 
 27   
 28  qubfast.qub_lmmin_fit.argtypes = (CurveFunc, c_int, c_double_p, c_double_p, c_double_p, c_int_p, 
 29                                    c_int, c_float_p, c_float_p, IterFunc, ReportFunc, c_int_p, 
 30                                    c_int, c_double, c_float_p, c_double_p) 
 31  qubfast.qub_lmmin_fit.restype = c_int 
 32   
33 -def lmmin_fit(curve, param_vals, xx, yy, ww, vvv, on_iter, on_status, max_iter, toler):
34 """ 35 @return: (final param_vals, sum_sqr_residual, iterations, fit_curve_samples) 36 """ 37 Nparam = len(param_vals) 38 ff = numpy.zeros(shape=xx.shape, dtype='float32') 39 stop = [False] 40 def cdata(x, typ): 41 if x is None: return x 42 return x.ctypes.data_as(typ)
43 def curve_eval(obj, params, ffptr): 44 try: 45 ff[:] = curve.eval(numpy.array([params[i] for i in xrange(Nparam)]), xx, vvv) 46 except KeyboardInterrupt: 47 stop[0] = True 48 def do_iter(obj, params, iter): 49 if stop[0]: return False 50 cont = on_iter(numpy.array([params[i] for i in xrange(Nparam)]), iter) 51 return cont 52 def do_report(msg, obj): 53 on_status(msg) 54 return 0 55 params = numpy.array(param_vals, dtype='float64') 56 lo = numpy.array(curve.lo, dtype='float64') 57 hi = numpy.array(curve.hi, dtype='float64') 58 can_fit = numpy.array(curve.can_fit, dtype='int32') 59 Ndata = len(xx) 60 ssr = c_double(0.0) 61 iterations = qubfast.qub_lmmin_fit(CurveFunc(curve_eval), Nparam, cdata(params, c_double_p), 62 cdata(lo, c_double_p), cdata(hi, c_double_p), cdata(can_fit, c_int_p), 63 Ndata, cdata(yy, c_float_p), cdata(ww, c_float_p), 64 IterFunc(do_iter), ReportFunc(do_report), None, 65 max_iter, toler, cdata(ff, c_float_p), byref(ssr)) 66 return [x for x in params], ssr.value, iterations, ff 67 68 69 DFP_Func = CFUNCTYPE(c_int, c_void_p, c_double_p, c_double_p, c_double_p, c_int_p, c_int_p, c_int) 70 DFP_Check = CFUNCTYPE(c_int, c_void_p, c_int, c_int, c_int, c_double, c_double_p, c_double_p) 71 72 qubfast.qub_dfpmin.argtypes = (c_void_p, c_int, c_double_p, c_double_p, DFP_Func, DFP_Check, 73 c_int, c_double, c_double, c_double, c_int_p) 74 qubfast.qub_dfpmin.restype = c_double 75 76
77 -def dfpmin(pars, func, on_iter, max_iter=100, conv=1e-5, conv_grad=1e-5, step_max=1.0):
78 try: 79 nx = max(pars.shape) 80 except: 81 nx = len(pars) 82 z = numpy.array(pars, dtype='float64') 83 H = numpy.zeros(shape=(nx, nx), dtype='float64') 84 out_grads = numpy.zeros(shape=(nx,), dtype='float64') 85 err_out = c_int(0) 86 interrupted = [False] 87 88 def dfp_func(obj, pars, grads, p_ll, p_nf, p_ndf, do_grads): 89 try: 90 p_ll[0] = func(pars, grads, do_grads) 91 return 0 92 except KeyboardInterrupt: 93 interrupted[0] = True 94 except: 95 traceback.print_exc() 96 return -1
97 def dfp_check(obj, iters, nf, ndf, ll, pars, grads): 98 try: 99 if interrupted[0]: 100 raise KeyboardInterrupt 101 out_grads[:] = grads[:nx] 102 return on_iter(iters, nf, ndf, ll, pars, grads) 103 except KeyboardInterrupt: 104 return -23 105 except: 106 traceback.print_exc() 107 return -1 108 109 minval = qubfast.qub_dfpmin( c_void_p(), len(pars), z.ctypes.data_as(c_double_p), 110 H.ctypes.data_as(c_double_p), DFP_Func(dfp_func), DFP_Check(dfp_check), 111 max_iter, conv, conv_grad, step_max, byref(err_out) ) 112 return minval, err_out.value, out_grads, H 113