1 """Compiled routines for curve fitting.
2
3 Copyright 2008-2013 Research Foundation State University of New York
4 This file is part of QUB Express.
5
6 QUB Express is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
10
11 QUB Express is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License,
17 named LICENSE.txt, in the QUB Express program directory. If not, see
18 <http://www.gnu.org/licenses/>.
19
20 """
21
22 from qubx.fast.fast_utils import *
23
24
25 CurveFunc = CFUNCTYPE(None, c_int_p, c_double_p, c_float_p)
26 IterFunc = CFUNCTYPE(c_int, c_int_p, c_double_p, c_int)
27
28 qubfast.qub_lmmin_fit.argtypes = (CurveFunc, c_int, c_double_p, c_double_p, c_double_p, c_int_p,
29 c_int, c_float_p, c_float_p, IterFunc, ReportFunc, c_int_p,
30 c_int, c_double, c_float_p, c_double_p)
31 qubfast.qub_lmmin_fit.restype = c_int
32
33 -def lmmin_fit(curve, param_vals, xx, yy, ww, vvv, on_iter, on_status, max_iter, toler):
34 """
35 @return: (final param_vals, sum_sqr_residual, iterations, fit_curve_samples)
36 """
37 Nparam = len(param_vals)
38 ff = numpy.zeros(shape=xx.shape, dtype='float32')
39 stop = [False]
40 def cdata(x, typ):
41 if x is None: return x
42 return x.ctypes.data_as(typ)
43 def curve_eval(obj, params, ffptr):
44 try:
45 ff[:] = curve.eval(numpy.array([params[i] for i in xrange(Nparam)]), xx, vvv)
46 except KeyboardInterrupt:
47 stop[0] = True
48 def do_iter(obj, params, iter):
49 if stop[0]: return False
50 cont = on_iter(numpy.array([params[i] for i in xrange(Nparam)]), iter)
51 return cont
52 def do_report(msg, obj):
53 on_status(msg)
54 return 0
55 params = numpy.array(param_vals, dtype='float64')
56 lo = numpy.array(curve.lo, dtype='float64')
57 hi = numpy.array(curve.hi, dtype='float64')
58 can_fit = numpy.array(curve.can_fit, dtype='int32')
59 Ndata = len(xx)
60 ssr = c_double(0.0)
61 iterations = qubfast.qub_lmmin_fit(CurveFunc(curve_eval), Nparam, cdata(params, c_double_p),
62 cdata(lo, c_double_p), cdata(hi, c_double_p), cdata(can_fit, c_int_p),
63 Ndata, cdata(yy, c_float_p), cdata(ww, c_float_p),
64 IterFunc(do_iter), ReportFunc(do_report), None,
65 max_iter, toler, cdata(ff, c_float_p), byref(ssr))
66 return [x for x in params], ssr.value, iterations, ff
67
68
69 DFP_Func = CFUNCTYPE(c_int, c_void_p, c_double_p, c_double_p, c_double_p, c_int_p, c_int_p, c_int)
70 DFP_Check = CFUNCTYPE(c_int, c_void_p, c_int, c_int, c_int, c_double, c_double_p, c_double_p)
71
72 qubfast.qub_dfpmin.argtypes = (c_void_p, c_int, c_double_p, c_double_p, DFP_Func, DFP_Check,
73 c_int, c_double, c_double, c_double, c_int_p)
74 qubfast.qub_dfpmin.restype = c_double
75
76
77 -def dfpmin(pars, func, on_iter, max_iter=100, conv=1e-5, conv_grad=1e-5, step_max=1.0):
78 try:
79 nx = max(pars.shape)
80 except:
81 nx = len(pars)
82 z = numpy.array(pars, dtype='float64')
83 H = numpy.zeros(shape=(nx, nx), dtype='float64')
84 out_grads = numpy.zeros(shape=(nx,), dtype='float64')
85 err_out = c_int(0)
86 interrupted = [False]
87
88 def dfp_func(obj, pars, grads, p_ll, p_nf, p_ndf, do_grads):
89 try:
90 p_ll[0] = func(pars, grads, do_grads)
91 return 0
92 except KeyboardInterrupt:
93 interrupted[0] = True
94 except:
95 traceback.print_exc()
96 return -1
97 def dfp_check(obj, iters, nf, ndf, ll, pars, grads):
98 try:
99 if interrupted[0]:
100 raise KeyboardInterrupt
101 out_grads[:] = grads[:nx]
102 return on_iter(iters, nf, ndf, ll, pars, grads)
103 except KeyboardInterrupt:
104 return -23
105 except:
106 traceback.print_exc()
107 return -1
108
109 minval = qubfast.qub_dfpmin( c_void_p(), len(pars), z.ctypes.data_as(c_double_p),
110 H.ctypes.data_as(c_double_p), DFP_Func(dfp_func), DFP_Check(dfp_check),
111 max_iter, conv, conv_grad, step_max, byref(err_out) )
112 return minval, err_out.value, out_grads, H
113