#/* Copyright 2008-2014 Research Foundation State University of New York   */

#/* This file is part of QUB Express.                                      */

#/* QUB Express is free software; you can redistribute it and/or modify    */
#/* it under the terms of the GNU General Public License as published by   */
#/* the Free Software Foundation, either version 3 of the License, or      */
#/* (at your option) any later version.                                    */

#/* QUB Express is distributed in the hope that it will be useful,         */
#/* but WITHOUT ANY WARRANTY; without even the implied warranty of         */
#/* GNU General Public License for more details.                           */

#/* You should have received a copy of the GNU General Public License,     */
#/* named LICENSE.txt, in the QUB Express program directory.  If not, see  */
#/* <>.                                        */

import collections
import itertools
from numpy import *
import scipy
import scipy.linalg
import scipy.optimize
import sys
import traceback
import time
from import *
from import EigenFunc, Eigen_Scipy_Ptr
from qubx.model_constraints import *

# ============= The Data =========================

# The data consist of one or more parallel signals.  The first (index 0) is the Markovian one to be analyzed.
# Each additional signal describes a ligand or voltage variable.
# A signal is a list of segments.  A segment of consists of "events" ("dwells").
# An event has duration in milliseconds, and class -- index of its measured amplitude.
# For each segment there's a list of class amplitude values.
# For each signal you provide a list of segments: [(classes[], durations[], amp_of_cls[]), ...]

# If the model depends on a variable which is held constant, provide a constant signal.
# This function copies the segmentation of a template signal:

def ConstantSignal(template, value):
    return [([0], [sum(durations)], [value]) for classes, durations, amps in template]

# We will multiplex the signals into a single idealized stream, whose classes are multi-classes (mcls):
#    mcls[i] = (model_class, stimulus_class)
# The stimulus classes (scls) are tuples with the value of each signal (signal 0 is ignored).

def MultiplexSignals(signals):
    """MultiplexSignals([(classes_0_seg0, durations_0_seg0, amp_of_cls_0_seg0), ...seg1...], ..._1_...)

Returns segments, multi_classes, stimulus_classes
where   segments         = [(mclasses_seg0, durations_seg0), (mclasses_seg1, durations_seg1), ...]
        multi_classes    = [(model_class_of_0, stimulus_class_of_0), (model_class_of_1, stimulus_class_of_1), ...],
        stimulus_classes = [(signal_0_scls_0, signal_1_scls_0, ...), (signal_0_scls_1, signal_1_scls_1, ...), ...]
    # assuming each signal has the same segmentation, collect all the signals for each segment together
    stacked_segs = [ [seg] for seg in signals[0] ]
    nseg = len(stacked_segs)
    raw_segs = [segs for segs in signals if len(segs) == nseg]
    for segs in raw_segs[1:]:
        for j, seg in enumerate(segs):
    # We use the same multi- and stimulus-classes for all segments, numbered in
    # the order of appearance.  These dicts start counting unique indices from 0.
    # At the end we'll reverse them into lists, and the unique index will be the location.
    # sclass keys are tuples of float (amp_of_signal_i)
    sclass = collections.defaultdict(itertools.count().next)
    # mclass keys are pairs of integers (model_class, sclass)
    mclass = collections.defaultdict(itertools.count().next)

    def MplexEvents(stack):  # Called upon each stacked segment
        """MplexEvents([(classes_sig_0, durations_sig_0, amps_sig_0), ...])
        Finds multiplex and stimulus classes (populates sclass and mclass).
        Returns parallel arrays (multiclasses, durations)."""
        for i, cda in enumerate(stack):
            if (i>0) and (cda[0][0] < 0):
                raise Exception("Can't process events: one or more segments starts with a gap (excluded region).")
            N = [len(x[0]) for x in stack]                                  #  number of events on signal i
            finger = [0 for x in stack]                                     #  index of current event on signal i
            clss = [classes[0] for classes, durations, amp in stack]        #  class of current event on signal i
            amps = [amp[classes[0]] if (classes[0] >= 0) else float(classes[0]) for classes, durations, amp in stack]   #  amp of current event on signal i
            remain = [durations[0] for classes, durations, amp in stack]    #  time remaining in event on signal i
        except IndexError:
            return [], []
        cl = []                            # outputs
        du = []
        while True: # all(f<n for f,n in itertools.izip(finger, N)):
            # loop until any signal ends
            done = False
            for i, n in enumerate(N):
                if finger[i] >= n:
                    done = True
            if done: break
            # find the signal with shortest remaining
            changing = argmin(remain)
            tm = remain[changing]
            # look up the current stimulus- and multi-classes
            sc = sclass[tuple([0]+amps[1:])]
            mc = mclass[(clss[0],sc)]
            # append the event
            # shorten all signals' remain time, and move the finger for any that are 0
            for i, t in enumerate(remain):
                remain[i] = t - tm
                if remain[i] <= 0.0:
                    finger[i] += 1
                    if finger[i] < N[i]:
                        remain[i] = stack[i][1][finger[i]]
                        clss[i]   = stack[i][0][finger[i]]
                        if clss[i] >= 0:
                            amps[i]   = stack[i][2][clss[i]]
                        # else: gap: hold over last event's amp
        return cl, du
    mplex_segs = [MplexEvents(seg) for seg in stacked_segs]

    # convert the class dictionaries to lists:
    Nsc = len(sclass)
    amps_of_sc = [None] * Nsc
    for amps, sc in sclass.iteritems():
        amps_of_sc[sc] = amps
    Nmc = len(mclass)
    mclass_lst = [None]*len(mclass)
    for tup, mc in mclass.iteritems():
        mclass_lst[mc] = tup
    return mplex_segs, mclass_lst, amps_of_sc # [(classes, durations)] * nseg, [(I_cls, S_cls)] * ncls, {S_cls : {Name : amp}}

# Then we apply the dead time t, concatenating any too-short event into the previous one:

def ApplyDeadTime(classes, durations, tdead):
    """ApplyDeadTime(classes, durations, tdead)

    classes:      list of each event's class index
    durations:    list of each event's duration, in milliseconds
    tdead:        length of shortest perceptible event
    Returns: (classes, durations) with
                 * events shorter than tdead joined to the prior event
                 * remaining events shortened by tdead
                   (because the forward likelihood is calculated separately for
                    the first (duration-tdead) "dwell time" and the final
                    tdead of "transition time.")"""
    # copy the arrays and modify them in-place
    cl = array(classes, copy=True, dtype='int32')
    du = array(durations, copy=True)
    tm = 0.0
    L = len(du)
    if L == 0:
        return cl, du
    # skip initial too-shorts
    first_alive = 0
    while first_alive < L:
        if (du[first_alive] < tdead) and (fulpdiff(du[first_alive], tdead) > MAX_ULP_DIFF):
            tm += du[first_alive]
            first_alive += 1
            # du[first_alive] += tm # uncomment to join them to the first long-enough event
    i_rd, i_wr = first_alive, 0
    while i_rd < L:
        cls, tm = cl[i_rd], du[i_rd]
        if i_wr and ((cl[i_wr-1] == cls) or ((tm < tdead) and (fulpdiff(tm, tdead) > MAX_ULP_DIFF))):
            du[i_wr-1] += tm
            cl[i_wr] = cls
            du[i_wr] = max(0.0, tm - tdead)
            i_wr += 1
        i_rd += 1
    return cl, du

# the total data processing is to multiplex, apply dead time,
# and convert durations to seconds, for compatibility with Q-math

def ProcessSignals(tdead, signals):
    segs, mcls, scls = MultiplexSignals(signals)
    segs = [ApplyDeadTime(classes, durations, tdead) for classes, durations in segs]
    segs = [(classes, 1e-3*array(durations, dtype='float32')) for classes, durations in segs]
    return segs, mcls, scls

# ============= The Model ====================

# Our Markov model is a graph with colored vertices.  A vertex is called a "state,"
# and its color is a nonnegative integer called its "class."  States in the same
# class are indistinguishable (same amp).  To describe the vertices, provide the array
#    clazz = array([class of state s for s in len(states)])
#    Ns    = len(clazz)

# Each edge is labeled with its transition rate (probability per second).  These form the matrix
#    Q, a Ns x Ns matrix with
#    Q[a,b] = rate from state a to state b
#    Q[a,a] = - sum(Q[a,:])

# Each Q[a,b] is computed by q = k0 * ligand * e**(k1 * voltage + k2 * pressure).  You provide the Ns x Ns matrices
#    K0, K1, K2   of kinetic parameters
#    L, V, P      index of the ligand or voltage signal influencing each rate, or 0
# with the diagonals undefined.

# A pair of states is either connected in both directions or neither.  To indicate un-connectedness, set
#    K0[a,b] = K0[b,a] = 0.0

def BuildQ(K0, K1, K2, L, V, P, signal_values):
    Ns = K0.shape[0]
    Q = matrix(zeros(shape=K0.shape))
    for a in xrange(Ns):
        for b in xrange(Ns):
            if a != b:
                k = 0.0
                if V[a,b]:
                    k += K1[a,b] * signal_values[V[a,b]]
                if P[a,b]:
                    k += K2[a,b] * signal_values[P[a,b]]
                k = exp(k)
                k *= K0[a,b]
                if L[a,b]:
                    k *= signal_values[L[a,b]]
                Q[a,b] = k
        Q[a,a] = - sum(Q[a])
    return Q

# For the initial probability vector we follow the lead of (Milescu ...) and
# recompute equilibrium probability each iteration.
# Actually, we'll use QtoPe(eQ) [eQ is the dead-time-corrected Q matrix, below].

def QtoPe(Q):
    # S = [Q | 1]; Pe = 1 * (S * S.T).I
    S = matrix(ones(shape=(Q.shape[0], Q.shape[1]+1)))
    S[:,:-1] = Q
    u = matrix(ones(shape=(1,Q.shape[0])))
    p = u * (S*S.T).I
    return array(p).reshape((Q.shape[0],))

# The core calculation is the log-likelihood (LL) of the model producing the data.
# We follow the method of (Qin 1996) but revise the missed event correction
# to allow the LL of one event to be computed piece-wise.  In MIL, the LL of a dwell
# and its subsequent transition are computed together, with no provision for a
# "transition" from a class to itself.  We re-use the corrected eQaa for the initial (t-td) of an event:

def QtoQe(Q, clazz, td): # (3,4b)
    # z = a-bar (not a)
    expm = scipy.linalg.matfuncs.expm
    eQ = matrix(zeros(shape=Q.shape))
    for a in set(clazz):
        # partitions:
        A = clazz == a
        Z = clazz != a
        Qaa = Q[ix_(A, A)]
        Qaz = Q[ix_(A, Z)]
        Qza = Q[ix_(Z, A)]
        Qzz = Q[ix_(Z, Z)]
        Iz = identity(Qzz.shape[0])
        # staying and returning:
        missed_returning = ((Qaz * (Iz - expm(td*Qzz))) * Qzz.I) * Qza
        eQ[ix_(A, A)] = Qaa - missed_returning
        for b in (set(clazz) - set([a])):
            # partitions
            B = clazz == b
            C = (clazz != a) * (clazz != b)
            Qab = Q[ix_(A, B)]
            Qbb = Q[ix_(B, B)]
            if any(C):
                Qac = Q[ix_(A, C)]
                Qcc = Q[ix_(C, C)]
                Qcb = Q[ix_(C, B)]
                Ic = identity(Qcc.shape[0])
            # switching class:
                escaped_indirect = ((Qac * (Ic - expm(td * Qcc))) * Qcc.I) * Qcb
                escaped_indirect = matrix(zeros(shape=Qab.shape))
            # as published:
            # eQ[ix_(A,B)] = expm(td * missed_returning) * (Qab - escaped_indirect)
            # as coded, since before 2000:
            eQ[ix_(A,B)] = (Qab - escaped_indirect) * expm(td * Qbb)
    return eQ

# Actually, we'll precompute part of exp(eQ*(t-td)) (7) so we can quickly change t:

def Spectrum(M):
    lamb, eig = linalg.eig(M)
    return matrix(eig), lamb, matrix(eig).I
def SpectrumExp(spectrum, t):
    U, lamb, Ui = spectrum
    return U*diag(exp(lamb*t))*Ui

# and compute the forward probability over the final td using the sampled transition matrix
#    A = e**(Q*td)
# with zeros for impossible transitions (into a different class than the data says).
# These are memorized to avoid duplicate calculations, with this general utility:

def memoize(func, lbl=""):
    """Returns wrapped func that caches return values for specific args.
       All args must be immutable (e.g. tuples not lists).

    def factorial(n):
        if n > 1:
           return n * fast_fac(n-1)
           return 1
     fast_fac = memoize(factorial)
    results = {}
    def wrapper(*args):
        if args in results:
            return results[args]
            result = func(*args)
            results[args] = result
            # print lbl, args, ':', result
            return result
    return wrapper

def subi_ll_or_0(args): # note the only arg is a tuple for MSL args; this helps with multiprocessing.Pool.imap
        if args[0] and HAVE_LIB:
            return subi_ll_lib(*args[1:])
            return subi_ll(*args[1:])
    except KeyboardInterrupt:
        return -1e20

def subi_ll(clazz, P0, K0, K1, K2, L, V, P, tdead_sec, segs, mcls, scls, printout=False):
    expm = scipy.linalg.matfuncs.expm
    td = tdead_sec
    if P0 != None:
        P0_norm = array(P0, copy=True)
        P0_norm /= sum(P0_norm) or 1.0
    if printout:
        print 'Nd:',[len(classes) for classes,durations in segs]
    # Q is different for each scls:
    QQ = [BuildQ(K0, K1, K2, L, V, P, amps) for amps in scls]
    if printout:
        print 'Q:',QQ[-1]
        eQQ = [QtoQe(Q, clazz, td) for Q in QQ] # (3,4b)
        if printout:
            print 'eQ:',eQQ[-1]
    except linalg.linalg.LinAlgError, err:
        print err,'; falling back to raw Q'
        eQQ = QQ
    # the submatrices eQaa are used for dwell probability (3)
    eQQaa = [ [(a, eQ[ix_(clazz==a, clazz==a)]) for a in set(clazz)]
              for eQ in eQQ ]
    # we take their spectral decomposition for quick exponentiation:
    eQQaaSpectrum = [ dict([(a, Spectrum(eQaa)) for a, eQaa in Qaas])
                      for Qaas in eQQaa ]
    # class < 0 is a gap
    eQQgapSpectrum = [Spectrum(eQ) for eQ in eQQ]
    if printout:
        print 'eQcc spectrum:'
        print eQQaaSpectrum[-1][0]
    def DwellProbMatrix(multiclass, t): # (7)
        G_full = matrix(zeros(shape=(len(clazz),len(clazz))))
        mc, sc = mcls[multiclass]
        if mc >= 0:
            A = clazz == mc
            U, lamb, Ui = eQQaaSpectrum[sc][mc]
            G_full[ix_(A,A)] = U*diag(exp(lamb*t))*Ui
        else: # gap: exponentiate full eQ matrix
            U, lamb, Ui = eQQgapSpectrum[sc]
            G_full[:,:] = U*diag(exp(lamb*t))*Ui
        return G_full
    # for each scls's Q, there's an A(td):
    AA = [expm(Q*td) for Q in QQ] # (8)
    if printout:
        print 'Atd:', AA[-1]
    # and we zero out different columns depending on the next event.
    # To specify 'any class but a', give destination=-(a+1)
    # destination=None: don't zero anything
    def _Adead(sc, destination=None): # (9)
        if destination == None:
            return AA[sc]  # gap: full A matrix
            Adead = matrix(AA[sc], copy=True)
            if destination >= 0:
                cols = clazz == destination
                cols = clazz != - (destination + 1)
            all_rows = [True] * len(clazz)
            Adead[ix_(all_rows, ~cols)] = 0.0
            return Adead
    # presumably the same transitions happen again and again, so we compute and memorize on first need:
    Adead = memoize(_Adead, 'Adead') # (9)
\(LL = log(\alpha_N 1)\), where \(\alpha_k\) is the column vector of state probabilities at time k. As in (Qin ...), we find \(scale_k = \frac{1}{\sum \alpha_k}\) and reset each \(\alpha_k = \alpha_k * scale_k\), so that \[LL = - \sum_k log(scale_k)\]
    def Scale(ak):
        mag = sum(ak)
        if mag == 0.0:
            raise Exception('Impossible!')
        scale = 1.0 / mag
        ak *= scale
        return scale
    LL = 0.0 # sum of segment LL

    for classes, durations in segs:
        Nd = len(durations)
        scale = zeros(shape=(Nd+1,))
        #alpha = matrix(zeros(shape=(Nd+1, len(clazz))))
        ak = matrix(zeros(shape=(1, len(clazz))))
        mc = classes[0]
        b, sc = mcls[mc]
        if P0 != None:
            P0_seg = array(P0, copy=True)
            P0_seg = QtoPe(eQQ[sc])
        if b >= 0: #no gap on start: zero impossible states
            P0_seg *= (clazz == b)
        # P0 /= sum(P0) or 1.0 ??
        ak[0,:] = P0_seg
        scale[0] = Scale(ak)
        #alpha[0] = ak

        for k in xrange(Nd-1):
            ak *= DwellProbMatrix(mc, durations[k]) # (7) modified for gaps

            mc = classes[k+1]
            a, b = b, mcls[mc][0]
            #print ak
            if b >= 0:
                #print Adead(sc, b)
                ak *= Adead(sc, b) # (9)
            else: # transition into gap: full Adead matrix
                ak *= Adead(sc, None)
            #print ak
            scale[k+1] = Scale(ak) # (10)
            #alpha[k+1] = ak
            #print ak
            #print '-'

            sc = mcls[mc][1]
        a = b
        #print ak
        #print DwellProbMatrix(mc, durations[Nd-1])
        ak *= DwellProbMatrix(mc, durations[Nd-1])
        print ak
        if a >= 0:
            #print Adead(sc, -(a+1))
            ak *= Adead(sc, -(a+1))
        else: # end in gap; full Adead
            ak *= Adead(sc, None)
        #print ak
        scale[Nd] = Scale(ak)
        #alpha[Nd] = ak

        LL -= sum(log(scale)) # (10)
    return LL

    import ctypes
        if 'windows' in os.platform.system().lower():
        maxill = ctypes.cdll.LoadLibrary('maxill_opencl.dll')
            maxill = ctypes.cdll.LoadLibrary('maxill.dll')
                maxill = ctypes.cdll.LoadLibrary('')
                    maxill = ctypes.cdll.LoadLibrary('@executable_path/../Frameworks/')
                        maxill = ctypes.cdll.LoadLibrary('')
                    except OSError:
                        maxill = ctypes.cdll.LoadLibrary('@executable_path/../Frameworks/')
    def subi_ll_lib(clazz, P0, K0, K1, K2, L, V, P, tdead_sec, segs, mcls, scls, printout=False):
        p_int = ctypes.POINTER(ctypes.c_int)
        p_float = ctypes.POINTER(ctypes.c_float)
        p_double = ctypes.POINTER(ctypes.c_double)

        clazz_ = array(clazz, dtype='int32')
        Nseg = len(segs)
        if P0 != None:
            P0_ = array(P0)
            P0_ = None
        K0_ = array(K0)
        K1_ = array(K1)
        K2_ = array(K2)
        L_ = array(L)
        V_ = array(V)
        P_ = array(P)
        dwellCounts = array([len(classes) for classes, durations in segs], dtype='int32')
        classeses = (p_int*Nseg)(*[classes.ctypes.data_as(p_int) for classes, durations in segs])
        durationses = (p_float*Nseg)(*[durations.ctypes.data_as(p_float) for classes, durations in segs])
        Nsig = len(scls[0])
        Nplex = len(mcls)
        plexicls = zeros(dtype='int32', shape=(2*Nplex))
        for i,mc in enumerate(mcls):
            plexicls[2*i  ], plexicls[2*i+1] = mc
        Nstim = len(scls)
        stimcls = zeros(shape=(Nstim, Nsig))
        for i,sc in enumerate(scls):
            stimcls[i,:] = sc
        ll = array([-1e20])
        def cdata(x, typ):
            if x == None: return x
            return x.ctypes.data_as(typ)
        if printout:
            print 'Ndwell',dwellCounts
        rtnval = maxill.subi_ll_arr(len(clazz), cdata(clazz_, p_int), cdata(P0_, p_double),
                                    cdata(K0_, p_double), cdata(K1_, p_double), cdata(K2_, p_double),
                                    cdata(L_, p_int), cdata(V_, p_int), cdata(P_, p_int),
                                    ctypes.c_double(tdead_sec), Nseg, cdata(dwellCounts, p_int), classeses, durationses,
                                    Nsig, Nplex, cdata(plexicls, p_int), Nstim, cdata(stimcls, p_double),
                                    cdata(ll, p_double), Eigen_Scipy_Ptr)
        return ll[0]

    maxill.msl_accel_context_init.argtypes = (c_int_p,)
    maxill.msl_accel_context_init.restype = c_void_p
    maxill.msl_accel_context_free.argtypes = (c_void_p,)
    maxill.msl_accel_context_free.restype = None
    maxill.msl_accel_data_init.argtypes = (c_void_p, c_int, c_int_p, c_int_pp, c_float_pp, c_int, c_int, c_int_p, c_int, c_double_p)
    maxill.msl_accel_data_init.restype = c_void_p
    maxill.msl_accel_data_free.argtypes = (c_void_p,)
    maxill.msl_accel_data_free.restype = None
    maxill.msl_accel_data_get_ll.argtypes = (c_void_p,)
    maxill.msl_accel_data_get_ll.restype = c_double_p
    maxill.msl_accel_models_init.argtypes = (c_void_p, c_int, c_int, c_int)
    maxill.msl_accel_models_init.restype = c_void_p
    maxill.msl_accel_models_free.argtypes = (c_void_p,)
    maxill.msl_accel_models_free.restype = None
    maxill.msl_accel_models_reset.argtypes = (c_void_p,)
    maxill.msl_accel_models_reset.restype = None
    maxill.msl_accel_models_setup_model_arr.argtypes = (c_void_p, c_void_p, c_int, c_int_p, c_double_p, c_double_p, c_double_p, c_double_p, c_int_p, c_int_p, c_int_p, c_double, EigenFunc, ReportFunc, c_void_p)
    maxill.msl_accel_models_setup_model_arr.restype = c_int
    maxill.msl_accel_models_get_ll.argtypes = (c_void_p,)
    maxill.msl_accel_models_get_ll.restype = c_double_p
    maxill.msl_accel_ll.argtypes = (c_void_p, c_void_p, c_void_p, c_int)
    maxill.msl_accel_ll.restype = c_int
    maxill.mil_accel_ll.argtypes = (c_void_p, c_void_p, c_void_p, c_int)
    maxill.mil_accel_ll.restype = c_int


    class msl_accel_context(object):
        def __init__(self, accel=MSL_ACCEL_OPENCL):
            c_accel = c_int(accel)
            self.ctx = maxill.msl_accel_context_init(byref(c_accel))
            self.accel = c_accel.value
        def __nonzero__(self):
            return bool(self.ctx)
        def __del__(self):
        def dispose(self):
            if self.ctx:
                self.ctx = None
        def run_msl(self, data, models, accel):
            rtn = maxill.msl_accel_ll(self.ctx,, models.models, accel)
            models.ll = frombuffer(pybuf(maxill.msl_accel_models_get_ll(models.models), models.Nmodel*sizeof(c_double)), dtype='float64', count=models.Nmodel)
            data.ll = frombuffer(pybuf(maxill.msl_accel_data_get_ll(, models.Nmodel*data.Nseg*sizeof(c_double)), dtype='float64', count=models.Nmodel*data.Nseg).reshape((models.Nmodel, data.Nseg))
            return rtn
        def run_mil(self, data, models, accel):
            rtn = maxill.mil_accel_ll(self.ctx,, models.models, accel)
            models.ll = frombuffer(pybuf(maxill.msl_accel_models_get_ll(models.models), models.Nmodel*sizeof(c_double)), dtype='float64', count=models.Nmodel)
            data.ll = frombuffer(pybuf(maxill.msl_accel_data_get_ll(, models.Nmodel*data.Nseg*sizeof(c_double)), dtype='float64', count=models.Nmodel*data.Nseg).reshape((models.Nmodel, data.Nseg))
            return rtn

    class msl_accel_data(object):
        def __init__(self, context, segs, mcls, scls):
            self.Nseg = len(segs)
            dwellCounts = array([len(classes) for classes, durations in segs], dtype='int32')
            classeses = (c_int_p*self.Nseg)(*[classes.ctypes.data_as(c_int_p) for classes, durations in segs])
            durationses = (c_float_p*self.Nseg)(*[durations.ctypes.data_as(c_float_p) for classes, durations in segs])
            Nsig = len(scls[0])
            Nplex = len(mcls)
            plexicls = zeros(dtype='int32', shape=(2*Nplex))
            for i,mc in enumerate(mcls):
                plexicls[2*i  ], plexicls[2*i+1] = mc
            Nstim = len(scls)
            stimcls = zeros(shape=(Nstim, Nsig), dtype='float64')
            for i,sc in enumerate(scls):
                stimcls[i,:] = sc
   = maxill.msl_accel_data_init(context.ctx, self.Nseg, cdata(dwellCounts, c_int_p), classeses, durationses,
                                                Nsig, Nplex, cdata(plexicls, c_int_p), Nstim, cdata(stimcls, c_double_p))
        def __nonzero__(self):
            return bool(
        def __del__(self):
        def dispose(self):
       = None

    class msl_accel_models(object):
        def __init__(self, context, dim, Nmodel, Nsc):
            self.models = maxill.msl_accel_models_init(context.ctx, dim, Nmodel, Nsc)
            self.dim = dim
            self.Nmodel = Nmodel
        def __nonzero__(self):
            return bool(self.models)
        def __del__(self):
        def dispose(self):
            if self.models:
                self.models = None
        def reset(self):
        def setup_model(self, data, clazz, P0, K0, K1, K2, L, V, P, tdead_sec, do_report=None):
            clazz_ = array(clazz, dtype='int32')
            if P0 != None:
                P0_ = array(P0)
                P0_ = None
            K0_ = array(K0)
            K1_ = array(K1)
            K2_ = array(K2)
            L_ = array(L)
            V_ = array(V)
            P_ = array(P)
            def cdata(x, typ):
                if x is None: return x
                return x.ctypes.data_as(typ)
            def console_report(msg, obj):
                print msg
                return 0
            return maxill.msl_accel_models_setup_model_arr(self.models,, len(clazz), cdata(clazz_, c_int_p),  cdata(P0_, c_double_p),
                                                        cdata(K0_, c_double_p), cdata(K1_, c_double_p), cdata(K2_, c_double_p),
                                                           cdata(L_, c_int_p), cdata(V_, c_int_p), cdata(P_, c_int_p),
                                                        tdead_sec, Eigen_Scipy_Ptr, # EigenFunc(),
                                                        ReportFunc(do_report or console_report), None)

    maxill.qubx_durhist_calhist.argtypes = (c_float, c_int, c_int, c_int_p, c_int_pp, c_float_pp, c_int, c_float_p, c_float_p, c_float)
    maxill.qubx_durhist_calhist.restype = c_int

    def durhist_calhist(c, classes, durations, bins, td_ms):
        ndwell = c_int(len(classes))
        idwell = cdata(classes, c_int_p)
        tdwell = cdata(durations, c_float_p)
        hst = numpy.zeros(shape=bins.shape, dtype='float32')
        total = maxill.qubx_durhist_calhist(0.0, c, 1, byref(ndwell), byref(idwell), byref(tdwell), len(bins), cdata(bins, c_float_p), cdata(hst, c_float_p), td_ms)
        if total:
            hst /= total
        return hst

    maxill.qubx_durhist_calhist_smooth.argtypes = (c_int, c_int, c_int_p, c_int_pp, c_float_pp, c_int, c_float_p, c_float_p, c_float, c_float, c_float)
    maxill.qubx_durhist_calhist_smooth.restype = c_int

    def durhist_calhist_smooth(c, classes, durations, bins, td_ms, sampling):
        ndwell = c_int(len(classes))
        idwell = cdata(classes, c_int_p)
        tdwell = cdata(durations, c_float_p)
        hst = numpy.zeros(shape=bins.shape, dtype='float32')
        total = maxill.qubx_durhist_calhist_smooth(c, 1, byref(ndwell), byref(idwell), byref(tdwell), len(bins), cdata(bins, c_float_p), cdata(hst, c_float_p), td_ms, sampling*1e3, bins[1]/bins[0])
        if total:
            hst /= total
        return hst

    HAVE_LIB = True
except KeyboardInterrupt:
    print 'Compiled maxill library was not found.'
    HAVE_LIB = False

# Here we use simplex to maximize likelihood of user-supplied or stock data and model.
# TODO:  standardize output for diff-testing
#        explain

TDEAD = 0.1 # the sampling rate since the simulator is perfect
MODELFILE = 'msl_test.qmf'
DATAFILE = 'msl_test.qsf'

if __name__ == '__main__':
    import sys
    import qubx.model
    import qubx.tree

    if '--help' in sys.argv:
        print """usage: [<tdead=0 in ms> [<model>.qmf [<data>.qsf]]]
or --verify"""
    if '--verify' in sys.argv:
        import subprocess
        p = subprocess.Popen(sys.argv[0] + " | diff msl_test.out -", shell=True)

    def QSFtoSignals(session, wanted_signals=None):
        raw_idls = [(i, str(dc['Name'].data), chan)
                    for i, dc, chan in itertools.izip(itertools.count(),
                                                      qubx.tree.children(session['DataChannels'], 'Channel'),
                                                      qubx.tree.children(session['Idealization'], 'Channel'))
                    if not chan.find('Segment').isNull]
        wanted_idls = [(nm, chan) for i, nm, chan in raw_idls if ((i==0) or (wanted_signals==None) or (nm in wanted_signals))]
        raw_segs = [(nm, [(seg['Classes'][:,0],
                          for seg in qubx.tree.children(chan, 'Segment')])
                    for nm, chan in wanted_idls]
        return [nm for nm, segs in raw_segs], [segs for nm, segs in raw_segs]
        # [name_0, ...],. [(classes_0_seg0, durations_0_seg0, amp_of_cls_0_seg0), ...seg1...]

    tdead = (len(sys.argv) > 1) and float(sys.argv[1]) or TDEAD
    print 'Tdead:', tdead

    model_path = (len(sys.argv) > 2) and sys.argv[2] or MODELFILE
    data_path = (len(sys.argv) > 3) and sys.argv[3] or DATAFILE

    model = qubx.model.OpenQubModel(model_path)
    print model
    rates = model.rates
    wanted_signals = set([rates.get(r, 'Ligand') for r in xrange(rates.size)] +
                         [rates.get(r, 'Voltage') for r in xrange(rates.size)])

    qsf = qubx.tree.Open(data_path, True)
    signal_names, signals = QSFtoSignals(qsf, wanted_signals)
    segments, multi_classes, stimulus_classes = ProcessSignals(tdead, signals)
    print 'Events:', sum(len(cl) for cl,du in segments)

    td_sec = tdead * 1e-3
    clazz = qubx.model.ModelToClazz(model)
    K0, K1, K2, L, V, Pres = qubx.model.ModelToRateMatricesP(model, signal_names)
    K = K012toK(K0, K1, K2)
    t0 = time.time()
    LL = subi_ll(clazz, None, K0, K1, K2, L, V, Pres, td_sec, segments, multi_classes, stimulus_classes, printout=True)
    print 'Initial LL:', LL, '[', ('%.3f' % (time.time() - t0)), 'sec ]'

    if HAVE_LIB:
        t0 = time.time()
        LL = subi_ll_lib(clazz, None, K0, K1, K2, L, V, Pres, td_sec, segments, multi_classes, stimulus_classes, printout=True)
        print 'Compiled LL:', LL, '[', ('%.3f' % (time.time() - t0)), 'sec ]'

    constraints = model.constraints_kin.get_matrices_p(K0, K1, K2, L, V, Pres)
    print 'Constraints in:'
    print Asys
    print Bsys
    Asys, Bsys = reduce_constraints(Asys, Bsys)
    Acns, Bcns, Ainv, pars = linear_constraints(Asys, Bsys, K)
    print 'Constraints out:'
    print Acns
    print Bcns
    Ascal, Bscal, Asci, pars = start_at_ones(pars)
    if any(pars != 1.0):
        print 'Non-unit starting par!',pars
        print Ascal
        print Bscal
        print Asci

    msl_ll = HAVE_LIB and subi_ll_lib or subi_ll
    def msl_func(pars):
        P = parsToK(pars, Ascal, Asci, Bscal)
        K = ParsToK(P, Acns, Ainv, Bcns)
        new_K0, new_K1, new_K2 = KtoK012(K, K0)
        LL = msl_ll(clazz, None, new_K0, new_K1, new_K2, L, V, Pres, td_sec, segments, multi_classes, stimulus_classes)
        print LL,pars
        return - LL
    counter = itertools.count()
    def msl_iter(pars):
        print '---> %i <---', pars
        P = parsToK(pars, Ascal, Asci, Bscal)
        K = parsToK(P, Acns, Ainv, Bcns)
        print 'K',K.T
        new_K0, new_K1, new_K2 = KtoK012(K, K0)
        qubx.model.K012ToModel(new_K0, new_K1, new_K2, model)
        print model.rates
        print '------------->    <---'
    pars, fopt, iter, funcalls, warnflag = \
        scipy.optimize.fmin(msl_func, pars, maxfun=2000,
                            xtol=.0001, ftol=.0001, maxiter=300,
                            full_output=1, disp=1, retall=0, callback=msl_iter)
    print 'final',

# ----------------------
import struct
from functools import partial

# (c) 2010 Eric L. Frederich
# Python implementation of algorithms detailed here...
# from

def c_mem_cast(x, f=None, t=None):
    do a c-style memory cast

    In Python...

    x = 12.34
    y = c_mem_cast(x, 'd', 'l')

    ... should be equivilent to the following in c...

    double x = 12.34;
    long   y = *(long*)&x;
    return struct.unpack(t, struct.pack(f, x))[0]

dbl_to_lng = partial(c_mem_cast, f='d', t='l')
lng_to_dbl = partial(c_mem_cast, f='l', t='d')
flt_to_int = partial(c_mem_cast, f='f', t='i')
int_to_flt = partial(c_mem_cast, f='i', t='f')

def ulp_diff_maker(converter, negative_zero):
    Getting the ulp difference of floats and doubles is similar.
    Only difference if the offset and converter.
    def the_diff(a, b):

        # Make a integer lexicographically ordered as a twos-complement int
        ai = converter(a)
        if ai < 0:
            ai = negative_zero - ai

        # Make b integer lexicographically ordered as a twos-complement int
        bi = converter(b)
        if bi < 0:
            bi = negative_zero - bi

        return abs(ai - bi)

    return the_diff

# double ULP difference
dulpdiff = ulp_diff_maker(dbl_to_lng, 0x8000000000000000)
# float  ULP difference
fulpdiff = ulp_diff_maker(flt_to_int, 0x80000000        )

# default to double ULP difference
ulpdiff = dulpdiff
ulpdiff.__doc__ = '''
Get the number of doubles between two doubles.