#!/usr/bin/env python

import os
from scipy import stats
from scipy.special import gammaln
from scipy.special import gamma
from scipy.stats.mstats import mquantiles
import numpy as np
import pylab as P
import prettytable
import datetime
import pickle

def logit(x):
    """
    logit function
    """
    return np.log(x) - np.log(1 - x)

def inv_logit(x):
    """
    inverse logit function
    """
    return np.exp(x) / (1 + np.exp(x))

def thProbFX(tt, debug = False):
    """
    multinomial probability
    """
    tt2 = np.exp(tt)
    tt3 = tt2/np.sum(tt2)
    return(tt3)

def distxy(x1,y1,x2,y2):
    """
    distance between points
    """
    return np.sqrt(np.power(x1 - x2, 2) + np.power(y1 - y2, 2))

def matrixsub(arr1, arr2):
    """
    looping sub-function to calculate distance matrix among many points
    """
    ysize = arr1.shape[0]
    xsize = arr2.shape[0]
    out = np.empty((ysize, xsize), arr1.dtype)
    for y in range(ysize):
        for x in range(xsize):
            out[y,x] = arr1[y] - arr2[x]
    return out

def distmat(x1, y1, x2, y2):
    """
    distance matrix calculation
    """
    dx = matrixsub(x1, x2)
    dy = matrixsub(y1, y2)
    dmat = np.sqrt(dx**2.0 + dy**2.0)
    return dmat

def initialPStoatTrapCaptFX(params, basicdata, availTrapNights, location,
        g0Param, debug = False):      # prob that stoat was capt in trap
    """
    initial probability of capture of given stoat captured in specified traps
    """
    distToTraps = basicdata.distTrapToCell2[:, location]
    eterm = np.exp(-(distToTraps) / params.var2)           # prob stoat-trap pair
    pNoCapt = 1. - g0Param * eterm
    pNoCaptNights = pNoCapt**(availTrapNights)
    pNoCaptNights = np.where(pNoCaptNights == 1., .9999, pNoCaptNights)
    return 1 - pNoCaptNights


def NpredInitialFX(nsession, N, removeDat, rpara, it, Npred):
    """
    Initial predicted N from population model
    """
    for i in range(nsession)[0:nsession-1]:
        Nday = N[i] - removeDat[i]
        Nday = np.where(Nday < 0, 0, Nday)
        Nday = np.multiply(rpara[i+1],Nday) + it[i+1]
        Npred[i+1] = Nday
    return(Npred)

def removeDatFX(nsession, stoat, session):
    """
    get removed stoat data
    """
    removeDat = np.arange(nsession)
    for i in range(nsession):
        removeDat[i] = np.sum(stoat[session==i])
    return(removeDat)

        
def quantileFX(a):
    """
    calc quantiles
    """
    return mquantiles(a, prob=[0.05, 0.5, 0.95])

################
#####
##
class Params(object):
    def __init__(self):
        """
        parameter class for simulations
        """
        # number of iterations to simulate
        self.iter = 2000
        # number of habitat covariates
        self.ncov = 2
        # suppresion threshold to stay below
        self.maxTrapNights = 14.0
        # improved lure target g0
        self.target_g0 = 0.03

        ########
        ########
        ################################   Set the simulation scenario
        # 1 = current
        # 2 = remove july
        # 3 = add March
        # 4 = Add gap
        # 5 = Add gap and March
        # 6 = Improve lure to mean g0 of .08
        # 7 = Add March and improve lure

        self.simGroup = 7
        if np.in1d(self.simGroup, np.array([1, 4, 6])):
            self.sessStructure = 1                      # current timing
        elif self.simGroup == 2:
            self.sessStructure = 2                      # remove July
        else:
            self.sessStructure = 3                      # Add March

        ########
        ########
        ################################ End modifying scenario


##
######
###############
class Simul(object):
    def __init__(self, params, gibbsdata, basicdata):
        """
        Class and functions to simulate population and trapping dynamics
        """
        #################
        #   Run functions

        self.getobjects(params, basicdata, gibbsdata)
        self.nMatrixFX()
        self.getDMatFX()
        self.simulFX()

        #   End running functions
        #################

    def getobjects(self, params, basicdata, gibbsdata):
        """
        bring in objects into Simul
        """
        self.params = params
        self.basicdata = basicdata
        self.gibbsdata = gibbsdata
        self.years = 2000 + np.array([8,8,8,8,8,9,9,9,10,10,10,11,11,11,
                12,12,12,13,13,13,14,14,14,15])
        self.meanG0 = np.mean(self.gibbsdata.ggibbs)

    def nMatrixFX(self):
        """
        create table of quantiles of estimated population size for each session
        """
        self.getMonthYear()
        self.ngibbs = len(self.gibbsdata.rgibbs)
        self.nQuantileMat = np.zeros(shape=(3,(self.basicdata.nsession + self.nSimSess)))

        remDatMat = np.expand_dims(self.basicdata.removeDat,1)
        nGibbsTransposed = np.transpose(self.gibbsdata.Ngibbs)
        postTrapNGibbs = nGibbsTransposed - remDatMat
        postTrapNGibbs2 = np.transpose(postTrapNGibbs)
        nQuantsGibbs = np.apply_along_axis(quantileFX, 0, postTrapNGibbs2)
        nMeanGibbs = np.apply_along_axis(np.mean, 0, postTrapNGibbs2)
        self.nQuantileMat[:, 0:self.basicdata.nsession] = nQuantsGibbs
        self.nQuantileMat[1, 0:self.basicdata.nsession] = nMeanGibbs
        self.nSimMat = np.zeros(shape=(self.params.iter, self.nSimSess))
        self.nPostTrapMat = np.zeros(shape=(self.params.iter, self.nSimSess))

        self.nPreQuantMat = np.zeros(shape=(3,(self.basicdata.nsession + self.nSimSess)))
        nPreQuantsGibbs = np.apply_along_axis(quantileFX, 0, self.gibbsdata.Ngibbs)
        nPreMean =  np.apply_along_axis(np.mean, 0, self.gibbsdata.Ngibbs)
        self.nPreQuantMat[:, 0:self.basicdata.nsession] = nPreQuantsGibbs
        self.nPreQuantMat[1, 0:self.basicdata.nsession] = nPreMean

        self.sampID = np.random.choice(range(self.ngibbs), self.params.iter, replace = True)
        self.nMatPred = np.zeros(shape = (self.params.iter, self.nSimSess))
        self.nNovember = len(self.simMonths[self.simMonths == 11])
        self.reproPara = np.ones(self.nSimSess)
        self.nTrapRSel = np.round(self.basicdata.ntrap * 0.01)         
        self.nEradications = 0
        self.nSuppressions = 0

    def getMonthYear(self):
        """
        depending on scenario, get month and years for simulations
        """
        # current regime
        if self.params.sessStructure == 1:
            self.simMonths = np.append(np.array([7, 11]), np.tile(np.array([1,7,11]), 4))
            self.vulnerIndx = np.append(np.array([1, 2]), np.tile(range(3), 4))
            self.simYears = np.append(np.array([2015, 2015]), 
                        np.repeat(np.arange(2016, 2020),3))
            self.testSessMask = (self.simMonths == 11) & (self.simYears == 2019)
        # remove July
        if self.params.sessStructure == 2:
            self.simMonths = np.append(11, np.tile(np.array([1,11]), 4))
            self.vulnerIndx = np.append(2, np.tile(np.array([0, 2]), 4))
            self.simYears = np.append(2015, np.repeat(np.arange(2016, 2020), 2))
            self.testSessMask = (self.simMonths == 11) & (self.simYears == 2019)
        # Add March
        if self.params.sessStructure == 3:
            self.simMonths = np.append(np.array([7, 11]), np.tile(np.array([1, 3, 7,11]), 4))
            self.vulnerIndx = np.append(np.array([1, 2]), np.tile(np.array([0, 1, 1, 2]), 4))
            self.simYears = np.append(np.array([2015, 2015]), 
                        np.repeat(np.arange(2016, 2020),4))
            self.testSessMask = (self.simMonths == 11) & (self.simYears == 2019)
        self.nSimSess = len(self.simMonths)

    def getDMatFX(self):
        """
        calc distance from cells to traps
        """
        self.distTrapToCell2 = self.basicdata.distTrapToCell2.copy()
        self.ntrap = self.basicdata.ntrap
        ################ Adjust trap data if scen = 4 or 5
        if np.in1d(self.params.simGroup, [4, 5]):
            # read in added trap data
            self.addTraps = np.genfromtxt('Addtraps5.csv', delimiter=',', names = True,
                dtype=['f8', 'f8'])
            newX = self.addTraps['x']
            newY = self.addTraps['y']
            self.trapX = np.append(self.basicdata.trapX, newX)
            self.trapY = np.append(self.basicdata.trapY, newY)
            self.ntrap = len(self.trapX)
            distTrapToCell = distmat(self.trapX, self.trapY, 
                self.basicdata.cellX, self.basicdata.cellY)
            self.distTrapToCell2 = distTrapToCell**2.0

    def getAvailTrapFX(self):
        """
        randomly reduce availability of a subset of traps by .5 (sprung traps)
        """
        rSel = np.random.choice(self.basicdata.trapID, self.nTrapRSel, replace = False)
        availTrap = np.ones(self.ntrap)
        availTrap[rSel] = 0.5
        self.availTrapNights = np.expand_dims(availTrap * self.params.maxTrapNights, 1)

    def pCaptOneStoat(self):
        """
        probability of capture of stoat k
        """
        eterm = np.exp(-(self.dArray) / self.sigmaIter**2)
        pNoCapt = 1.0 - self.g0Iter * eterm
        pNoCaptNights = pNoCapt**(self.availTrapNights)
        pNoCaptNights = np.where(pNoCaptNights == 1.0, 0.9999, pNoCaptNights)
        pNoCaptNightsTraps = pNoCaptNights.prod(axis = 0)
        self.pCaptureAllTraps = 1.0 - pNoCaptNights
        self.pCaptOne = 1.0 - pNoCaptNightsTraps

    def idTrapTrapped(self):
        """
        id trap that captured stoat and
        update trap availability    
        """
        trapTrappedID = np.random.binomial(1, self.pCaptureAllTraps)
        maskTrap = trapTrappedID == 1
        tmpAvailTrapNights = self.availTrapNights[maskTrap]
        self.availTrapNights[maskTrap] = tmpAvailTrapNights * 0.5 

    def captStateProbs(self, j):
        """
        calc multinomial capture state probabilities and put in array
        """
        vulnCapt = self.pCaptOne * (1.0 - self.vulner_j)
        vulnNotCapt = (1.0 - self.pCaptOne) * (1.0 - self.vulner_j)
        notVuln = self.vulner_j
        s1 = np.append(vulnCapt, vulnNotCapt)
        self.stateProbArray = np.append(s1, notVuln)

    def loopIndividuals(self, j):
        """
        loop thru individuals to capture stoats and ID traps that capture
        """
        for k in range(self.nPreTrap):
            # id cell location
            stoatInCell = np.random.multinomial(1, self.thMultiNom, size = None)
            cellIDSess = self.basicdata.cellID[stoatInCell == 1]
            self.dArray = self.distTrapToCell2[:, cellIDSess]
            self.pCaptOneStoat()
            self.captStateProbs(j)
            captarray = np.random.multinomial(1, self.stateProbArray)
            self.captEvent = captarray[0]
            # if individual captured, update trap availability
            if self.captEvent == 1:
                self.idTrapTrapped()
                self.nNow += -1.0

    def simulFX(self):
        """
        simulation function
        """
        for i in range(self.params.iter):
            sampID_i = self.sampID[i]
            nStart = self.gibbsdata.Ngibbs[sampID_i, -1] - self.basicdata.removeDat[-1]
            self.randRepro = self.gibbsdata.rgibbs[sampID_i]
            self.g0Iter = self.gibbsdata.ggibbs[sampID_i]
            # if simGroup == 6 or 7 to improve lures
            if np.in1d(self.params.simGroup, [6, 7]):
                self.lureDiff()
            self.sigmaIter = self.gibbsdata.siggibbs[sampID_i]
            self.vIter = self.gibbsdata.vgibbs[sampID_i]      # array of three

            self.b = self.gibbsdata.bgibbs[sampID_i]
            self.mu = np.dot(self.basicdata.xdat,self.b)
            thMultiNomTemp = thProbFX(self.mu, debug = False)
            self.thMultiNom = thMultiNomTemp.flatten()
            self.nNow = nStart
            # loop thru sessions
            self.simSessLoop(i)
        # gather up simulation results
        self.gatherSimResults()

    def meanSessPara(self):
        """
        get mean parameters from posteriors
        """
        nVariates = 2000
        variateID = np.random.choice(self.ngibbs, nVariates)
        self.randRepro = np.mean(self.gibbsdata.rgibbs)
        for m in range(self.ngibbs):
            self.b = self.gibbsdata.bgibbs[m] 
            self.mu += np.dot(self.basicdata.xdat, self.b)
        self.mu = self.mu / nVariates
        thMultiNomTemp = thProbFX(self.mu, debug = False)
        self.thMultiNom = thMultiNomTemp.flatten()

        self.g0Iter = np.mean(self.gibbsdata.ggibbs)
        # if simGroup == 6 or 7 to improve lures
        if np.in1d(self.params.simGroup, [6, 7]):
            self.lureDiff()
        self.sigmaIter = np.mean(self.gibbsdata.siggibbs)
        self.vIter = np.mean(self.gibbsdata.vgibbs, axis = 0)      # array of three


    def getSessParameters(self):
        """
        get random variates of parameters for the session
        """
        sessSampID = np.random.choice(self.ngibbs)
        self.randRepro = self.gibbsdata.rgibbs[sessSampID]
        self.g0Iter = self.gibbsdata.ggibbs[sessSampID]
        # if simGroup == 6 or 7 to improve lures
        if np.in1d(self.params.simGroup, [6, 7]):
            self.lureDiff()
        self.sigmaIter = self.gibbsdata.siggibbs[sessSampID]
        self.vIter = self.gibbsdata.vgibbs[sessSampID]      # array of three



    def simSessLoop(self, i):
        """
        loop thru sessions in simulation
        """
        for j in range(self.nSimSess):
            if self.simMonths[j] == 11:
                lambdaPara = self.nNow * self.randRepro
                self.nPreTrap = round(lambdaPara)
                self.nNow = self.nPreTrap
            else:
                self.nPreTrap = self.nNow
            if self.nNow > 150:
                self.nNow = 150
                self.nPreTrap = self.nNow
            if self.nPreTrap == 0:
                self.nNow = 0
            self.nPreTrap = int(self.nPreTrap)
            if self.nPreTrap > 0:
                self.getAvailTrapFX()
                self.calcVulnerStatus(j)
                self.loopIndividuals(j)
            # populate result arrays
            self.nSimMat[i,j] = self.nPreTrap
            self.nPostTrapMat[i,j] = self.nNow
            if self.testSessMask[j]:
                if self.nNow == 0:
                    self.nEradications += 1

    def lureDiff(self):
        """
        Calc differnce in mean g0 from gibbs to target g0 for improved lure
        """
        addG0 = self.params.target_g0 - self.meanG0
        self.g0Iter = self.g0Iter + addG0 

    def gatherSimResults(self):
        """
        gather simulation results and populate arrays
        """
        probErad = self.nEradications / self.params.iter
        probSuppression = self.nSuppressions / self.params.iter
        preTrapQuants = np.apply_along_axis(quantileFX, 0, self.nSimMat)
        postTrapQuants = np.apply_along_axis(quantileFX, 0, self.nPostTrapMat)
        meanPreN = np.apply_along_axis(np.mean, 0, self.nSimMat)
        meanPostN = np.apply_along_axis(np.mean, 0, self.nPostTrapMat)
        matCols = np.shape(self.nQuantileMat)[1]
        self.nQuantileMat[:, -self.nSimSess:] = postTrapQuants
        self.nQuantileMat[1, -self.nSimSess:] = meanPostN
        self.nPreQuantMat[:, -self.nSimSess:] = preTrapQuants
        self.nPreQuantMat[1, -self.nSimSess:] = meanPreN
        print('probability of Erad', probErad)
        print('Probability of Suppression', probSuppression)
        
    def calcVulnerStatus(self, j):
        """
        calc which stoats are vulnerable
        """
        # vulner indx for sim session j
        vID = self.vulnerIndx[j]
        # probability of individual being not vulner in session j
        if (self.simMonths[j] == 3) & (self.params.sessStructure == 3):
            self.vulner_j = np.mean(self.vIter[:1])
        else:
            self.vulner_j = self.vIter[vID]

class ResultsProcessing(object):
    def __init__(self, basicdata, simobj):
        """
        class and function to process results of simulation into tables and figures
        """
        self.basicdata = basicdata
        self.simobj = simobj
        self.mo = np.append(self.basicdata.month, self.simobj.simMonths)
        self.yr = np.append(self.simobj.years, self.simobj.simYears)
        self.stoatpath = os.getenv('STOATSPROJDIR', default='.')
                                                    ###############
                                                                                ###############
                                                                                ###############
                                                                                ###############
        self.summaryTable = os.path.join(self.stoatpath,'summary_I14_G2_R12.txt') # result table
        self.plotPNG = os.path.join(self.stoatpath, 'simCurrent.png')   # Sim image .png
                                                                                ###############
                                                                                ###############
                                                                                ###############
                                                                                ###############
    def makeTableFX(self):
        """
        make table
        """
        resultNPostTrap = self.simobj.nQuantileMat.transpose()
        resultNPostTrap = np.round(resultNPostTrap, 4)
            

        aa = prettytable.PrettyTable(['Months', 'Years', 'Low CI', 'Mean', 'High CI'])
        months = self.mo
        years = self.yr
        for i in range(np.shape(resultNPostTrap)[0]):
            month = months[i]
            year = years[i]
            row = [month] + [year] + resultNPostTrap[i].tolist()
            aa.add_row(row)
        print(aa)
        self.summaryNPost = resultNPostTrap.copy()
           


    def writeToFileFX(self):
        """
        write table to directory
        """
        print('shape sumNPost', self.summaryNPost.shape)
        print('shp nquantileMat', self.simobj.nQuantileMat.shape)
        (m, n) = self.summaryNPost.shape
        # create new structured array with columns of different types
        structured = np.empty((m,), dtype=[('Months', np.integer), ('Years', np.integer), ('Low CI', np.float),
                    ('Mean', np.float), ('High CI', np.float)])
        # copy data over
        structured['Low CI'] = self.summaryNPost[:, 0]
        structured['Mean'] = self.summaryNPost[:, 1]
        structured['High CI'] = self.summaryNPost[:, 2]

        structured['Months'] = self.mo.astype(np.integer)
        structured['Years'] = self.yr.astype(np.integer)
#
#        np.savetxt(self.summaryTable, structured, fmt=['%d', '%d', '%.4f', '%.4f', '%.4f'],
#                    header='Months Years Low_CI Mean High_CI')

    def plotFX(self):
        """
        plot mean and upper 95% quantile of population size for each year
        """
        dates = []
        minDate = datetime.date(2008, 5, 1)
        maxDate = datetime.date(2019, 12, 1)
        for month, year in zip(self.mo, self.yr):
            date = datetime.date(int(year), int(month), 1)
            dates.append(date)

        P.figure()
        P.plot(dates, self.simobj.nQuantileMat[1,:], label = 'Mean post-trap pop.', color = 'k', linewidth = 3)
        P.plot(dates, self.simobj.nQuantileMat[2,:], label = '95th percentile', color = 'k')
        P.ylim([0, 140])
        P.xlabel('Time', fontsize = 17)
        P.ylabel('Pop. size after trapping', fontsize = 17)
        P.legend(loc='upper right')
        ax = P.gca()
        ax.text(datetime.date(2009, 1, 1), 142, 'A.', ha = 'left', fontsize=17)

        for tick in ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(14)
        for tick in ax.yaxis.get_major_ticks():
            tick.label.set_fontsize(14)
        P.xlim(minDate, maxDate) 

        P.xticks([datetime.date(2009, 1, 1), datetime.date(2011, 1, 1), datetime.date(2013, 1, 1),
                    datetime.date(2015, 1, 1), datetime.date(2017, 1, 1), datetime.date(2019, 1, 1)])
        P.savefig(self.plotPNG, format='png', dpi = 1000)
        P.show()
        ######################################


########            Main function
#######
def main():

    #np.seterr(all='raise')


    params = Params()

    # path to project directory to read in data and write results
    stoatpath = os.getenv('STOATSPROJDIR', default='.')

    # read in pickled data from mcmc
    inputGibbs = os.path.join(stoatpath, 'out_gibbs_M11.pkl')
    fileobj = open(inputGibbs, 'rb')
    gibbsdata = pickle.load(fileobj)
    fileobj.close()

    inputBasicdata = os.path.join(stoatpath, 'out_basicdata_M11.pkl')
    fileobj = open(inputBasicdata, 'rb')
    basicdata = pickle.load(fileobj)
    fileobj.close()

    simobj = Simul(params, gibbsdata, basicdata)

    resultsobj = ResultsProcessing(basicdata, simobj)

    resultsobj.makeTableFX()

#    resultsobj.writeToFileFX()

#    resultsobj.plotFX()


if __name__ == '__main__':
    main()



