max_inter_ll.py.html mathcode2html   
 Source file:   max_inter_ll.py
 Converted:   Sun May 10 2015 at 16:07:49
 This documentation file will not reflect any later changes in the source file.

$$\phantom{******** If you see this on the webpage then the browser could not locate *********}$$
$$\phantom{******** jsMath/easy/load.js or the variable root is set wrong in this file *********}$$
$$\newcommand{\vector}[1]{\left[\begin{array}{c} #1 \end{array}\right]}$$ $$\newenvironment{matrix}{\left[\begin{array}{cccccccccc}} {\end{array}\right]}$$ $$\newcommand{\A}{{\cal A}}$$ $$\newcommand{\W}{{\cal W}}$$

#!/usr/bin/python

#/* Copyright 2008-2014 Research Foundation State University of New York   */

#/* This file is part of QUB Express.                                      */

#/* QUB Express is free software; you can redistribute it and/or modify    */
#/* it under the terms of the GNU General Public License as published by   */
#/* the Free Software Foundation, either version 3 of the License, or      */
#/* (at your option) any later version.                                    */

#/* QUB Express is distributed in the hope that it will be useful,         */
#/* but WITHOUT ANY WARRANTY; without even the implied warranty of         */
#/* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          */
#/* GNU General Public License for more details.                           */

#/* You should have received a copy of the GNU General Public License,     */
#/* named LICENSE.txt, in the QUB Express program directory.  If not, see  */
#/* <http://www.gnu.org/licenses/>.                                        */

import collections
import itertools
from numpy import *
import scipy
import scipy.linalg
import scipy.optimize
import traceback
import time

See also:

Up: Index


# ============= The Data =========================

# The data is a list of segments.  A segment of consists of "events" ("dwells").
# An event has duration in milliseconds, and class -- index of its measured amplitude.
# For each segment there's a list of class amplitude values.
# You provide the list [(classes[], durations[]) ...]


# tdead is the longest duration of events that can't be reliably detected.
# MIL assumes you have deleted any such events, by merging them with their prior.
# Also, for computational reasons, tdead should be subtracted from each event.
# We provide a utility which merges and shortens events (makes a copy):

def ApplyDeadTime(classes, durations, tdead):
    """ApplyDeadTime(classes, durations, tdead)

    classes:      list of each event's class index
    durations:    list of each event's duration, in milliseconds
    tdead:        length of shortest unreliably perceptible event
    
    Returns: (classes, durations) with
                 * events shorter than tdead joined to the prior event
                 * remaining events shortened by tdead
                   (because the forward likelihood is calculated separately for
                    the first (duration-tdead) "dwell time" and the final
                    tdead of "transition time.")"""
    # copy the arrays and modify them in-place
    cl = array(classes, copy=True, dtype='int32')
    du = array(durations, copy=True)
    tm = 0.0
    L = len(du)
    # skip initial too-shorts
    first_alive = 0
    while first_alive < L:
        if (du[first_alive] < tdead) and (fulpdiff(du[first_alive], tdead) > MAX_ULP_DIFF):
            tm += du[first_alive]
            first_alive += 1
        else:
            # du[first_alive] += tm # uncomment to join them to the first long-enough event
            break
    i_rd, i_wr = first_alive, 0
    while i_rd < L:
        cls, tm = cl[i_rd], du[i_rd]
        if i_wr and ((cl[i_wr-1] == cls) or ((tm < tdead) and (fulpdiff(tm, tdead) > MAX_ULP_DIFF))):
            du[i_wr-1] += tm
        else:
            cl[i_wr] = cls
            du[i_wr] = tm - tdead
            i_wr += 1
        i_rd += 1
    cl.resize((i_wr))
    du.resize((i_wr))
    return cl, du

# and convert durations to seconds, for compatibility with Q-math

def ProcessSegments(tdead, segments):
    segs = [ApplyDeadTime(classes, durations, tdead) for classes, durations in segments]
    segs = [(classes, 1e-3*array(durations, dtype='float32')) for classes, durations in segs]
    return segs


# ============= The Model ====================

# Our Markov model is a graph with colored vertices.  A vertex is called a "state,"
# and its color is a nonnegative integer called its "class."  States in the same
# class are indistinguishable (same amp).  To describe the vertices, provide the array
#    clazz = array([class of state s for s in len(states)])
#    Ns    = len(clazz)

# Each edge is labeled with its transition rate (probability per second).  These form the matrix
#    Q, a Ns x Ns matrix with
#    Q[a,b] = rate from state a to state b
#    Q[a,a] = - sum(Q[a,:])

# Each Q[a,b] is computed by q = k0 * ligand * e**(k1 * voltage).  You provide the Ns x Ns matrices
#    K0, K1   of kinetic parameters
#    L, V     index of the ligand or voltage constant influencing each rate, or 0
# with the diagonals undefined.

# A pair of states is either connected in both directions or neither.  To indicate un-connectedness, set
#    K0[a,b] = K0[b,a] = 0.0

If any \(Ligand_{a,b}\ \neq 0\) or \(Voltage_{a,b}\ \neq 0\), you must provide a list of Constants, where e.g. if \(Ligand_{a,b} = 2\) then Constants[2] holds the ligand concentration. We assume all segments were recorded at the same constant conditions. For global fitting, call max_inter_ll separately for each dataset and sum the LL.




The Q matrix of rate constants (probability per second) is computed from intrinsic rate constants K0 and K1, and potentially Ligand- and Voltage-sensitive. \[Q_{a,b} = K0_{a,b} * Ligand_{a,b} * e^{K1_{a,b} * Voltage_{a,b}}\] \[Q_{a,a} = - \sum_i Q_{a,i}\]

def BuildQ(K0, K1, K2, L, V, P, constants=[]):
    Ns = K0.shape[0]
    Q = matrix(zeros(shape=K0.shape))
    for a in xrange(Ns):
        for b in xrange(Ns):
            if a != b:
                k = 0.0
                if V[a,b]:
                    k += K1[a,b] * constants[V[a,b]]
                if P[a,b]:
                    k += K2[a,b] * constants[P[a,b]]
                k = exp(k)
                k *= K0[a,b]
                if L[a,b]:
                    k *= constants[L[a,b]]
                Q[a,b] = k
        Q[a,a] = - sum(Q[a])
    return Q


As discussed in (Milescu 2005), it can be preferable to start an experiment at equilibrium. Entry probabilities (P0) are then no longer constant, but a function of rate constants. We use the "direct" method given in Neher and Sakmann:

Defining \(S = [Q | 1]\) and \(u\) a row vector of ones, \[P_{eq} = u \cdot (S \cdot S^T)^{-1}\] Actually, we'll use QtoPe(eQ) [eQ is the dead-time-corrected Q matrix, below].


def QtoPe(Q):
    # S = [Q | 1]; Pe = 1 * (S * S.T).I
    S = matrix(ones(shape=(Q.shape[0], Q.shape[1]+1)))
    S[:,:-1] = Q
    u = matrix(ones(shape=(1,Q.shape[0])))
    p = u * (S*S.T).I
    return array(p).reshape((Q.shape[0],))


"Qe" aka \({}^eQ\) is the "apparent" rate constant matrix, given that events with duration <= tdead are not recorded. MIL computes the probability of staying in class a for duration t using the submatrix \({}^eQ_{aa}\) \[A(a, t) = e^{{}^eQ_{aa} t'}\]

def QtoQe(Q, clazz, td):
    # z = a-bar (not a)
    expm = scipy.linalg.matfuncs.expm
    eQ = matrix(zeros(shape=Q.shape))
    for a in set(clazz):
        # partitions:
        A = clazz == a
        Z = clazz != a
        Qaa = Q[ix_(A, A)]
        Qaz = Q[ix_(A, Z)]
        Qza = Q[ix_(Z, A)]
        Qzz = Q[ix_(Z, Z)]
        Iz = identity(Qzz.shape[0])
        # staying and returning:
        missed_returning = ((Qaz * (Iz - expm(td*Qzz))) * Qzz.I) * Qza
        eQ[ix_(A, A)] = Qaa - missed_returning
        
        for b in (set(clazz) - set([a])):
            # partitions
            B = clazz == b
            C = (clazz != a) * (clazz != b)
            Qab = Q[ix_(A, B)]
            Qbb = Q[ix_(B, B)]
            if any(C):
                Qac = Q[ix_(A, C)]
                Qcc = Q[ix_(C, C)]
                Qcb = Q[ix_(C, B)]
                Ic = identity(Qcc.shape[0])
            # switching class:
                escaped_indirect = ((Qac * (Ic - expm(td * Qcc))) * Qcc.I) * Qcb
            else:
                escaped_indirect = matrix(zeros(shape=Qab.shape))
            # as published:
            # eQ[ix_(A,B)] = expm(td * missed_returning) * (Qab - escaped_indirect)
            # as coded, since before 2000:
            eQ[ix_(A,B)] = (Qab - escaped_indirect) * expm(td * Qbb)
    return eQ

# Actually, we'll precompute part of exp(eQ*(t-td)) (7) so we can quickly change t:

def Spectrum(M):
    lamb, eig = linalg.eig(M)
    return matrix(eig), lamb, matrix(eig).I
def SpectrumExp(spectrum, t):
    U, lamb, Ui = spectrum
    return U*diag(exp(lamb*t))*Ui

# and compute the forward probability over the final td using the sampled transition matrix
#    A = e**(Q*td)
# with zeros for impossible transitions (into a different class than the data says).
# These are memorized to avoid duplicate calculations, with this general utility:

def memoize(func, lbl=""):
    """Returns wrapped func that caches return values for specific args.
       All args must be immutable (e.g. tuples not lists).

Example:
    def factorial(n):
        if n > 1:
           return n * fast_fac(n-1)
        else:
           return 1
     fast_fac = memoize(factorial)
"""
    results = {}
    def wrapper(*args):
        if args in results:
            return results[args]
        else:
            result = func(*args)
            results[args] = result
            # print lbl, args, ':', result
            return result
    return wrapper


def inter_ll_or_0(args): # note the only arg is a tuple for MSL args; this helps with multiprocessing.Pool.imap
    try:
        return inter_ll(*args)
    except:
        traceback.print_exc()
        return 0.0


def inter_ll(clazz, p0, K0, K1, K2, L, V, P, constants, tdead_sec, segs, printout=False):
    expm = scipy.linalg.matfuncs.expm
    td = tdead_sec
    
    ixset = dict([(a, clazz==a) for a in set(clazz)])
    nxset = dict([(a, clazz!=a) for a in set(clazz)])
    statecount = collections.defaultdict(lambda: 0)
    for a in clazz:
        statecount[a] = statecount[a] + 1
    
    Q = BuildQ(K0, K1, K2, L, V, P, constants)
    if printout:
        print 'Q:',Q
    try:
        eQ = QtoQe(Q, clazz, td)
        if printout:
            print 'eQ:',eQ
    except linalg.linalg.LinAlgError, err:
        print err,'; falling back to raw Q'
        eQ = Q
    # the submatrices eQaa are used for dwell probability (3)
    eQaa = [(a, eQ[ix_(ixset[a], ixset[a])]) for a in set(clazz)]

    # we take their spectral decomposition for quick exponentiation:
    eQaaSpectrum = dict([(a, Spectrum(eQaa_sub)) for a,eQaa_sub in eQaa])
    def At_aa(a, t):
        U, lamb, Ui = eQaaSpectrum[a]
        return U * diag(exp(lamb*t)) * Ui
    # and memorize the short ones since they'll probably recur
    At_aa_memo = memoize(At_aa, 'At_aa')
    if printout:
        print 'eQaa spectrum:'
        print eQaaSpectrum
    MAX_T_MEMO = 7 * td
    
    # memorize transition submatrices
    def eQab_ix(A, B):
        return eQ[ix_(A,B)]
    def _eQab(a, b):
        return eQab_ix(ixset[a], ixset[b])
    eQab = memoize(_eQab, 'eQab')
    
    if p0 == None:
        P0 = QtoPe(eQ)
    else:
        P0 = p0
    
    
\(LL = log(\alpha_N 1)\), where \(\alpha_k\) is the column vector of state probabilities at time k. We find \(scale_k = \frac{1}{\sum \alpha_k}\) and reset each \(\alpha_k = \alpha_k * scale_k\), so that \[LL = - \sum_k log(scale_k)\]
    def Scale(ak):
        mag = sum(ak)
        if mag == 0.0:
            raise Exception('Impossible!')
        scale = 1.0 / mag
        ak *= scale
        return scale
    
    # scale accumulator -- sum ll over all segments all events
    ll = 0.0

    for classes, durations in segs:
        Nd = len(durations)
        scale = zeros(shape=(Nd+1,))
        #alpha = matrix(zeros(shape=(Nd+1, len(clazz))))
        ak = matrix(zeros(shape=(1, len(clazz))))
        a = b = classes[0]
        Na = Nb = statecount[a]
        P0_a = array(P0[ixset[a]])
        ak[0,:Na] = P0_a / (sum(P0_a) or 1.0)
        ll -= log( Scale(ak[0,:Na]) )
        
        for k in xrange(Nd-1):
            a, b = b, classes[k+1]
            Na, Nb = Nb, statecount[b]
            t = durations[k]

            if t <= MAX_T_MEMO:
                At = At_aa_memo(a, t)
            else:
                At = At_aa(a, t)
            ak[0,:Na] = ak[0,:Na] * At
            ll -= log( Scale(ak[0,:Na]) )
            
            ak[0,:Nb] = ak[0,:Na] * eQab(a, b)
            ll -= log( Scale(ak[0,:Nb]) )

        a, Na = b, Nb
        t = durations[Nd-1]
        ak[0,:Na] *= At_aa(a, t)
        ll -= log( Scale(ak[0,:Na]) )

        Nz = len(clazz) - Na
        ak[0,:Nz] = ak[0,:Na] * eQab_ix(ixset[a], nxset[a])
        ll -= log( Scale(ak[0,:Nz]) )

    return ll



try:
    import ctypes
    import os
    try:
        if 'windows' in os.platform.system().lower():
            ctypes.cdll.LoadLibrary('OpenCL.dll')
        maxill = ctypes.cdll.LoadLibrary('maxill_opencl.dll')
    except:
        try:
            maxill = ctypes.cdll.LoadLibrary('maxill.dll')
        except:
            try:
                maxill = ctypes.cdll.LoadLibrary('libmaxill_opencl.so')
            except:
                try:
                    maxill = ctypes.cdll.LoadLibrary('@executable_path/../Frameworks/libmaxill_opencl.so')
                except:
                    try:
                        maxill = ctypes.cdll.LoadLibrary('libmaxill.so')
                    except OSError:
                        maxill = ctypes.cdll.LoadLibrary('@executable_path/../Frameworks/libmaxill.so')
    
    def inter_ll_lib(clazz, P0, K0, K1, K2, L, V, P, constants, tdead_sec, segs, printout=False):
        p_int = ctypes.POINTER(ctypes.c_int)
        p_float = ctypes.POINTER(ctypes.c_float)
        p_double = ctypes.POINTER(ctypes.c_double)

        clazz_ = array(clazz, dtype='int32')
        Nseg = len(segs)
        if P0 != None:
            P0_ = array(P0)
        else:
            P0_ = None
        K0_ = array(K0)
        K1_ = array(K1)
        K2_ = array(K2)
        L_ = array(L, dtype='int32')
        V_ = array(V, dtype='int32')
        P_ = array(P, dtype='int32')
        constants_ = array(constants)
        dwellCounts = array([len(classes) for classes, durations in segs], dtype='int32')
        classeses = (p_int*Nseg)(*[classes.ctypes.data_as(p_int) for classes, durations in segs])
        durationses = (p_float*Nseg)(*[durations.ctypes.data_as(p_float) for classes, durations in segs])
        ll = array([-9898.0])
        
        cdata = lambda x,typ: x.ctypes.data_as(typ)
        rtnval = maxill.inter_ll_arr(len(clazz), cdata(clazz_, p_int), cdata(P0_, p_double),
                                     cdata(K0_, p_double), cdata(K1_, p_double), cdata(K2_, p_double), cdata(L_, p_int), cdata(V_, p_int), cdata(P_, p_int),
                                     cdata(constants_, p_double), ctypes.c_double(tdead_sec),
                                     Nseg, cdata(dwellCounts, p_int), classeses, durationses,
                                     cdata(ll, p_double))
        return ll[0]
    HAVE_LIB = True
except:
    traceback.print_exc()
    print 'Compiled maxill library was not found.'
    HAVE_LIB = False


# What are the optimization parameters?  (some of) the off-diagonal elements of K0 and K1.
# Actually, we use log(K0) to keep them positive.
# Breaking with past practice, we include all the unconnected zeros,
# and put all the K0s before all the K1s.  As we'll see, this is a column vector.

def K_index(N):
    ix = []
    if N <= 0: return ix
    for a in xrange(N):
        for b in xrange(N):
            if a != b:
                ix.append((a,b))
    return ix

def K012toK(K0, K1, K2):
    k0, k1, k2 = [], [], []
    for a,b in K_index(K0.shape[1]):
        if a != b:
            x = K0[a,b]
            k0.append((x > 0.0) and log(x) or 0.0) # tricky: 1->0 and 0->0; see KtoK01
            k1.append(K1[a,b])
            k2.append(K2[a,b])
    K = matrix(zeros(shape=(len(k0)+len(k1)+len(k2), 1)))
    K[:,0] = matrix(array(k0 + k1 + k2)).T
    return K

# K alone isn't enough to reconstruct K0 and K1, since 0 could mean log(1), or else unconnected.
# We disambiguate by checking the original K0 for zeros, which mean unconnected.

def KtoK012(K, last_K0):
    # len(K)/Ns = K0.N * (K0.N - 1);  K0.N = .5 + sqrt(len(K)/Nstim + .25)
    N = int(round(.5 + sqrt(.25 + len(K)/3.0)))
    ixs = K_index(N)
    K0 = matrix(zeros(shape=(N,N)))
    i = 0
    for a,b in ixs:
        if a != b:
            if last_K0[a,b]:
                K0[a,b] = exp(K[i,0])
            i += 1
    K1 = matrix(zeros(shape=(N,N)))
    for a,b in ixs:
        if a != b:
            K1[a,b] = K[i]
            i += 1
    K2 = matrix(zeros(shape=(N,N)))
    for a,b in ixs:
        if a != b:
            K2[a,b] = K[i]
            i += 1
    return K0, K1, K2


# But some of those K are zero and have to stay zero, and the user may have other constraints.
# We represent the constraints by the system of linear equations (11) and derive the transformations (12, 13).

# Some constraints are more a part of the representation than the model:
#   * where K0 == 0 there's no connection; fix k0 and k1
#   * where V == 0  fix k1
def AutoConstraints(K0, K1, K2, L, V, P):
    """Returns the constraints (Ain, Bin)   [11] which fix unconnected k0 and unsensitive k1."""
    aa = []
    bb = []
    ixs = K_index(K0.shape[0])
    for i, xx in enumerate(ixs):
        a, b = xx
        if not K0[a,b]:
            coeffs = zeros(shape=(1,3*len(ixs)))
            coeffs[0, i] = 1.0
            aa.append(coeffs)
            bb.append(0.0)
            V[a,b] = 0
        if not V[a,b]:
            coeffs = zeros(shape=(1,3*len(ixs)))
            coeffs[0, len(ixs)+i] = 1.0
            aa.append(coeffs)
            bb.append(0.0)
        if not P[a,b]:
            coeffs = zeros(shape=(1,3*len(ixs)))
            coeffs[0, 2*len(ixs)+i] = 1.0
            aa.append(coeffs)
            bb.append(0.0)
    if aa:
        return vstack(aa), matrix(array(bb)).T
    else:
        return matrix(zeros(shape=(0,0))), matrix(zeros(shape=(0,1)))

# The user might also have some constraints for our linear system:
def UserConstraints(K0, K1, K2, L, V, P, constraints):
    """Returns the constraints (Ain, Bin)   [11]
    corresponding to the tuples    (name, states) in constraints
    where name is in ["FixRate", "FixExp", "FixPress", "ScaleRate", "ScaleExp", "ScalePress", "BalanceLoop", "ImbalanceLoop"]
    and states is a sequence of state indices.  The first six names require 2,2,4,4 states respectively;
    the last three require at least 3."""
    aa = []
    bb = []
    ixs = K_index(K0.shape[0])
    ixo = dict([(ab, i) for i, ab in enumerate(ixs)])
    N = len(ixs)
    for name, states  in constraints:
        if name == 'FixRate':
            r = ixo[(states[0], states[1])]
            coeffs = zeros(shape=(1,3*N))
            coeffs[0, r] = 1.0
            aa.append(coeffs)
            bb.append(log(K0[states[0], states[1]]))
        if name == 'FixExp':
            r = ixo[(states[0], states[1])]
            coeffs = zeros(shape=(1,3*N))
            coeffs[0, N+r] = 1.0
            aa.append(coeffs)
            bb.append(K1[states[0], states[1]])
        if name == 'FixPress':
            r = ixo[(states[0], states[1])]
            coeffs = zeros(shape=(1,3*N))
            coeffs[0, 2*N+r] = 1.0
            aa.append(coeffs)
            bb.append(K2[states[0], states[1]])
        if name == 'ScaleRate':
            r1 = ixo[(states[0], states[1])]
            r2 = ixo[(states[2], states[3])]
            coeffs = zeros(shape=(1,3*N))
            coeffs[0, r1] = 1.0
            coeffs[0, r2] = -1.0
            aa.append(coeffs)
            bb.append(log(K0[states[0], states[1]]) - log(K0[states[2], states[3]]))
        if name == 'ScaleExp':
            r1 = ixo[(states[0], states[1])]
            r2 = ixo[(states[2], states[3])]
            coeffs = zeros(shape=(1,3*N))
            coeffs[0, N+r1] = 1.0
            coeffs[0, N+r2] = - K1[states[0], states[1]] / K1[states[2], states[3]]
            aa.append(coeffs)
            bb.append(0.0)
        if name == 'ScalePress':
            r1 = ixo[(states[0], states[1])]
            r2 = ixo[(states[2], states[3])]
            coeffs = zeros(shape=(1,3*N))
            coeffs[0, 2*N+r1] = 1.0
            coeffs[0, 2*N+r2] = - K2[states[0], states[1]] / K2[states[2], states[3]]
            aa.append(coeffs)
            bb.append(0.0)
        elif name == 'LoopBal':###
            v_depend = p_depend = False
            loopst = list(states) + [states[0]]
            coeffs = zeros(shape=(1,3*N))
            for i in xrange(len(loopst)-1):
                rf, rb = ixo[(loopst[i], loopst[i+1])], ixo[(loopst[i+1], loopst[i])]
                coeffs[0,rf] = 1.0
                coeffs[0,rb] = -1.0
                v_depend = v_depend or V[(loopst[i], loopst[i+1])] or V[(loopst[i+1], loopst[i])]
                p_depend = p_depend or P[(loopst[i], loopst[i+1])] or P[(loopst[i+1], loopst[i])]
            aa.append(coeffs)
            bb.append(0.0)
            if v_depend:
                coeffs = zeros(shape=(1,3*N))
                for i in xrange(len(loopst)-1):
                    rf, rb = ixo[(loopst[i], loopst[i+1])], ixo[(loopst[i+1], loopst[i])]
                    coeffs[0,N+rf] = 1.0
                    coeffs[0,N+rb] = -1.0
                aa.append(coeffs)
                bb.append(0.0)
            if p_depend:
                coeffs = zeros(shape=(1,3*N))
                for i in xrange(len(loopst)-1):
                    rf, rb = ixo[(loopst[i], loopst[i+1])], ixo[(loopst[i+1], loopst[i])]
                    coeffs[0,2*N+rf] = 1.0
                    coeffs[0,2*N+rb] = -1.0
                aa.append(coeffs)
                bb.append(0.0)
        elif typ == 'LoopImbal':
            v_depend = p_depend = False
            loopst = list(states) + [states[0]]
            coeffs = zeros(shape=(1,3*N))
            b = 0.0
            for i in xrange(len(loopst)-1):
                rf, rb = ixo[(loopst[i], loopst[i+1])], ixo[(loopst[i+1], loopst[i])]
                coeffs[0,rf] = 1.0
                coeffs[0,rb] = -1.0
                b += log(K0[(loopst[i], loopst[i+1])]) - log(K0[(loopst[i+1], loopst[i])])
                v_depend = v_depend or V[(loopst[i], loopst[i+1])] or V[(loopst[i+1], loopst[i])]
                p_depend = p_depend or P[(loopst[i], loopst[i+1])] or P[(loopst[i+1], loopst[i])]
            aa.append(coeffs)
            bb.append(b)
            if v_depend:
                coeffs = zeros(shape=(1,3*N))
                for i in xrange(len(loopst)-1):
                    rf, rb = ixo[(loopst[i], loopst[i+1])], ixo[(loopst[i+1], loopst[i])]
                    coeffs[0,N+rf] = 1.0
                    coeffs[0,N+rb] = -1.0
                aa.append(coeffs)
                bb.append(0.0)
            if p_depend:
                coeffs = zeros(shape=(1,3*N))
                for i in xrange(len(loopst)-1):
                    rf, rb = ixo[(loopst[i], loopst[i+1])], ixo[(loopst[i+1], loopst[i])]
                    coeffs[0,2*N+rf] = 1.0
                    coeffs[0,2*N+rb] = -1.0
                aa.append(coeffs)
                bb.append(0.0)

    if aa:
        return vstack(aa), matrix(array(bb)).T
    else:
        return matrix(zeros(shape=(0,N))), matrix(zeros(shape=(0,1)))


# You could think of others too; just vstack them with the rest:
def AllConstraints(K0, K1, K2, L, V, P, constraints):
    """Returns combined AutoConstraints(...) and UserConstraints(...)"""
    A, B = UserConstraints(K0, K1, K2, L, V, P, constraints)
    Aauto, Bauto = AutoConstraints(K0, K1, K2, L, V, P)
    if A.shape[0] == 0:
        return Aauto, Bauto
    elif Aauto.shape[0] == 0:
        return A, B
    else:
        return vstack((A, Aauto)), vstack((B, Bauto))

# To derive the transformation we will use singular value decomposition (svd)
# which is really well described at http://www.uwlax.edu/faculty/will/svd/index.html

# This wrapper around numpy.linalg.svd makes it behave more like LAPACK:
#   * V is transposed
#   * matrices have full dimension
# Then we repair and check it:
#   * tiny numbers are replaced with 0
#   * U and V need orthonormal columns
# And derive some useful extras:
#   * Winv, where inv(0) = 0
#   * Ainv = V*Winv*U.T
#   * Nsv = number of nonzero singular values
def adjusted_svd(A):
    u,s,v = linalg.svd(A, full_matrices=True)
    N = s.shape[0]
    U = matrix(zeros(shape=(max(A.shape[0], u.shape[0]), max(A.shape[1], u.shape[1]))))
    U[:u.shape[0],:u.shape[1]] = u
    S = zeros(shape=(U.shape[1],))
    S[:N] = s
    N = len(s[s>=1e-10])
    W = diag(S)
    V = matrix(zeros(shape=(A.shape[1], W.shape[1])))
    V[:v.shape[1],:v.shape[0]] = v.T 
    for M in U, W, V:
        for i in xrange(M.shape[0]):
            for j in xrange(M.shape[1]):
                if abs(M[i,j]) < 1e-10:
                    M[i,j] = 0.0
    for i,x in enumerate(S):
        if abs(x) < 1e-10:
            S[i] = 0.0
    Sinv = array([x and 1/x or 0 for x in S])
    Winv = diag(Sinv)
    # A = U*W*V.T; A.I = V*Winv*U.T
    if linalg.norm(identity(U.shape[0]) - U*U.T) > 1e-8:
        raise Exception("Bad SVD: U's columns are not orthonormal")
    if linalg.norm(identity(V.shape[0]) - V*V.T) > 1e-8:
        raise Exception("Bad SVD: V's columns are not orthonormal")
    return U, W, V, Winv, V*(Winv*U.T), N

# This function cleans up the constraint system prior to the transform,
# by deleting duplicate constraints.  It uses two checks in case of numerical inaccuracy.
# It's called automatically in linear_constraints().
def reduce_constraints(A, B):
    Csrc = matrix(hstack((A, B)).T) # may not be square
    N = max(Csrc.shape)
    C = matrix(zeros(shape=(N,N))) # (14)
    C[:Csrc.shape[0],:Csrc.shape[1]] = Csrc
    Csub = C
    cols = [C[:,i] for i in xrange(Csrc.shape[1])]
    i = 0
    rank = 0
    while i < len(cols):
        Csub = hstack(cols[:i+1]) # (15)
        U, W, V, Winv, Cinv, Nsv = adjusted_svd(Csub) # (16)
        if Nsv == rank: # check 1: added constraint must increase rank (21)
            del cols[i]
        else:
            rank = Nsv
            j = i+1
            while j < len(cols):
                # check 2: other constraints can't be lin. combinations of the prior
                image = Cinv * cols[j] # (19)
                if 1e-5 > linalg.norm(cols[j] - (Csub*image)):
                    del cols[j]
                else:
                    j += 1
            i += 1
    return matrix(Csub[:-1,:].T), matrix(Csub[-1,:].T)
  

# Now that we have K0, K1, L, V and rate vector K,  
# we do some magic with the SVD to change your linear constraints
#    A * K = B
# into a transformation between K and the column vector P of free parameters
#    K = A*P + B
#    P = Ainv * K

def linear_constraints(Ain, Bin, K0):
    """Returns A, B, A.I, P0 [12]."""
    A, B = reduce_constraints(Ain, Bin)
    Nc = A.shape[0]
    Nk = A.shape[1]
    if A.shape[0] == 0:
        return identity(len(K0)), matrix(zeros(shape=(len(K0),1))), matrix(array([x for x in K0])).T
    U, W, V, Winv, Ainv, rank = adjusted_svd(A)
    if rank < Nc:
        raise Exception('Bad SVD: Some constraints may be incompatible')
    if rank != Nc:
        raise Exception('Too many constraints.  %i constraints may be linear combinations of the others.' % (Nc - rank))
    for i in xrange(Nc):
        for j in xrange(Nc,Nk):
            if abs(U[i,j] > 1e-5):
                raise Exception('Bad SVD: 0 != U[c,non] == %.6g' % U.A[i,j])
    Acns = matrix(V[:,Nc:]) # (24)
    Bcns = Ainv * B         # (29)
    Out0 = Acns.T*K0        # (13)
    return Acns, Bcns, Acns.T, Out0

# And yet another linear transformation, this time to scale initial parameter values to 1.0.
# (In the case where p[i] == 0, we scale by 1 and translate by -1.)
# P = As*S + Bs
# S = As.I * (P - Bs)

def start_at_ones(P0):
    """Returns As, Bs, As.I, S0 [32]."""
    scales = array([x or 1.0 for x in P0.A[:,0]])
    reverses = array([x and 1/x or 1.0 for i,x in enumerate(P0.A[:,0])])
    Ascal = matrix(diag(scales))          # (35)
    Bscal = matrix(zeros(shape=P0.shape)) # (36)
    for i,x in enumerate(P0.A[:,0]):
        if not x:
            Bscal[i] = -1.0
    return Ascal, Bscal, matrix(diag(reverses)), matrix(ones(P0.shape))

def KToPars(rates, A, Ainv, B): # (13)
    k = matrix(zeros(shape=(prod(rates.shape),1)))
    k[:,0] = matrix(rates.flatten()).T
    return Ainv*k

# variant 1: pars is a python or nd- array, such as S provided by an optimizer
def parsToK(pars, A, Ainv, B): # (12, 32)
    pa = array(pars).flatten()
    P = matrix(pa).T
    K = ParsToK(P, A, Ainv, B)
    return K

# variant 2: pars is a column vector, such as P = parsToK(S, ...)
def ParsToK(P, A, Ainv, B): # (12, 32)
    return A*P+B


# To set up, you provide K0,K1,L,V and constraints, and we derive
# P and S and their transforms.  Then the optimizer adjusts S and we use it
# to build the Q matrices:
#   S --parsToK---> P --ParsToK---> K --KtoK01---> K0,K1 --BuildQ---> Q



# Here we use simplex to maximize likelihood of user-supplied or stock data and model.
# TODO:  standardize output for diff-testing
#        explain


TDEAD = 0.099 # < sampling rate since the simulator is perfect
MODELFILE = 'mil_test.qmf'
DATAFILE = 'mil_test.qsf'

if __name__ == '__main__':
    import sys
    import qubx.model
    import qubx.tree
    qubx.tree.CHOOSE_FLAVOR('numpy')

    if '--help' in sys.argv:
        print """usage: max_inter_ll.py [<tdead=0 in ms> [<model>.qmf [<data>.qsf]]]
or     max_inter_ll.py --verify"""
        sys.exit(0)
    if '--verify' in sys.argv:
        import subprocess
        p = subprocess.Popen(sys.argv[0] + " | diff mil_test.out -", shell=True)
        p.communicate()
        sys.exit(0)

    tdead_ms = (len(sys.argv) > 1) and float(sys.argv[1]) or TDEAD
    print 'Tdead:', tdead_ms

    model_path = (len(sys.argv) > 2) and sys.argv[2] or MODELFILE
    data_path = (len(sys.argv) > 3) and sys.argv[3] or DATAFILE

    model = qubx.model.OpenQubModel(model_path)
    print model
    rates = model.rates
    wanted_signals = set([rates.get(r, 'Ligand') for r in xrange(rates.size)] +
                         [rates.get(r, 'Voltage') for r in xrange(rates.size)] +
                         [rates.get(r, 'Pressure') for r in xrange(rates.size)])

    qsf = qubx.tree.Open(data_path, True)
    qsf.close()
    raw_segs = [(seg['Classes'].storage.data[:,0],
                 seg['Durations'].storage.data[:,0])
                for seg in qubx.tree.children(qsf['Idealization']['Channel'], 'Segment')]
    segments = ProcessSegments(tdead_ms, raw_segs)
    print 'Events:', sum(len(cl) for cl,du in segments)

    clazz = qubx.model.ModelToClazz(model)
    P0 = qubx.model.ModelToP0(model)
    print 'P0',P0
    K0, K1, K2, L, V, P = qubx.model.ModelToRateMatricesP(model)
    K = K012toK(K0, K1, K2)
    constants = []
    
    t0 = time.time()
    LL = inter_ll(clazz, P0, K0, K1, K2, L, V, P, constants, tdead_ms*1e-3, segments, printout=True)
    print 'Initial LL:', LL, '[', ('%.3f' % (time.time() - t0)), 'sec ]'
    
    if HAVE_LIB:
        t0 = time.time()
        LL = inter_ll_lib(clazz, P0, K0, K1, K2, L, V, P, constants, tdead_ms*1e-3, segments)
        print 'Compiled LL:', LL, '[', ('%.3f' % (time.time() - t0)), 'sec ]'
        
    constraints = model.constraints_kin.get_matrices_p(K0, K1, K2, L, V, P)
    print 'Constraints in:'
    print Asys
    print Bsys
    Acns, Bcns, Ainv, pars = linear_constraints(Asys, Bsys, K)
    print 'Constraints out:'
    print Acns
    print Bcns
    Ascal, Bscal, Asci, pars = start_at_ones(pars)
    if any(pars != 1.0):
        print 'Non-unit starting par!',pars
        print Ascal
        print Bscal
        print Asci
    
    mil = HAVE_LIB and inter_ll_lib or inter_ll
    def mil_func(pars):
        P = parsToK(pars, Ascal, Asci, Bscal)
        K = ParsToK(P, Acns, Ainv, Bcns)
        new_K0, new_K1, new_K2 = KtoK012(K, K0)
        LL = mil(clazz, P0, new_K0, new_K1, new_K2, L, V, P, constants, tdead_ms*1e-3, segments)
        print LL,pars
        return - LL
    counter = itertools.count()
    def mil_iter(pars):
        print '---> %i <---'%counter.next(), pars
        P = parsToK(pars, Ascal, Asci, Bscal)
        K = parsToK(P, Acns, Ainv, Bcns)
        print 'K',K.T
        new_K0, new_K1, new_K2 = KtoK012(K, K0)
        qubx.model.K012ToModel(new_K0, new_K1, new_K2, model)
        print model.rates
        print '------------->    <---'
    
    pars, fopt, iter, funcalls, warnflag = \
        scipy.optimize.fmin(mil_func, pars, maxfun=1000,
                            xtol=.0001, ftol=.0001, maxiter=100,
                            full_output=1, disp=1, retall=0, callback=mil_iter)
    print 'final',
    mil_iter(pars)





# ----------------------
import struct
from functools import partial

# (c) 2010 Eric L. Frederich
#
# Python implementation of algorithms detailed here...
# from http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm

def c_mem_cast(x, f=None, t=None):
    '''
    do a c-style memory cast

    In Python...

    x = 12.34
    y = c_mem_cast(x, 'd', 'l')

    ... should be equivilent to the following in c...

    double x = 12.34;
    long   y = *(long*)&x;
    '''
    return struct.unpack(t, struct.pack(f, x))[0]

dbl_to_lng = partial(c_mem_cast, f='d', t='l')
lng_to_dbl = partial(c_mem_cast, f='l', t='d')
flt_to_int = partial(c_mem_cast, f='f', t='i')
int_to_flt = partial(c_mem_cast, f='i', t='f')

def ulp_diff_maker(converter, negative_zero):
    '''
    Getting the ulp difference of floats and doubles is similar.
    Only difference if the offset and converter.
    '''
    def the_diff(a, b):

        # Make a integer lexicographically ordered as a twos-complement int
        ai = converter(a)
        if ai < 0:
            ai = negative_zero - ai

        # Make b integer lexicographically ordered as a twos-complement int
        bi = converter(b)
        if bi < 0:
            bi = negative_zero - bi

        return abs(ai - bi)

    return the_diff

# double ULP difference
dulpdiff = ulp_diff_maker(dbl_to_lng, 0x8000000000000000)
# float  ULP difference
fulpdiff = ulp_diff_maker(flt_to_int, 0x80000000        )

# default to double ULP difference
ulpdiff = dulpdiff
ulpdiff.__doc__ = '''
Get the number of doubles between two doubles.
'''

MAX_ULP_DIFF = 4