/*
 * Decompiled with CFR 0.152.
 */
package stacey;

import beast.core.Citation;
import beast.core.Description;
import beast.core.Input;
import beast.core.State;
import beast.core.parameter.RealParameter;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeDistribution;
import beast.evolution.tree.TreeInterface;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import stacey.GtreeAndCoalFactor;
import stacey.InverseGammaComponent;
import stacey.util.Bindings;
import stacey.util.FitsHeights;
import stacey.util.InverseGammaMixture;

@Description(value="The STACEY coalescent distribution. The central class in STACEY.")
@Citation(value="Graham Jones (2016). Algorithmic improvements to species delimitation and phylogeny estimation under the multispecies coalescent.\nJournal of Mathematical Biology. DOI 10.1007/s00285-016-1034-0\nhttp://link.springer.com/article/10.1007/s00285-016-1034-0 ", firstAuthorSurname="Jones")
public class PIOMSCoalescentDistribution
extends TreeDistribution {
    public Input<List<GtreeAndCoalFactor>> geneTreesInput = new Input("geneTree", "All gene trees", new ArrayList());
    public Input<List<InverseGammaComponent>> priorComponentsInput = new Input("popPriorInvGamma", "Component of mixture of inverse gamma distributions used as a prior for the per-branch populations", new ArrayList(), Input.Validate.REQUIRED);
    public Input<RealParameter> popPriorScaleInput = new Input("popPriorScale", "Overall scale for population size", Input.Validate.REQUIRED);
    public Input<TaxonSet> taxonSetInput = new Input("taxonset", "set of taxa mapping lineages to species", Input.Validate.REQUIRED);
    private Bindings bindings;
    private FitsHeights fitsHeights;
    private TreeInterface sTree;
    private int nSMCTreeNodes;
    private int nSMCTreeTips;
    private List<GtreeAndCoalFactor> gTreeCFs;
    private ArrayList<Tree> gTrees;
    private int nGTrees;
    private double[][] lnGammaRatiosTable;
    private int[][] coalCounts;
    private double[][] coalIntensities;
    private boolean[] gTreeFitIsDirty;
    private boolean[] gTreeFits;
    private boolean[] gTreeCountIntensityIsDirty;
    private int[][] storedCoalCounts;
    private double[][] storedCoalIntensities;
    private boolean[] storedGTreeFits;
    private boolean[] storedGTreeFitIsDirty;
    private boolean[] storedGTreeCountIntensityIsDirty;
    private boolean debugFlag = Boolean.valueOf(System.getProperty("stacey.debug"));
    private int numberofdebugchecks = 0;
    private static final int maxnumberofdebugchecks = 1000000;

    public void initAndValidate() {
        int j;
        int i;
        super.initAndValidate();
        this.gTreeCFs = (List)this.geneTreesInput.get();
        this.nGTrees = this.gTreeCFs.size();
        this.gTrees = new ArrayList(this.nGTrees);
        for (int j2 = 0; j2 < this.nGTrees; ++j2) {
            this.gTrees.add(this.gTreeCFs.get(j2).getTree());
        }
        this.sTree = (TreeInterface)this.treeInput.get();
        this.bindings = Bindings.initialise(this.sTree, this.gTrees);
        this.fitsHeights = new FitsHeights(this.sTree, this.gTrees, this.bindings);
        this.nSMCTreeNodes = this.sTree.getNodeCount();
        this.nSMCTreeTips = this.sTree.getLeafNodeCount();
        int totalGTipCount = 0;
        for (Tree gtree : this.gTrees) {
            totalGTipCount += gtree.getLeafNodeCount();
        }
        this.lnGammaRatiosTable = new double[((List)this.priorComponentsInput.get()).size()][totalGTipCount];
        for (int c = 0; c < ((List)this.priorComponentsInput.get()).size(); ++c) {
            InverseGammaComponent igc = (InverseGammaComponent)((Object)((List)this.priorComponentsInput.get()).get(c));
            for (int q = 0; q < totalGTipCount; ++q) {
                this.lnGammaRatiosTable[c][q] = this.lnRatioGammas(igc.getAlpha(), q);
            }
        }
        this.coalCounts = new int[this.nSMCTreeNodes][this.nGTrees];
        this.coalIntensities = new double[this.nSMCTreeNodes][this.nGTrees];
        this.storedCoalCounts = new int[this.nSMCTreeNodes][this.nGTrees];
        this.storedCoalIntensities = new double[this.nSMCTreeNodes][this.nGTrees];
        double totalweight = 0.0;
        for (i = 0; i < ((List)this.priorComponentsInput.get()).size(); ++i) {
            totalweight += ((InverseGammaComponent)((Object)((List)this.priorComponentsInput.get()).get(i))).getWeight();
        }
        for (i = 0; i < ((List)this.priorComponentsInput.get()).size(); ++i) {
            ((InverseGammaComponent)((Object)((List)this.priorComponentsInput.get()).get(i))).normalizeWeight(totalweight);
        }
        this.gTreeFits = new boolean[this.nGTrees];
        this.storedGTreeFits = new boolean[this.nGTrees];
        this.gTreeFitIsDirty = new boolean[this.nGTrees];
        this.storedGTreeFitIsDirty = new boolean[this.nGTrees];
        for (j = 0; j < this.nGTrees; ++j) {
            this.gTreeFitIsDirty[j] = true;
            this.storedGTreeFitIsDirty[j] = true;
        }
        this.gTreeCountIntensityIsDirty = new boolean[this.nGTrees];
        this.storedGTreeCountIntensityIsDirty = new boolean[this.nGTrees];
        for (j = 0; j < this.nGTrees; ++j) {
            this.gTreeCountIntensityIsDirty[j] = true;
            this.storedGTreeCountIntensityIsDirty[j] = true;
        }
    }

    public double calculateLogP() {
        double fastLogP = this.logLhoodAllGeneTreesInSMCTree(this.getInverseGammaMixture(), ((RealParameter)this.popPriorScaleInput.get()).getValue(), false);
        if (this.debugFlag && this.numberofdebugchecks < 1000000) {
            double robustLogP = this.logLhoodAllGeneTreesInSMCTree(this.getInverseGammaMixture(), ((RealParameter)this.popPriorScaleInput.get()).getValue(), true);
            if (Math.abs(fastLogP - robustLogP) > 1.0E-12) {
                System.err.println("BUG in calculateLogP() in PIOMSCoalescentDistribution");
                throw new RuntimeException("Fatal STACEY error.");
            }
            ++this.numberofdebugchecks;
        }
        this.logP = fastLogP;
        return this.logP;
    }

    protected boolean requiresRecalculation() {
        int j;
        if (this.sTree.somethingIsDirty()) {
            for (j = 0; j < this.nGTrees; ++j) {
                this.gTreeCountIntensityIsDirty[j] = true;
                this.gTreeFitIsDirty[j] = true;
            }
        }
        j = 0;
        while (j < this.nGTrees) {
            boolean gTreeDirty = this.gTrees.get(j).somethingIsDirty();
            int n = j;
            this.gTreeFitIsDirty[n] = this.gTreeFitIsDirty[n] | gTreeDirty;
            int n2 = j++;
            this.gTreeCountIntensityIsDirty[n2] = this.gTreeCountIntensityIsDirty[n2] | gTreeDirty;
        }
        return true;
    }

    public List<String> getArguments() {
        return null;
    }

    public List<String> getConditions() {
        return null;
    }

    public void sample(State state, Random random) {
    }

    public void store() {
        System.arraycopy(this.gTreeFits, 0, this.storedGTreeFits, 0, this.storedGTreeFits.length);
        System.arraycopy(this.gTreeFitIsDirty, 0, this.storedGTreeFitIsDirty, 0, this.storedGTreeFitIsDirty.length);
        for (int n = 0; n < this.nSMCTreeNodes; ++n) {
            System.arraycopy(this.coalCounts[n], 0, this.storedCoalCounts[n], 0, this.storedCoalCounts[n].length);
            System.arraycopy(this.coalIntensities[n], 0, this.storedCoalIntensities[n], 0, this.storedCoalIntensities[n].length);
        }
        System.arraycopy(this.gTreeCountIntensityIsDirty, 0, this.storedGTreeCountIntensityIsDirty, 0, this.storedGTreeCountIntensityIsDirty.length);
        super.store();
    }

    public void restore() {
        System.arraycopy(this.storedGTreeFits, 0, this.gTreeFits, 0, this.gTreeFits.length);
        int[][] tmpCCs = this.coalCounts;
        this.coalCounts = this.storedCoalCounts;
        this.storedCoalCounts = tmpCCs;
        double[][] tmpCIs = this.coalIntensities;
        this.coalIntensities = this.storedCoalIntensities;
        this.storedCoalIntensities = tmpCIs;
        System.arraycopy(this.storedGTreeCountIntensityIsDirty, 0, this.gTreeCountIntensityIsDirty, 0, this.gTreeCountIntensityIsDirty.length);
        super.restore();
    }

    InverseGammaMixture getInverseGammaMixture() {
        int n = ((List)this.priorComponentsInput.get()).size();
        InverseGammaMixture igm = new InverseGammaMixture(n);
        for (int c = 0; c < n; ++c) {
            igm.setWeight(c, ((InverseGammaComponent)((Object)((List)this.priorComponentsInput.get()).get(c))).getWeight());
            igm.setAlpha(c, ((InverseGammaComponent)((Object)((List)this.priorComponentsInput.get()).get(c))).getAlpha());
            igm.setBeta(c, ((InverseGammaComponent)((Object)((List)this.priorComponentsInput.get()).get(c))).getBeta());
        }
        return igm;
    }

    double coalFactor(int j) {
        return this.gTreeCFs.get(j).getCoalFactor();
    }

    private double logLhoodAllGeneTreesInSMCTree(InverseGammaMixture igm, double popPriorScale, boolean robust) {
        int j;
        if (robust) {
            for (j = 0; j < this.nGTrees; ++j) {
                this.gTreeCountIntensityIsDirty[j] = true;
                this.gTreeFitIsDirty[j] = true;
            }
        }
        for (j = 0; j < this.nGTrees; ++j) {
            if (!this.gTreeFitIsDirty[j] && this.gTreeFits[j]) continue;
            this.gTreeCountIntensityIsDirty[j] = true;
        }
        boolean allGtreesFit = true;
        for (int j2 = 0; j2 < this.nGTrees && allGtreesFit; ++j2) {
            if (this.gTreeFitIsDirty[j2]) {
                this.gTreeFits[j2] = this.fitsHeights.updateFitHeightsForOneGTree(j2);
                allGtreesFit &= this.gTreeFits[j2];
                this.gTreeFitIsDirty[j2] = false;
                continue;
            }
            if (!this.debugFlag || this.numberofdebugchecks >= 1000000 || this.gTreeFits[j2] == this.fitsHeights.updateFitHeightsForOneGTree(j2)) continue;
            System.err.println("BUG in logLhoodAllGeneTreesInSMCTree() gTreeFits[j] wrong");
            throw new RuntimeException("Fatal STACEY error.");
        }
        if (!allGtreesFit) {
            return Double.NEGATIVE_INFINITY;
        }
        double maxgtreehgt = 0.0;
        for (Tree gTree : this.gTrees) {
            maxgtreehgt = Math.max(maxgtreehgt, gTree.getRoot().getHeight());
        }
        for (int j3 = 0; j3 < this.nGTrees; ++j3) {
            int n;
            if (!this.gTreeCountIntensityIsDirty[j3]) continue;
            double coalFactor = this.gTreeCFs.get(j3).getCoalFactor();
            for (int n2 = 0; n2 < this.nSMCTreeNodes; ++n2) {
                this.coalCounts[n2][j3] = this.fitsHeights.getHeightsFromSNodeNrGTree(n2, j3).size();
            }
            int[] nLins = new int[this.nSMCTreeNodes];
            for (n = 0; n < this.nSMCTreeTips; ++n) {
                nLins[n] = this.bindings.nLineagesForBeastTipNrAndGtree(n, j3);
            }
            this.fillinSubtreeNLineages(nLins, this.sTree.getRoot(), j3);
            for (n = 0; n < this.nSMCTreeNodes; ++n) {
                Node anc;
                double ancHeight;
                ArrayList<Double> njHeights = this.fitsHeights.getHeightsFromSNodeNrGTree(n, j3);
                double nodeHeight = this.sTree.getNode(n).getHeight();
                if (this.debugFlag) {
                    for (int i = 0; i < njHeights.size(); ++i) {
                        if (!(njHeights.get(i) < nodeHeight)) continue;
                        System.out.println("BUG in logLhoodAllGeneTreesInSMCTree() bad height (1).");
                    }
                }
                double d = ancHeight = (anc = this.sTree.getNode(n).getParent()) == null ? maxgtreehgt : anc.getHeight();
                assert (nodeHeight <= ancHeight);
                int nCoals = njHeights.size();
                assert (nCoals == this.coalCounts[n][j3]);
                assert (nCoals < nLins[n]);
                if (this.debugFlag) {
                    for (int i = 0; i < njHeights.size(); ++i) {
                        if (!(njHeights.get(i) > ancHeight)) continue;
                        System.out.println("BUG in logLhoodAllGeneTreesInSMCTree() bad height (2).");
                    }
                }
                double[] heights = new double[nCoals + 2];
                heights[0] = nodeHeight;
                heights[nCoals + 1] = ancHeight;
                for (int i = 0; i < nCoals; ++i) {
                    heights[i + 1] = njHeights.get(i);
                }
                if (nCoals >= 2) {
                    Arrays.sort(heights);
                }
                double coalIntensity = 0.0;
                for (int i = 0; i < nCoals + 1; ++i) {
                    coalIntensity += (heights[i + 1] - heights[i]) * (double)(nLins[n] - i) * (double)(nLins[n] - i - 1) * 0.5;
                }
                this.coalIntensities[n][j3] = coalIntensity /= coalFactor;
            }
            this.gTreeCountIntensityIsDirty[j3] = false;
        }
        double logPGS = this.logProbAllGTreesInSMCTree(igm, popPriorScale);
        assert (!Double.isNaN(logPGS));
        assert (!Double.isInfinite(logPGS));
        return logPGS;
    }

    private void fillinSubtreeNLineages(int[] nlineages, Node node, int j) {
        int rgtNLin;
        int lftNLin;
        Node lftNode = node.getChild(0);
        if (lftNode.isLeaf()) {
            lftNLin = nlineages[lftNode.getNr()];
        } else {
            this.fillinSubtreeNLineages(nlineages, lftNode, j);
            lftNLin = nlineages[lftNode.getNr()];
        }
        lftNLin -= this.coalCounts[lftNode.getNr()][j];
        Node rgtNode = node.getChild(1);
        if (rgtNode.isLeaf()) {
            rgtNLin = nlineages[rgtNode.getNr()];
        } else {
            this.fillinSubtreeNLineages(nlineages, rgtNode, j);
            rgtNLin = nlineages[rgtNode.getNr()];
        }
        nlineages[node.getNr()] = lftNLin + (rgtNLin -= this.coalCounts[rgtNode.getNr()][j]);
    }

    private double logProbAllGTreesInSMCTree(InverseGammaMixture igm, double popPriorScale) {
        double logP = 0.0;
        double[] logCFs = new double[this.nGTrees];
        for (int j = 0; j < this.nGTrees; ++j) {
            logCFs[j] = Math.log(this.gTreeCFs.get(j).getCoalFactor());
        }
        double[] lambdas = igm.getWeights();
        double[] alphas = igm.getAlphas();
        double[] betas = igm.getBetas();
        double[] logProbCpts = new double[lambdas.length];
        for (int n = 0; n < this.nSMCTreeNodes; ++n) {
            int q_b = 0;
            double gamma_b = 0.0;
            double minusLog_r_b = 0.0;
            for (int j = 0; j < this.nGTrees; ++j) {
                q_b += this.coalCounts[n][j];
                minusLog_r_b += logCFs[j] * (double)this.coalCounts[n][j];
                gamma_b += this.coalIntensities[n][j];
            }
            for (int c = 0; c < lambdas.length; ++c) {
                logProbCpts[c] = 0.0;
                double sigmabeta_c = popPriorScale * betas[c];
                int n2 = c;
                logProbCpts[n2] = logProbCpts[n2] + Math.log(lambdas[c]);
                int n3 = c;
                logProbCpts[n3] = logProbCpts[n3] + alphas[c] * Math.log(sigmabeta_c);
                int n4 = c;
                logProbCpts[n4] = logProbCpts[n4] - (alphas[c] + (double)q_b) * Math.log(sigmabeta_c + gamma_b);
                int n5 = c;
                logProbCpts[n5] = logProbCpts[n5] + this.lnGammaRatiosTable[c][q_b];
            }
            double logP_b = this.logSumExp(logProbCpts);
            logP += (logP_b -= minusLog_r_b);
        }
        return logP;
    }

    private double lnRatioGammas(double a, int n) {
        double x = 0.0;
        for (int i = 0; i < n; ++i) {
            x += Math.log(a + (double)i);
        }
        return x;
    }

    private double logSumExp(double[] x) {
        double maxx = Double.NEGATIVE_INFINITY;
        for (double d : x) {
            if (!(d > maxx)) continue;
            maxx = d;
        }
        double sum = 0.0;
        for (double d : x) {
            sum += Math.exp(d - maxx);
        }
        return maxx + Math.log(sum);
    }
}

