#!/usr/bin/env python

import os
from scipy import stats
from scipy.special import gammaln
from scipy.special import gamma
import numpy as np
#from numba import jit
import pickle

# Global functions
def thProbFX(tt, debug = False):
    """
    multinomial probability
    """
    tt2 = np.exp(tt)
    tt3 = tt2/np.sum(tt2)
    return(tt3)

def logit(x):
    """
    logit function
    """
    return np.log(x) - np.log(1 - x)

def inv_logit(x):
    """
    inverse logit to get probability
    """
    return np.exp(x) / (1 + np.exp(x))

def distxy(x1,y1,x2,y2):
    """
    distance function for 2 points
    """
    return np.sqrt(np.power(x1 - x2, 2) + np.power(y1 - y2, 2))

#@jit
def matrixsub(arr1, arr2):
    """
    looping sub-function to calculate distance matrix among many points
    """
    ysize = arr1.shape[0]
    xsize = arr2.shape[0]
    out = np.empty((ysize, xsize), arr1.dtype)
    for y in range(ysize):
        for x in range(xsize):
            out[y,x] = arr1[y] - arr2[x]
    return out

def distmat(x1, y1, x2, y2):
    """
    distance-matrix function
    """
    dx = matrixsub(x1, x2)
    dy = matrixsub(y1, y2)
    dmat = np.sqrt(dx**2.0 + dy**2.0)
    return dmat

def initialPStoatTrapCaptFX(params, basicdata, availTrapNights, location, 
    g0Param, debug = False):      # prob that stoat was capt in trap
    """
    initial probability of capture of given stoat captured in specified traps
    """
    distToTraps = basicdata.distTrapToCell2[:, location]
    eterm = np.exp(-(distToTraps) / params.var2)           # prob stoat-trap pair
    pNoCapt = 1. - g0Param * eterm
    pNoCaptNights = pNoCapt**(availTrapNights)
    pNoCaptNights = np.where(pNoCaptNights == 1., .9999, pNoCaptNights)
    return 1 - pNoCaptNights


def NpredInitialFX(nsession, N, removeDat, rpara, Npred):
    """
    Initial predicted N from population model
    """
    for i in range(nsession)[0:nsession-1]:
        Nday = N[i] - removeDat[i]
        Nday = np.where(Nday < 0, 0, Nday)
        Nday = rpara[i+1] * Nday        # it[i+1]   # imm upto previous
        Npred[i+1] = Nday
    return(Npred)

def removeDatFX(nsession, stoat, session):
    """
    obtain number of stoats removed in each session
    """
    removeDat = np.arange(nsession)
    for i in range(nsession):
        removeDat[i] = np.sum(stoat[session==i])
    return(removeDat)

def multinomial_pmf(probs, counts):
    """
    multinomial probability
    """
    probssum = probs.sum()
    if probssum < 0.999 or probssum > 1.0001:
        raise ValueError("probs must sum to 1")
    if probs.size != counts.size:
        raise ValueError("probs and counts must be the same size")
    return gammaln(counts.sum() + 1.0) - gammaln(counts + 1.0).sum() + np.sum(np.log(probs)*counts)


def gamma_pdf(xx, shape, scale):
    """
    gamma pdf
    """
    gampdf = 1.0 / (scale**shape) / gamma(shape) * xx**(shape - 1) * np.exp(-(xx/scale))      
    return gampdf



class StoatData(object):
    def __init__(self, basicdata, params, debug = False):
        """
        class and functions to create latent variables of stoat location, trapped, trap id, trap probabilities
        """
        # session identifier
        stoatSession = np.repeat(np.array(range(basicdata.nsession)), params.maxN)
        # stoat id in each session
        stoatID = np.tile(np.arange(0, params.maxN), basicdata.nsession)
        # indicator if a stoat is present in a session
        stoatPres = np.zeros(np.multiply(params.maxN, basicdata.nsession))
        # location of a present stoat
        stoatLoc = np.zeros(np.multiply(params.maxN, basicdata.nsession), dtype=int)
        #indicator of present stoats that are removed
        stoatRemove = np.zeros(np.multiply(params.maxN, basicdata.nsession))
        # Trap id capturing a trapped stoat
        stoatTrapID = np.zeros(np.multiply(params.maxN, basicdata.nsession), dtype=int)
        # Probability of a present stoat being captured
        stoatPCapt = np.zeros(np.multiply(params.maxN, basicdata.nsession))
        # Probability of given stoat being captured by the trap that caught it
        stoatTrapPCapt =  np.zeros(np.multiply(params.maxN, basicdata.nsession))
        # array of capture states (1) captured; (2) vulnerable but not capt; (3) not vulnerable
        stoatCaptState = np.zeros((np.multiply(params.maxN, basicdata.nsession), 3), dtype = int)
        # array of capture states probabilities (1) captured; (2) vulnerable but not capt; (3) not vulnerable
        stoatStateProb = np.zeros((np.multiply(params.maxN, basicdata.nsession), 3))
        # loop to get initial values for latent variables
        for i in range(basicdata.nsession):
            tmpPres = stoatPres[stoatSession ==i]
            tmpPres[0:params.N[i]] = 1
            stoatPres[stoatSession == i] = tmpPres
            tmpLoc = stoatLoc[stoatSession == i]
            tmpPCapt = stoatPCapt[stoatSession == i]
            tmpCaptState = stoatCaptState[stoatSession == i]
            tmpStateProb = stoatStateProb[stoatSession == i]

            tmpRemove = stoatRemove[stoatSession == i]
            selStoatRem = np.random.permutation(range(params.N[i]))[0:basicdata.removeDat[i]]
            tmpRemove[selStoatRem] = 1
            stoatRemove[stoatSession == i] = tmpRemove
            selTrapSprungStillOpen = np.ones(basicdata.removeDat[i])

            tstoatcaptsess = basicdata.trapStoatCapt[basicdata.trapSession == i]
            selTrapIDCaptStoat = basicdata.trapID[tstoatcaptsess == 1]

            tmpstoatTrapID = stoatTrapID[stoatSession == i]
            tmpstoatTrapPCapt = stoatTrapPCapt[stoatSession == i]
            tnightsavailSession = basicdata.trapNightsAvail[basicdata.trapSession == i]

            for j in range(params.N[i]):

                tmpLoc[j] = basicdata.cellID[np.random.multinomial(1, params.thMultiNom, 1).flatten() == 1]

                tmpPTrapCapt = initialPStoatTrapCaptFX(params, basicdata, tnightsavailSession , tmpLoc[j],
                    params.g0, debug = False)
                tmpPCapt[j] = 1. - np.prod(1. - tmpPTrapCapt)
                # get state probabilities
                tmpStateProb[j, 0] = tmpPCapt[j] * (1.0 - params.vt[i])
                tmpStateProb[j, 1] = (1.0 - tmpPCapt[j]) * (1.0 - params.vt[i])
                tmpStateProb[j, 2] = params.vt[i]

                if tmpRemove[j] == 1:

                    sprungTrapsStillOpen = selTrapIDCaptStoat[selTrapSprungStillOpen == 1]

                    pCaptAvailTrap = tmpPTrapCapt[sprungTrapsStillOpen]

                    captID =  sprungTrapsStillOpen[pCaptAvailTrap == max(pCaptAvailTrap)]

                    tmpstoatTrapID[j] = captID[0]
                    tmpstoatTrapPCapt[j] = tmpPTrapCapt[tmpstoatTrapID[j]]
                    tmpstoatTrapPCapt[j] = np.where(tmpstoatTrapPCapt[j] < 1.00e-70, 1.00e-70, tmpstoatTrapPCapt[j])

                    selTrapSprungStillOpen[selTrapIDCaptStoat == captID[0]] = 0
                    # assign captured state
                    tmpCaptState[j, 0] = 1
                else:
                    # bernouilli prob of not being vulnerable
                    bernNotVulProb = tmpStateProb[j, 1] / (tmpStateProb[j, 1] + tmpStateProb[j, 2])
                    randNotVul = np.random.binomial(1, bernNotVulProb)
                    tmpCaptState[j, 1] = randNotVul
                    tmpCaptState[j, 2] = (1 - randNotVul)

            stoatLoc[stoatSession == i] = tmpLoc
            stoatPCapt[stoatSession == i] = tmpPCapt
            stoatTrapID[stoatSession == i] = tmpstoatTrapID
            stoatTrapPCapt[stoatSession == i] = tmpstoatTrapPCapt
            stoatCaptState[stoatSession == i] = tmpCaptState 
            stoatStateProb[stoatSession == i] = tmpStateProb

        self.stoatSession = stoatSession
        self.stoatID = stoatID
        self.stoatPres = stoatPres
        self.stoatLoc = stoatLoc
        self.stoatRemove = stoatRemove
        self.stoatTrapID = stoatTrapID
        self.stoatPCapt = stoatPCapt
        self.stoatTrapPCapt = stoatTrapPCapt
        self.stoatPCapt_s = stoatPCapt.copy()
        self.stoatTrapPCapt_s = stoatTrapPCapt.copy()
        self.stoatCaptState = stoatCaptState
        self.stoatStateProb = stoatStateProb
        self.stoatStateProb_s = stoatStateProb.copy()


class BasicData(object):
    def __init__(self, captFname, newCaptFname, datesFname, trapFname, 
                covFname, maxTrapNights = 14):
        self.maxTrapNights = maxTrapNights

        self.capt6 = np.genfromtxt(captFname,  delimiter=',', names=True, 
            dtype=['i8', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'S10', 'S10',
            'i8', 'i8', 'i8', 'S12', 'i8', 'i8', 'i8', 'f8', 'i8', 'f8', 'i8'])

        self.newCapt = np.genfromtxt(newCaptFname, delimiter=',', names=True,
            dtype=['i8', 'S10', 'i8','f8'])  

        self.trap = np.genfromtxt(trapFname, delimiter=',', names=True,
            dtype=['S10', 'f8', 'f8', 'f8', 'f8'])  


        self.stoat = self.capt6['stoat']
        self.session = self.capt6['session'] - 1
        self.avail = np.where(self.capt6['avail']==0.25, 0.5, self.capt6['avail'])

        self.dates = np.genfromtxt(datesFname, delimiter=',', names=True,
            dtype=['S10', 'i8', 'S10', 'i8', 'i8', 'i8', 'S10', 'i8',
            'i8', 'i8', 'f8', 'i8', 'i8'])
        self.month = self.dates['mo']
        self.interval = self.dates['interval']
        self.intervalSession = np.where(self.interval < (self.maxTrapNights + 1),
            self.interval, self.maxTrapNights)
        self.immPeriod = self.dates['immPeriod']
        self.vPeriod = self.dates['g0Period']
                
        self.getTrapIDFX()

        # run function to add new capture data from 11/13 - 7/15
        self.addNewCaptData()

        self.gettrapSession()


        self.getTrapNightsFX()

        self.covDat = np.genfromtxt(covFname, delimiter=',', names=True,
            dtype=['f8', 'f8', 'i8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8'])
        self.cellX = self.covDat['x1']
        self.cellY = self.covDat['x2']       
        self.eastCov = self.covDat['x1'] - min(self.covDat['x1'])  
        self.northCov = self.covDat['x2'] - min(self.covDat['x2'])
        self.nw = self.covDat['cellDevNW']
        self.ndev = self.covDat['cellDevN']
        self.terrIndx = self.covDat['terrIndx']
        self.ncell = len(self.eastCov)
        
        self.scaleEast = (self.eastCov - np.mean(self.eastCov)) / np.std(self.eastCov)
        self.scaleNorth = (self.northCov - np.mean(self.northCov)) / np.std(self.northCov)
        self.scaleNW = (self.nw - np.mean(self.nw)) / np.std(self.nw)
        self.scaleDevN = (self.ndev - np.mean(self.ndev)) / np.std(self.ndev)
        self.scaleTerrIndx = (self.terrIndx - np.mean(self.terrIndx)) / np.std(self.terrIndx)
        
        self.scaleEast2 = (self.eastCov**2 - np.mean(self.eastCov**2)) / np.std(self.eastCov**2)
        self.scaleNorth2 = (self.northCov**2 - np.mean(self.northCov**2)) / np.std(self.northCov**2)
        ##########################################################
        ##########################################################
        ### change habitat covariates
        self.xdat = np.hstack([np.expand_dims(self.scaleEast,1), 
            np.expand_dims(self.scaleNorth,1)])
        self.nbcov = np.shape(self.xdat)[1]
        ##########################################################
        ##########################################################
        self.removeDat = removeDatFX(self.nsession, self.stoat, self.session) 
        self.cellID = np.arange(0, self.ncell, dtype = int)
        self.nStoatInCellTemplate = np.zeros(self.ncell)

        distTrapToCell = distmat(self.trapX, self.trapY, self.cellX, self.cellY)
        self.distTrapToCell2 = distTrapToCell**2.0

    ##############################################################
    ######################## Basicdata functions
    def getTrapIDFX(self):
        self.ntrap = len(self.trap['sid'])
        trapSid = self.trap['sid']
        self.trapID = np.arange(self.ntrap)
        self.trapX = self.trap['easting']
        self.trapY = self.trap['northing']
        captSid = self.capt6['sid']
        nCaptDat = len(captSid)
        self.captTrapID = np.empty(nCaptDat, dtype = int)           # trap ID in capt6 data
        for i in range(nCaptDat):
            self.captTrapID[i] = self.trapID[captSid[i] == trapSid]


    def addNewCaptData(self):
        """
        add capture data from 11/13 - 7/15
        """
        newSess = self.newCapt['session'] - 1
        newSid = self.newCapt['sid']
        newStoat = self.newCapt['stoat']
        newAvail = self.newCapt['avail']
        uSess = np.unique(newSess)
        nNewSess = len(uSess)
        trapSid = self.trap['sid']                  # old trap id from trap data
        nAdd = nNewSess * self.ntrap
        # new arrays to populate
        addSess = np.repeat(uSess, self.ntrap)
        addTrapID = np.tile(self.trapID, nNewSess)  # new trap id
        addStoat = np.zeros(nAdd, dtype = int)
        addAvail = np.ones(nAdd)
        for i in range(nAdd):
            addtrap_i = addTrapID[i]
            linkTrapSid = trapSid[self.trapID == addtrap_i]
            tmpMask = (newSess == addSess[i]) & (newSid == linkTrapSid)
            stoatDat = newStoat[tmpMask]
            availDat = newAvail[tmpMask]
            if np.sum(tmpMask) == 0 :
                addStoat[i] = 0
                addAvail[i] = 0.0
#                print('sess', addSess[i], 'addtrap_i', addtrap_i)
            else:  
                addStoat[i] = stoatDat 
                addAvail[i] = availDat
        self.session = np.append(self.session, addSess)
        self.captTrapID = np.append(self.captTrapID, addTrapID)
        self.stoat = np.append(self.stoat, addStoat)
        self.avail = np.append(self.avail, addAvail)
        self.avail[self.stoat == 1] = 0.5
        self.nsession = np.int(max(self.session + 1))

    def gettrapSession(self):
        sess = np.arange(self.nsession)
        self.trapSession = np.repeat(sess, self.ntrap)

    def getTrapNightsFX(self):
        self.trapNightsAvail = np.zeros(self.nsession * self.ntrap)
        self.trapStoatCapt = np.zeros(self.nsession * self.ntrap)        # stoat captures by trap and session             
        for i in range(self.nsession):
            sessmask = self.trapSession == i
            captIDSession = self.captTrapID[self.session == i]   # trap ID in capt6 data
            availSession = self.avail[self.session == i]         # avail in capt6 data
            tmpTN = np.zeros(self.ntrap)
            tmptrapstoatcapt = np.zeros(self.ntrap)
            stoatSession = self.stoat[self.session == i]
            for j in range(len(captIDSession)):
                tid = self.trapID[self.trapID == captIDSession[j]]
                tmpTN[tid] = availSession[j] * self.intervalSession[i]
                tmptrapstoatcapt[tid] = stoatSession[j]
            self.trapNightsAvail[sessmask] = tmpTN
            self.trapStoatCapt[sessmask] = tmptrapstoatcapt
                
class Params(object):
    def __init__(self, basicdata):
        self.sigma = 320.6
        self.sigma_s = 320.0
        self.sigma_mean = 255
        self.sigma_sd = 60
        self.sigma_search_sd = 4.0

        self.g0 = .03
        self.g0_s = self.g0 - .0001
        self.g0_alpha = 0.5 
        self.g0_beta = 0.5  
        self.g0Sd = .2

        self.v = np.array([0.25, 0.5, 0.10])                    # prob not vulnerable or catchable - beta priors
        self.v_s = self.v + 0.01
        self.vt = self.v[basicdata.vPeriod]
        self.vt_s = self.v_s[basicdata.vPeriod]
        self.v_alpha = .5
        self.v_beta = .5
        self.v_search = 0.15

        self.ig = np.array([0.0, 0.0, 0.0, 0.0])
        self.it =  self.ig[basicdata.immPeriod]     #np.multiply(self.ig, np.divide(np.float32(basicdata.interval),365))
 
        self.N = np.array([460, 170, 120, 65, 120, 54, 62, 50, 40, 8, 8, 8, 10, 49, 28, 20, 44,
                                        19, 40, 32, 19, 40, 44, 30])
        self.rg = 3.67   # 8.7
        self.rs = 3.7   # 8.86
        self.r_shape = .001     
        self.r_scale = 1000.0  
        self.b = np.array([0.018, 0.06])
        self.bs = np.random.normal(self.b, .01)
        self.nbcov = len(self.b)
        self.lth = np.dot(basicdata.xdat,self.b)
        self.thMultiNom = thProbFX(self.lth, debug = False)
        self.lth_s = np.dot(basicdata.xdat,self.bs)
        self.thMultiNom_s = thProbFX(self.lth_s, debug = False)

        self.rpara = np.where(basicdata.month==11,self.rg,1)
        self.rpara_s = np.where(basicdata.month==11,self.rs,1)
        
        self.Npred = np.multiply(self.N, 1.00)
        self.Npred = NpredInitialFX(basicdata.nsession, self.N, basicdata.removeDat,
            self.rpara, self.Npred)
        self.Npred_s = self.Npred.copy()

        self.maxN = 600
        self.bPrior = 0
        self.bPriorSD = np.sqrt(100)
        self.nsample = np.array([-1, 1])
        self.datseq = np.arange(0, self.maxN)
        self.var2 = 2.0 * (self.sigma**2.0)
        self.var2_s = 2.0 * (self.sigma_s**2.0)
        self.llikTh = np.empty(basicdata.nsession)
        self.llikTh_s = np.empty(basicdata.nsession)
        self.llikR = np.empty(basicdata.nsession - 1)
        self.llikR_s = np.empty(basicdata.nsession - 1)
        self.llikImm = np.empty(basicdata.nsession - 1)
        self.llikImm_s = np.empty(basicdata.nsession - 1)
        self.llikg0Sig = np.empty(basicdata.nsession)
        self.llikg0Sig_s = np.empty(basicdata.nsession)
        self.expTermMat =  np.exp(-(basicdata.distTrapToCell2) / self.var2)
        self.expTermMat_s =  np.exp(-(basicdata.distTrapToCell2) / self.var2_s)

class MCMC(object):
    def __init__(self, params, stoatdata, basicdata):
        self.ngibbs = 2500
        self.bgibbs = np.zeros([self.ngibbs, len(params.b)])
        self.rgibbs = np.zeros(self.ngibbs)
        self.Ngibbs = np.zeros([self.ngibbs, len(params.N)])
        self.igibbs = np.zeros((self.ngibbs, 2))
        self.ggibbs = np.zeros(self.ngibbs)
        self.siggibbs = np.zeros(self.ngibbs)
        self.vgibbs = np.zeros((self.ngibbs, 3))
        self.deviancegibbs = np.zeros(self.ngibbs)

        self.thinrate = 250 
        self.burnin = 100000     # 2000
        self.keepseq = np.arange(self.burnin, ((self.ngibbs * self.thinrate) + self.burnin), 
            self.thinrate)
        self.params = params
        self.stoatdata = stoatdata
        self.basicdata = basicdata

        # storage array for prob of trapping data
        self.pTrappingData = 0.0    #np.empty(self.basicdata.nsession)
        self.pTrappingData_s = 0.0  #np.empty(self.basicdata.nsession)
        self.pState = np.zeros(3)   # storage array for proposed state probabilities.

    def PCaptStoatFX(self, availTrapNights, eTermMat, g0_i, 
            debug = False):   # prob stoat capt in 1 of many traps
        pNoCapt = 1.0 - g0_i * eTermMat
        pNoCaptNights = pNoCapt**(availTrapNights)
        pNoCaptNights = np.where(pNoCaptNights == 1.0, 0.9999, pNoCaptNights)
        pNoCaptNightsTraps = pNoCaptNights.prod(axis = 0)
        return 1.0 - pNoCaptNightsTraps

    def PStoatTrapCaptFX(self, availTrapNights, eTermMat, g0_i, debug = False):      # prob that stoat was capt in trap
        pNoCapt = 1.0 - g0_i * eTermMat
        pNoCaptNights = pNoCapt**(availTrapNights)
        pNoCaptNights = np.where(pNoCaptNights == 1.0, 0.9999, pNoCaptNights)
        return 1.0 - pNoCaptNights


    def getBetaFX(self):
        a = self.params.g0 * ((self.params.g0 * (1.0 - self.params.g0)) / self.params.g0Sd**2.0 - 1.0)
        b = (1.0 - self.params.g0) * ((self.params.g0 * (1.0 - self.params.g0)) / self.params.g0Sd**2.0 - 1.0)
        return np.random.beta(a, b, size = None)

########
########       Block of functions for updating N
########
  
    def N_stoatdata_updateFX(self):
        self.devianceSum = 0.0
        for i in range(self.basicdata.nsession):
                                                ####### Update N
                                                 # get proposed Ns and stoatdata_s and llik
            (Ns, sessionMask, stoatLoc_s, stoatPres_s, stoatPCapt_s, llik, llik_s,
                availTrapNights, presOnlyMask, presOnlyMask_s, trapSessionMask,
                stoatCaptState_s, stoatStateProb_s) = self.proposeNFX(i, 
                debug = False)


                                                # pnow pnew Npred            
            presOnlyMask = self.N_PnowPnewFX(Ns, i, llik, 
                llik_s, presOnlyMask, presOnlyMask_s, stoatLoc_s, stoatPres_s, 
                stoatPCapt_s, sessionMask, stoatCaptState_s, stoatStateProb_s,
                debug = False)
                                                # returns stoatloc
                                                # returns stoatPres
                                                # returns stoatPCapt

    
                                                ########### update stoat location
                                                # propose new stoatloc and pcapt
            (stoatLocSession, stoatPCaptSession, stoatRemoveSession, stoatPresSession, 
                stoatTrapIDSession, stoatTrapPCaptSession, remMask, 
                nStoatCellSession, stoatStateProbSession,
                stoatCaptStateSession) = self.updateStoatLocFX(i, 
                sessionMask, availTrapNights, presOnlyMask, debug = False) 


                                                 ######## calc th likelihoods by session
            self.thetaLikelihoodFX(i, stoatLocSession, presOnlyMask, sessionMask)


                                                ######## get llik for g0 and g0_s
            self.g0SigLLikFX(i, availTrapNights, stoatLocSession, stoatTrapIDSession,
                stoatPCaptSession, stoatTrapPCaptSession, sessionMask, presOnlyMask, 
                stoatRemoveSession, remMask, stoatStateProbSession,
                stoatCaptStateSession)



#########
########  End updating N and stoatdata
########

    def proposeNFX(self, i, debug = False):    # get proposed Ns and assoc stoat data   
                                                                        # use FX in NupdateFX
        nRand = np.random.permutation(self.params.nsample)[0]
        rDat = self.basicdata.removeDat[i]
        Ns = self.params.N[i] + nRand
        Ns = np.where(Ns <= 0, 1, Ns)
        Ns = np.where(Ns <= rDat, self.params.N[i] + 1, Ns)
        Ns = np.where(Ns == self.params.N[i], self.params.N[i] + 1, Ns)
        Ns = np.where(Ns >  self.params.maxN, self.params.N[i] - 1, Ns)
        Ns = np.where(Ns == self.params.N[i], self.params.N[i] - 1, Ns)
        if (i == 0) & (Ns < 330):
            Ns = self.params.N[i] + 1
        sessionMask = self.stoatdata.stoatSession == i

        stoatRemoveSession = self.stoatdata.stoatRemove[sessionMask]
        stoatPresSession = self.stoatdata.stoatPres[sessionMask]
        stoatPCaptSession = self.stoatdata.stoatPCapt[sessionMask]
        presOnlyMask = stoatPresSession == 1                    #mask of pres only
        stoatLoc_s = self.stoatdata.stoatLoc[sessionMask].copy()
        stoatPres_s = stoatPresSession.copy()
        stoatPCapt_s = stoatPCaptSession.copy()
        # states and state probabilities for session i
        stoatCaptStateSession = self.stoatdata.stoatCaptState[sessionMask]
        stoatStateProbSession = self.stoatdata.stoatStateProb[sessionMask]
        stoatStateProb_s = stoatStateProbSession.copy()
        stoatCaptState_s = stoatCaptStateSession.copy()

        trapSessionMask = self.basicdata.trapSession == i
        availTrapNights = np.expand_dims(self.basicdata.trapNightsAvail[trapSessionMask], 1)         
                                                            # get proposed stoat data for Ns
        (presOnlyMask_s, stoatPres_s, stoatLoc_s, stoatPCapt_s, stoatStateProb_s,
            stoatCaptState_s)  = self.proposeStoatFX(Ns, i, stoatPres_s, stoatLoc_s,
            stoatPCapt_s, availTrapNights, stoatRemoveSession, 
            stoatPresSession, presOnlyMask, stoatStateProb_s,
            stoatCaptState_s,  debug = False)

        # get multinomial log likelihood for N and Ns
        llik = self.NMultiNomLik(self.params.N[i], stoatCaptStateSession[presOnlyMask],
            stoatStateProbSession[presOnlyMask])
        llik_s = self.NMultiNomLik(Ns, stoatCaptState_s[presOnlyMask_s],
            stoatStateProb_s[presOnlyMask_s])

        return (Ns, sessionMask, stoatLoc_s, stoatPres_s, stoatPCapt_s, llik, llik_s,
            availTrapNights, presOnlyMask, presOnlyMask_s, trapSessionMask,
            stoatCaptState_s, stoatStateProb_s)


    def N_PnowPnewFX(self, Ns, i, llik, llik_s, presOnlyMask, presOnlyMask_s,
        stoatLoc_s, stoatPres_s, stoatPCapt_s, sessionMask, 
        stoatCaptState_s, stoatStateProb_s, debug = False):

        if i == 0:
            Npred2_s = self.npredFX(i, Ns, self.params.rpara[1], 0.0)     # it[1]
            poisNow = np.log(stats.poisson.pmf(self.params.N[1], self.params.Npred[1]))
            poisNew = np.log(stats.poisson.pmf(self.params.N[1], Npred2_s))
            pnow = llik + poisNow
            pnew = llik_s + poisNew
            self.params.Npred[0] = self.params.N[0]

        if (i > 0) and (i < (self.basicdata.nsession - 1)):
            self.params.Npred[i] = self.npredFX((i - 1), self.params.N[i - 1], 
                self.params.rpara[i], 0.0)                                #it[i]
            Npred2_s = self.npredFX(i, Ns, self.params.rpara[i+1], 0.0)         #it[i+1]  
            dpoisN1 =  np.log(stats.poisson.pmf(self.params.N[i], self.params.Npred[i]))
            dpoisN2 =  np.log(stats.poisson.pmf(self.params.N[i + 1], self.params.Npred[i + 1]))
            pnow = llik + dpoisN1 + dpoisN2

            dpoisN1_s = np.log(stats.poisson.pmf(Ns, self.params.Npred[i]))
            dpoisN2_s = np.log(stats.poisson.pmf(self.params.N[i + 1], Npred2_s))
            pnew = llik_s + dpoisN1_s + dpoisN2_s

        if i == (self.basicdata.nsession -1):
            self.params.Npred[i] = self.npredFX((i - 1), self.params.N[i - 1], 
                self.params.rpara[i], 0.0)
            pnow = llik + np.log(stats.poisson.pmf(self.params.N[i], self.params.Npred[i]))
            pnew = llik_s + np.log(stats.poisson.pmf(Ns, self.params.Npred[i])) 
            Npred2_s = 1.0

        rValue = np.exp(pnew - pnow)        # calc importance ratio            
        zValue = np.random.uniform(0,1, size = None)
        if rValue > zValue:
            self.devianceSum += pnew       
            if i < (self.basicdata.nsession -1):
                self.params.Npred[i + 1] = Npred2_s
            self.params.N[i] = Ns
            self.stoatdata.stoatLoc[sessionMask] = stoatLoc_s
            self.stoatdata.stoatPres[sessionMask] = stoatPres_s
            self.stoatdata.stoatPCapt[sessionMask] = stoatPCapt_s
            presOnlyMask = presOnlyMask_s.copy()
            self.stoatdata.stoatCaptState[sessionMask] = stoatCaptState_s
            self.stoatdata.stoatStateProb[sessionMask] = stoatStateProb_s
        else:
            self.devianceSum += pnow
        return presOnlyMask


    def npredFX(self, ii, nn, rr, Immigr):                        # get Npred and Npred.s for new N
        Ntmp = nn - self.basicdata.removeDat[ii]
        npredout = (rr * (Ntmp + Immigr))
        return npredout
                                                                
                                                                # FX nested in proposeNFX
    def proposeStoatFX(self, Ns, i, stoatPres_s, stoatLoc_s,
        stoatPCapt_s, availTrapNights, stoatRemoveSession, 
        stoatPresSession, presOnlyMask, stoatStateProb_s,
        stoatCaptState_s, debug = False):
        
        if Ns > self.params.N[i]:
            potentialAddIndx = self.params.datseq[presOnlyMask == False]
            nadd = Ns - self.params.N[i]            # number stoats added
            randindx = np.random.randint(0, self.basicdata.ncell, nadd)
            add_id = np.random.choice(potentialAddIndx, nadd, replace = False)
            stoatPres_s[add_id] = 1
            stoatLoc_s[add_id] = randindx
            pCaptNew = self.PCaptStoatFX(availTrapNights, self.params.expTermMat[:,randindx],
                self.params.g0)
            stoatPCapt_s[add_id] = pCaptNew
            self.getStateProbs(pCaptNew, self.params.vt[i]) 
            stoatStateProb_s[add_id] = self.pState
            stoatCaptState_s[add_id, 1] = self.vulnState            
            stoatCaptState_s[add_id, 2] = 1 - self.vulnState           

        if Ns < self.params.N[i]:
            nrem = self.params.N[i] - Ns
            removeSessionMask = stoatRemoveSession == 0
            stoatPresMask = stoatPresSession == 1
            potentialRemoveMask = stoatPresMask & removeSessionMask
            potentialRemove = self.params.datseq[potentialRemoveMask]        # select from 0:maxN
            remIndx = np.random.choice(potentialRemove, nrem, replace = False)
            stoatLoc_s[remIndx] = 0
            stoatPres_s[remIndx] = 0
            stoatPCapt_s[remIndx] = 0.0
            stoatStateProb_s[remIndx] = 0.0
            stoatCaptState_s[remIndx] = 0

        presOnlyMask_s = stoatPres_s == 1
        return (presOnlyMask_s, stoatPres_s, stoatLoc_s, stoatPCapt_s, stoatStateProb_s,
            stoatCaptState_s) 

    def getStateProbs(self, pcapt, vpara):
        """
        get multinomial probs of 3 states
        """
        self.pState[0] = pcapt * (1.0 - vpara)
        self.pState[1] = (1.0 - pcapt) * (1.0 - vpara)
        self.pState[2] = vpara
        bernVulnProb = self.pState[1] / (self.pState[1] + self.pState[2])
        self.vulnState = np.random.binomial(1, bernVulnProb)


    def stateProb_Multi_Obs(self, pcapt, vpara):
        """
        get multinomial probs of 3 states for multiple observations
        """
        probArray = np.zeros((len(pcapt), 3))
        probArray[:, 0] = pcapt * (1.0 - vpara)
        probArray[:, 1] = (1.0 - pcapt) * (1.0 - vpara)
        probArray[:, 2] = vpara
        return probArray

    @staticmethod
    def NMultiNomLik(n, captState, stateProb):     # FX get multinomial log lik
        stateTotals = np.sum(captState, axis = 0) 
        comboTerm = gammaln(n + 1) - gammaln(stateTotals[1] + 1) - gammaln(stateTotals[2] + 1)
        prod1 = captState * stateProb
        sum1 = np.sum(prod1, axis = 1)
        lProb = np.sum(np.log(sum1))
        llik_term = comboTerm + lProb
        return llik_term        

    @staticmethod
    def NLogLikFX(n, remd, captp, sessrem, debug=False):     # FX get binomial log lik
                                                             # FX nested in ProposeNFX
        comboTerm = gammaln(n + 1) - gammaln(n - remd + 1)
        first = captp * sessrem
        second = (1.0 - captp) * (1.0 - sessrem)
        binProball = first + second
        lProb = np.sum(np.log(binProball))
        llik_term = comboTerm + lProb
        return llik_term


    @staticmethod
    def ZLoc_MultiNomLik(captState, stateProb):     # FX get multinomial log lik
        prod1 = captState * stateProb
        sum1 = np.sum(prod1, axis = 1)
        lProb = np.sum(np.log(sum1))
        return lProb


#########
#########           End block of functions for updating N


#######
#######            Begin functions for updating stoatLoc
#######
    def updateStoatLocFX(self, i, sessionMask, availTrapNights,
        presOnlyMask, debug = False):

        stoatPresSession = self.stoatdata.stoatPres[sessionMask]
        stoatLocSession = self.stoatdata.stoatLoc[sessionMask]
        stoatPCaptSession = self.stoatdata.stoatPCapt[sessionMask]
        stoatRemoveSession = self.stoatdata.stoatRemove[sessionMask]
        remMask = stoatRemoveSession == 1               # mask only stoats removed
        stoatPCapt_s = stoatPCaptSession.copy()
        stoatLoc_s = stoatLocSession.copy()
        stoatTrapIDSession = self.stoatdata.stoatTrapID[sessionMask]
        stoatTrapPCaptSession = self.stoatdata.stoatTrapPCapt[sessionMask]
        # states and state probabilities for session i
        stoatCaptStateSession = self.stoatdata.stoatCaptState[sessionMask]
        stoatStateProbSession = self.stoatdata.stoatStateProb[sessionMask]
        stoatStateProb_s = stoatStateProbSession.copy()
        
        nChangeLoc = int(np.ceil(self.params.N[i] * .20))
        potentialChangeID = self.params.datseq[presOnlyMask]
        changeID =  np.random.choice(potentialChangeID, nChangeLoc, replace = False)

        changeMask = np.in1d(self.params.datseq, changeID)
        Loc_s = np.random.randint(self.basicdata.ncell,size = nChangeLoc)
        stoatLoc_s[changeMask] = Loc_s

        pCaptNew = self.PCaptStoatFX(availTrapNights, self.params.expTermMat[:, Loc_s],
            self.params.g0)

        stoatPCapt_s[changeMask] = pCaptNew

        stoatStateProb_s[changeMask] = self.stateProb_Multi_Obs(pCaptNew, self.params.vt[i])

        # get multinomial log likelihood
        Z_llik = self.ZLoc_MultiNomLik(stoatCaptStateSession[changeMask],
            stoatStateProbSession[changeMask])
        Z_llik_s = self.ZLoc_MultiNomLik(stoatCaptStateSession[changeMask],
            stoatStateProb_s[changeMask])

        (Loc_llik, nStoatCellSession) = self.locMultiNom(stoatLocSession, presOnlyMask, 
            self.params.thMultiNom, debug = False)
        (Loc_llik_s, nStoatCellSession_s) = self.locMultiNom(stoatLoc_s, presOnlyMask, 
            self.params.thMultiNom, debug = False)

        changeRemoveMask = remMask & changeMask
        stoatTrapPCaptChangeRemove = stoatTrapPCaptSession[changeRemoveMask] 

        Loc_sChangeRemove = stoatLoc_s[changeRemoveMask]
        tid = stoatTrapIDSession[changeRemoveMask]

        etermsess = self.params.expTermMat[tid, Loc_sChangeRemove]

        tnightsavail = availTrapNights[tid]
        pCaptTrapNew = self.PStoatTrapCaptFX(tnightsavail, etermsess, 
            self.params.g0, debug = False)
        stoatTrapPCaptChangeRemove_s = pCaptTrapNew

        Trap_llik = np.sum(np.log(stoatTrapPCaptChangeRemove))
        Trap_llik_s = np.sum(np.log(stoatTrapPCaptChangeRemove_s))

        pnow = Z_llik + Loc_llik + Trap_llik
        pnew = Z_llik_s + Loc_llik_s + Trap_llik_s

        rValue = np.exp(pnew - pnow)        # calc importance ratio
        zValue = np.random.uniform(0,1, size = None)

        if rValue > zValue:
            self.stoatdata.stoatLoc[sessionMask] = stoatLoc_s          # don't need to return ???
            self.stoatdata.stoatPCapt[sessionMask] = stoatPCapt_s
            stoatPCaptSession = stoatPCapt_s.copy()
            stoatLocSession = stoatLoc_s.copy()
            stoatTrapPCaptSession[changeRemoveMask] = stoatTrapPCaptChangeRemove_s
            self.stoatdata.stoatTrapPCapt[sessionMask] = stoatTrapPCaptSession
            nStoatCellSession = nStoatCellSession_s
            self.stoatdata.stoatStateProb[sessionMask] = stoatStateProb_s
            stoatStateProbSession = stoatStateProb_s.copy()

        return (stoatLocSession, stoatPCaptSession, stoatRemoveSession, stoatPresSession, 
            stoatTrapIDSession, stoatTrapPCaptSession,remMask, nStoatCellSession,
            stoatStateProbSession, stoatCaptStateSession)


    def locMultiNom(self, stoatLocTest, presOnlyMask, thmn, debug = False):
        """
        Calc multinom density function for habitat
        """
        nStoatCell =  np.bincount(stoatLocTest[presOnlyMask], minlength = self.basicdata.ncell)
        log_mn = multinomial_pmf(thmn, nStoatCell)
        return (log_mn, nStoatCell)
        

#######
######      End function for updating stoat location



#######
###########         begin Beta - theta update ##########
#####
    def thetaLikelihoodFX(self, i, stoatLocSession, presOnlyMask, sessionMask):
        """
        get multinomial llik for each session
        """
        nStoatCell = np.bincount(stoatLocSession[presOnlyMask], minlength = self.basicdata.ncell)
        self.params.llikTh[i] = multinomial_pmf(self.params.thMultiNom, nStoatCell)
        self.params.llikTh_s[i] = multinomial_pmf(self.params.thMultiNom_s, nStoatCell)


    def betaUpdateFX(self):
        """
        updata betas all at same time
        """
        prior_pdf = stats.norm.logpdf(self.params.b, self.params.bPrior, self.params.bPriorSD)
        prior_pdf_s = stats.norm.logpdf(self.params.bs, self.params.bPrior, self.params.bPriorSD)

        pnow = np.sum(self.params.llikTh) + np.sum(prior_pdf)
        pnew = np.sum(self.params.llikTh_s) + np.sum(prior_pdf_s)
        rValue = np.exp(pnew - pnow)        # calc importance ratio
        zValue = np.random.uniform(0, 1.0, size = None)
        if rValue > zValue:
            self.params.b = self.params.bs.copy()
            self.params.lth = self.params.lth_s.copy()
            self.params.llikTh = self.params.llikTh_s.copy()
            self.params.thMultiNom = self.params.thMultiNom_s.copy()
        self.params.bs = np.random.normal(self.params.b, .02)
        self.params.lth_s = np.dot(self.basicdata.xdat, self.params.bs)
        self.params.thMultiNom_s = thProbFX(self.params.lth_s)



#######
###########         begin r update ##########
#####
    def rLLikFX(self):
        """
        get llik for rg and rs
        i in [0, nsession - 1]
        """
        for i in range(self.basicdata.nsession - 1):
            self.params.llikR[i] = np.log(stats.poisson.pmf(self.params.N[i+1], self.params.Npred[i+1]))
            self.params.Npred_s[i + 1] = self.npredFX(i, self.params.N[i], self.params.rpara_s[i + 1],
                0.0)
            self.params.llikR_s[i] = np.log(stats.poisson.pmf(self.params.N[i+1], self.params.Npred_s[i+1]))

    def rUpdateFX(self):
        """
        update rg in mcmc function
        """
        self.rLLikFX()
        pnow = np.sum(self.params.llikR) + np.log(gamma_pdf(self.params.rg,
            self.params.r_shape, (self.params.r_scale)))
        pnew = np.sum(self.params.llikR_s) + np.log(gamma_pdf(self.params.rs,
            self.params.r_shape, (self.params.r_scale)))

        rValue = np.exp(pnew - pnow)        # calc importance ratio
        zValue = np.random.uniform(0, 1.0, size = None)

        if rValue > zValue:
            self.params.rg = self.params.rs
            self.params.rpara = self.params.rpara_s.copy()
            self.params.Npred = self.params.Npred_s.copy()
        self.params.rs = np.exp(np.random.normal(np.log(self.params.rg), 0.16, size = None))
        self.params.rpara_s = np.where(self.basicdata.month==11, self.params.rs, 1.0)

#######
###########         begin I update ##########
#####
    def immLLikFX(self):
        """
        get llik for ig and is (immigration parameter)
        """
        for i in range(self.basicdata.nsession - 1):
            self.params.llikImm[i] = np.log(stats.poisson.pmf(self.params.N[i+1], self.params.Npred[i+1]))
            self.params.Npred_s[i + 1] = self.npredFX(i, self.params.N[i], self.params.rpara[i + 1],
                self.params.it_s[i + 1])
            self.params.llikImm_s[i] = np.log(stats.poisson.pmf(self.params.N[i+1], self.params.Npred_s[i+1]))
    
    def immUpdateFX(self):
        """
        update ig in mcmc function
        """
        self.immLLikFX()
        pnow = np.sum(self.params.llikImm) + np.sum(np.log(gamma_pdf(self.params.ig[:2],
            self.params.imm_shape, (self.params.imm_scale))))
        pnew = np.sum(self.params.llikImm_s) + np.sum(np.log(gamma_pdf(self.params.i_s[:2],
            self.params.imm_shape, (self.params.imm_scale))))

        rValue = np.exp(pnew - pnow)        # calc importance ratio
        zValue = np.random.uniform(0.0, 1.0, size = None)
        if rValue > zValue:
            self.params.ig = self.params.i_s.copy()
            self.params.it = self.params.it_s.copy()
            self.params.Npred = self.params.Npred_s.copy()
        self.params.i_s[:2] = np.exp(np.random.normal(np.log(self.params.ig[:2]), 0.6))

        self.params.it_s = self.params.i_s[self.basicdata.immPeriod]        #np.multiply(self.params.i_s, np.divide(np.float32(self.basicdata.interval),365.0))

#######
###########         begin g0 and Sigma  update ##########
#####

    def g0Sig_PStoatCaptFX(self, availTrapNights, eterm, g0Param, debug = False):     # prob that stoat was capt in trap
        pNoCapt = 1.0 - g0Param * eterm
        pNoCaptNights = pNoCapt**(availTrapNights)
        pNoCaptNights = np.where(pNoCaptNights >= 0.9999, 0.9999, pNoCaptNights)
        return pNoCaptNights


    @staticmethod
    def g0Sig_ProbsFX(captp, sessrem, debug=False):     # FX get binomial log lik
        first = captp * sessrem
        second = (1.0 - captp) * (1.0 - sessrem)
        binProball = first + second  
        llik_g0Sig = np.sum(np.log(binProball))
        return llik_g0Sig         

    def getPTrappingData(self, i, availTrapNights, eterm, g0_i, pNoCaptAll_s):
        """
        Get probabilities of trapping data given parameters and location
        The prob that traps catch the predators that they did - without predator id
        """
        ######  use proposed parameters
        pCaptAll_s = 1.0 - np.prod(pNoCaptAll_s, axis = 1)      # p Capt at unique locs in all traps
        pCaptAllSumOne_s = pCaptAll_s / np.sum(pCaptAll_s)
        trapsessmask = self.basicdata.trapSession == i
        trapTrappedSession = self.basicdata.trapStoatCapt[trapsessmask]
        self.pTrappingData_s += multinomial_pmf(pCaptAllSumOne_s, trapTrappedSession)
        ######  use current parameters
        pCaptAllPredators = self.PCaptAllPreds_FX(availTrapNights,  eterm, g0_i)
        pCaptAllSumOne = pCaptAllPredators / np.sum(pCaptAllPredators)
        self.pTrappingData += multinomial_pmf(pCaptAllSumOne, trapTrappedSession)

    def getTrapMultinomial(self, i, availTrapNights, pNoCaptAll_s, remMask, remPresMask,
                presLoc, stoatTrapPCaptSession, stoatTrapPNoCapt_s):
        """
        Get probabilities of trapping data given parameters and location
        The prob that traps catch the specific predators that they did 
        """
        stoatTrapPairProb = stoatTrapPCaptSession[remMask]      # prob capt for stoat-trap pairs of stoats captured
        stoatTrapPairProb_s = 1.0 - stoatTrapPNoCapt_s      # proposed prob capt for stoat-trap pairs

        # get sum pCapt across traps for all stoats that were captured
        ######################
        # current go and sigma
        tempNoCaptProb = self.g0Sig_PStoatCaptFX(availTrapNights, 
                self.params.expTermMat[:, presLoc],self.params.g0)
        sumCaptProb = np.sum((1.0 - tempNoCaptProb[:, remPresMask]), axis = 0)      # sum pCapt across traps
        multiNomPTrap = stoatTrapPairProb / sumCaptProb
        # current stoat trap multinomial prob
        self.params.LL_multiNomPTrap += np.sum(np.log(multiNomPTrap))
        #######################
        # proposed go and sigma
        sumCaptProb_s = np.sum((1.0 - pNoCaptAll_s), axis = 0)   # sum pCapt across traps 
        sumCaptProb_s = sumCaptProb_s[remPresMask]              # reduce to those removed
        multiNomPTrap_s = stoatTrapPairProb_s / sumCaptProb_s
        self.params.LL_multiNomPTrap_s += np.sum(np.log(multiNomPTrap_s))

    def g0SigLLikFX(self, i, availTrapNights, stoatLocSession, stoatTrapIDSession,
        stoatPCaptSession, stoatTrapPCaptSession, sessionMask, presOnlyMask, 
        stoatRemoveSession, remMask, stoatStateProbSession, stoatCaptStateSession):
        """
        get llik for g0 and g0_s
        """
        presLoc = stoatLocSession[presOnlyMask]
#        (uLoc, indxLoc) = np.unique(presLoc, return_inverse = True) #unique Loc so limit matrix computation
                                                                    #get indices to feedback pCapt to stoat data
        eterm_s = self.params.expTermMat_s[:, presLoc]                 # exp term of unique locs
        pNoCaptAll_s = self.g0Sig_PStoatCaptFX(availTrapNights, eterm_s, 
            self.params.g0_s)
        pNoCapt_s = pNoCaptAll_s.prod(axis = 0)
        pCapt_s = 1.0 - pNoCapt_s                                   # p Capt at unique locs
        stoatPCaptLoc_sSession = stoatPCaptSession.copy()
        stoatPCaptLoc_sSession[presOnlyMask] = pCapt_s
        self.stoatdata.stoatPCapt_s[sessionMask] = stoatPCaptLoc_sSession #keep this in case keep new g0 and sig
                                                                    # the following gets info only for capt stoats
        remPresMask = stoatRemoveSession[presOnlyMask] == 1         # mask of removed stoats in sess i
        remIndx = presLoc[remPresMask]                              # locations only of captured stoats
        remtid = stoatTrapIDSession[remMask]                        # trap id of captured stoats
        stoatTrapPNoCapt_s = pNoCaptAll_s[remtid, remPresMask]          # p No Capt of captured stoats only
        stoatTrapPCapt_sSession = stoatTrapPCaptSession.copy()      # template
        stoatTrapPCapt_sSession[remMask] = 1.0 - stoatTrapPNoCapt_s # p capt fill in template
        self.stoatdata.stoatTrapPCapt_s[sessionMask] = stoatTrapPCapt_sSession # fill in class data in case keep new params.
        # get trapping likelihood
        captState_i = stoatCaptStateSession[presOnlyMask]
        stoatStateProb_s = self.stateProb_Multi_Obs(pCapt_s, self.params.vt[i])
        stoatStateProbSession_s = stoatStateProbSession.copy()
        stoatStateProbSession_s[presOnlyMask] = stoatStateProb_s
        self.stoatdata.stoatStateProb_s[sessionMask] = stoatStateProbSession_s
        self.params.llikg0Sig[i] = self.ZLoc_MultiNomLik(captState_i, stoatStateProbSession[presOnlyMask])
        self.params.llikg0Sig_s[i] = self.ZLoc_MultiNomLik(captState_i, stoatStateProb_s)
        #############
        # multinomial prob of traps catching preds that they caught or did not catch.
        # only stoats in states 0 (caught) because competition is among traps for captured stoats       
        if self.basicdata.removeDat[i] > 0:
            eterm = self.params.expTermMat[:, presLoc]                 # exp term of unique locs
        #############

    def PCaptAllPreds_FX(self, availTrapNights,  eTermMat, g0Sess):
        """
         # prob stoat capt in 1 of many traps in a given session
        """
        pNoCapt = 1.0 - (g0Sess * eTermMat)
        pNoCaptNights = pNoCapt**(availTrapNights)
        pNoCaptNights[pNoCaptNights >= .9999] = .9999
        pNoCaptNightsAll = np.prod(pNoCaptNights, axis = 1)
        pCaptAllPredators = 1.0 - pNoCaptNightsAll
        return (pCaptAllPredators)

    def g0SigUpdateFX(self):
        """
        update g0 in mcmc function
        """
        pnow = (np.sum(self.params.llikg0Sig) + np.log(stats.beta.pdf(self.params.g0, 
            self.params.g0_alpha, self.params.g0_beta)) + np.log(stats.norm.pdf(self.params.sigma, 
            self.params.sigma_mean, self.params.sigma_sd)))     # + self.pTrappingData)
        pnew = (np.sum(self.params.llikg0Sig_s) + np.log(stats.beta.pdf(self.params.g0_s, 
            self.params.g0_alpha, self.params.g0_beta)) + np.log(stats.norm.pdf(self.params.sigma_s, 
            self.params.sigma_mean, self.params.sigma_sd)))      # + self.pTrappingData_s)
        rValue = np.exp(pnew - pnow)        # calc importance ratio
        zValue = np.random.uniform(0.0, 1.0, size = None)

        if rValue > zValue:
            self.params.g0 = self.params.g0_s
            self.params.sigma = self.params.sigma_s
            self.stoatdata.stoatPCapt = self.stoatdata.stoatPCapt_s.copy()
            self.stoatdata.stoatTrapPCapt = self.stoatdata.stoatTrapPCapt_s.copy()
            self.params.expTermMat = self.params.expTermMat_s.copy()
            self.stoatdata.stoatStateProb = self.stoatdata.stoatStateProb_s.copy()
 
        self.params.g0_s = inv_logit(np.random.normal(logit(self.params.g0), self.params.g0Sd))
        self.params.sigma_s = np.random.normal(self.params.sigma, self.params.sigma_search_sd)
        self.params.var2_s = 2.0 * (self.params.sigma_s**2.0)
        self.params.expTermMat_s =  np.exp(-(self.basicdata.distTrapToCell2) / self.params.var2_s)    
        self.pTrappingData = 0.0
        self.pTrappingData_s = 0.0

#######
###########         begin capture state update ##########
#####

    def captStateUpdate(self):
        """
        function to update the capture state of each present but not captured individ
        """
        # loop thru sessions
        for i in range(self.basicdata.nsession):
            sessionMask = self.stoatdata.stoatSession == i
            captStateSession = self.stoatdata.stoatCaptState[sessionMask]
            vMask = np.sum(captStateSession[:, 1:], axis = 1) == 1
            captState = captStateSession[vMask]
            stateProbSession = self.stoatdata.stoatStateProb[sessionMask]
            stateProb = stateProbSession[vMask]
            nV = np.sum(vMask)
            nStates = np.sum(captState, axis = 0)
            # loop thru individuals present but not captured
            for j in range(nV):
                captState_s = captState[j].copy()
                cs1 = 1 - captState[j, 1]
                cs2 = 1 - captState[j, 2]
                captState_s[1] = cs1
                captState_s[2] = cs2

                lcombo = -gammaln(nStates[1] + 1) - gammaln(nStates[2] + 1)
                nStates_s = nStates.copy()
                nStates_s[1] = nStates[1] + captState_s[1] - captState_s[2] 
                nStates_s[2] = nStates[2] - captState_s[1] + captState_s[2] 
                lcombo_s = -gammaln(nStates_s[1] + 1) - gammaln(nStates_s[2] + 1)
                lProb = np.log(np.sum(captState[j] * stateProb[j]))
                lProb_s = np.log(np.sum(captState_s * stateProb[j]))
 
                # calc importance ratio
                pnow = lcombo + lProb
                pnew = lcombo_s + lProb_s
                rValue = np.exp(pnew - pnow)        # calc importance ratio
                zValue = np.random.uniform(0.0, 1.0, size = None)
                if rValue > zValue:
                    captState[j] = captState_s
                    nStates = nStates_s.copy()
            captStateSession[vMask] = captState
            self.stoatdata.stoatCaptState[sessionMask] = captStateSession


#######
###########         begin probability of capture state update ##########
#####

    def v_update(self):     
        """
        update the non-removal states for present stoats
        """ 
        self.params.v_s = inv_logit(np.random.normal(logit(self.params.v), self.params.v_search))
        self.params.vt_s = self.params.v_s[self.basicdata.vPeriod]
        self.vSession = self.params.vt[self.stoatdata.stoatSession]
        self.vSession_s = self.params.vt_s[self.stoatdata.stoatSession]

        self.stoatdata.stoatStateProb_s = self.stateProb_Multi_Obs(self.stoatdata.stoatPCapt, 
            self.vSession_s)
        presMask = np.sum(self.stoatdata.stoatCaptState, axis = 1) == 1
        stateProb_s = self.stoatdata.stoatStateProb_s[presMask]
        stateProb = self.stoatdata.stoatStateProb[presMask]
        captState = self.stoatdata.stoatCaptState[presMask]
        vLLik = self.ZLoc_MultiNomLik(captState, stateProb)
        vLLik_s = self.ZLoc_MultiNomLik(captState, stateProb_s)

        pnow = vLLik + np.sum(np.log(stats.beta.pdf(self.params.v, self.params.v_alpha, self.params.v_beta)))
        pnew = vLLik_s + np.sum(np.log(stats.beta.pdf(self.params.v_s, self.params.v_alpha, self.params.v_beta)))

        rValue = np.exp(pnew - pnow)        # calc importance ratio
        zValue = np.random.uniform(0.0, 1.0, size = None)
        if rValue > zValue:
            self.params.v = self.params.v_s.copy()
            self.params.vt = self.params.vt_s.copy()
            self.stoatdata.stoatStateProb = self.stoatdata.stoatStateProb_s.copy()
             


########            Main mcmc function
########
    def mcmcFX(self):
        cc = 0
        for g in range(self.ngibbs * self.thinrate + self.burnin):

            self.N_stoatdata_updateFX()

            self.betaUpdateFX()

            self.rUpdateFX()

#            self.immUpdateFX()

            self.g0SigUpdateFX()

            self.v_update()

            self.captStateUpdate()     

            if g in self.keepseq:
                self.Ngibbs[cc] = self.params.N
                self.bgibbs[cc] = self.params.b
                self.rgibbs[cc] = self.params.rg
#                self.igibbs[cc] = self.params.ig[:2]
                self.ggibbs[cc] = self.params.g0
                self.siggibbs[cc] = self.params.sigma
                self.vgibbs[cc] = self.params.v
                self.deviancegibbs[cc] = self.devianceSum * -2.0
                cc = cc + 1
#        return (self.bgibbs, self.Ngibbs, self.rgibbs, self.igibbs, self.ggibbs, 
#                self.siggibbs, self.deviancegibbs)

########            Pickle results to directory
########

class Gibbs(object):
    def __init__(self, mcmcobj):
        self.Ngibbs = mcmcobj.Ngibbs
        self.bgibbs = mcmcobj.bgibbs
        self.rgibbs = mcmcobj.rgibbs
#        self.igibbs = mcmcobj.igibbs
        self.ggibbs = mcmcobj.ggibbs
        self.siggibbs = mcmcobj.siggibbs
        self.vgibbs = mcmcobj.vgibbs



########            Main function
#######
def main(basicfile=None, paramsfile=None, stoatfile=None):

    #np.seterr(all='raise')

    # path to project directory to read in data and write results
    stoatpath = os.getenv('STOATSPROJDIR', default='.')

    # paths and data to read in
    captDatFile = os.path.join(stoatpath,'capt13.csv')
    newCaptDatFile = os.path.join(stoatpath,'capt1113_715.csv')
    dateDatFile = os.path.join(stoatpath,'datesBind.csv')
    trapDatFile = os.path.join(stoatpath,'traploc5.csv')
    covDatFile = os.path.join(stoatpath,'covDat.csv')

    # initiate basicdata class and object when do not read in previous results
    if basicfile is None:
        # initiate stoatdata class
        basicdata = BasicData(captDatFile, newCaptDatFile, dateDatFile, 
                    trapDatFile, covDatFile)
    # read in pickled basicdata class
    else:
        inputBasicdata = os.path.join(stoatpath, basicfile)
        fileobj = open(inputBasicdata, 'rb')
        basicdata = pickle.load(fileobj)
        fileobj.close()

    # initiate params class and object when do not read in previous results
    if paramsfile is None:
        # initiate bas from script
        params = Params(basicdata)
    # pickled basicdata specified in system variable to be imported
    else:
        # read in pickled results (basicdata) from a previous run
        inputParams = os.path.join(stoatpath, paramsfile)
        fileobj = open(inputParams, 'rb')
        params = pickle.load(fileobj)
        fileobj.close()

    if stoatfile is None:
        # initiate stoatdata class
        stoatdata = StoatData(basicdata, params, debug = False)
    # read in pickled stoatdata class
    else:
        inputStoatdata = os.path.join(stoatpath, stoatfile)
        fileobj = open(inputStoatdata, 'rb')
        stoatdata = pickle.load(fileobj)
        fileobj.close()

    mcmcobj = MCMC(params, stoatdata, basicdata)
    # run mcmcFX - gibbs loop
    mcmcobj.mcmcFX()

    gibbsobj = Gibbs(mcmcobj)

    # pickle basic data from present run to be used to initiate new runs
    outParams = os.path.join(stoatpath,'out_params.pkl')
    fileobj = open(outParams, 'wb')
    pickle.dump(params, fileobj)
    fileobj.close()

    # pickle predator data from present run to be used to initiate new runs
    outBasicdata = os.path.join(stoatpath,'out_basicdata.pkl')
    fileobj = open(outBasicdata, 'wb')
    pickle.dump(basicdata, fileobj)
    fileobj.close()

    # pickle predator data from present run to be used to initiate new runs
    outStoatdata = os.path.join(stoatpath,'out_stoatdata.pkl')
    fileobj = open(outStoatdata, 'wb')
    pickle.dump(stoatdata, fileobj)
    fileobj.close()

    # pickle mcmc results for post processing in gibbsProcessing.py
    outGibbs = os.path.join(stoatpath,'out_gibbs.pkl')
    fileobj = open(outGibbs, 'wb')
    pickle.dump(gibbsobj, fileobj)
    fileobj.close()

if __name__ == '__main__':
    main()

  
  
  
