/*
 * Decompiled with CFR 0.152.
 */
package stacey;

import beast.core.Description;
import beast.core.Input;
import beast.core.Operator;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeInterface;
import beast.util.Randomizer;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import stacey.debugtune.Checks;
import stacey.util.Bindings;
import stacey.util.BitUnion;
import stacey.util.UnionArrays;

@Description(value="A move which changes a node height in the SMC-tree, and changes some node heights in the gene trees in a way which preserves all topologies and usually maintains compatibility.")
public class NodesNudge
extends Operator {
    public Input<Tree> smcTreeInput = new Input("smcTree", "The species tree or minimal clusters tree", Input.Validate.REQUIRED);
    public Input<List<Tree>> geneTreesInput = new Input("geneTree", "All gene trees", new ArrayList());
    public Input<Long> delayInput = new Input("delay", "Number of times the operator is disabled.");
    private static final double[] gtreeNodeCriterionWeights = new double[]{1.0, 1.0, 1.0};
    private Tree sTree;
    private List<Tree> gTrees;
    private long delay = 0L;
    private int callCount = 0;
    private boolean sTreeTooSmall;
    private UnionArrays unionArrays;
    private int[] rejectCounts;
    private int[] acceptCounts;
    private double[] upTotals;
    private double[] downTotals;
    private int lastcriterion;
    private double lastheightchange;
    private final boolean debugFlag = Boolean.valueOf(System.getProperty("stacey.debug"));
    private int numberofdebugchecks = 0;
    private static final int maxnumberofdebugchecks = 100000;

    public void initAndValidate() {
        this.sTree = (Tree)this.smcTreeInput.get();
        this.gTrees = (List)this.geneTreesInput.get();
        boolean bl = this.sTreeTooSmall = this.sTree.getLeafNodeCount() < 3;
        if (this.delayInput != null && this.delayInput.get() != null) {
            this.delay = (Long)this.delayInput.get();
        }
        Bindings bindings = Bindings.initialise((TreeInterface)this.sTree, this.gTrees);
        this.unionArrays = UnionArrays.initialise((TreeInterface)this.sTree, this.gTrees, bindings);
        this.rejectCounts = new int[GtreeNodeCriterion.values().length];
        this.acceptCounts = new int[GtreeNodeCriterion.values().length];
        this.upTotals = new double[GtreeNodeCriterion.values().length];
        this.downTotals = new double[GtreeNodeCriterion.values().length];
    }

    public double proposal() {
        if (this.sTreeTooSmall) {
            return Double.NEGATIVE_INFINITY;
        }
        ++this.callCount;
        if ((long)this.callCount < this.delay) {
            return Double.NEGATIVE_INFINITY;
        }
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "NodesNudge", "before move");
            ++this.numberofdebugchecks;
        }
        this.unionArrays.update();
        double logHR = this.doNodesNudgeMove();
        this.unionArrays.reset();
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "NodesNudge", "after move");
            ++this.numberofdebugchecks;
        }
        return logHR;
    }

    public void accept() {
        super.accept();
        int n = this.lastcriterion;
        this.acceptCounts[n] = this.acceptCounts[n] + 1;
        if (this.lastheightchange > 0.0) {
            int n2 = this.lastcriterion;
            this.upTotals[n2] = this.upTotals[n2] + this.lastheightchange;
        } else {
            int n3 = this.lastcriterion;
            this.downTotals[n3] = this.downTotals[n3] + this.lastheightchange;
        }
    }

    public void reject(int reason) {
        super.reject(reason);
        int n = this.lastcriterion;
        this.rejectCounts[n] = this.rejectCounts[n] + 1;
    }

    public void storeToFile(PrintWriter out) {
        super.storeToFile(out);
        if (this.debugFlag) {
            out.print("{id:\"" + this.getID() + "\"");
            for (int ch = 0; ch < GtreeNodeCriterion.values().length; ++ch) {
                out.print(" " + GtreeNodeCriterion.values()[ch].toString() + ", accept:" + this.acceptCounts[ch] + ", reject:" + this.rejectCounts[ch] + ", downTotal:" + this.downTotals[ch] + ", upTotals:" + this.upTotals[ch]);
            }
            out.print("}");
        }
    }

    private double doNodesNudgeMove() {
        int choice;
        this.sTree.startEditing((Operator)this);
        for (int j = 0; j < this.gTrees.size(); ++j) {
            this.gTrees.get(j).startEditing((Operator)this);
        }
        assert (GtreeNodeCriterion.values().length == gtreeNodeCriterionWeights.length);
        this.lastcriterion = choice = Randomizer.randomChoicePDF((double[])gtreeNodeCriterionWeights);
        GtreeNodeCriterion criterion = GtreeNodeCriterion.values()[choice];
        return this.doNormalNodesNudgeMove(criterion);
    }

    private double doNormalNodesNudgeMove(GtreeNodeCriterion criterion) {
        double logHR = 0.0;
        ArrayList<ArrayList<Node>> geneNodesToNudge = new ArrayList<ArrayList<Node>>(this.gTrees.size());
        OpNNinfoSMCNode nani = this.chooseNodeForNudge();
        for (int j = 0; j < this.gTrees.size(); ++j) {
            ArrayList<Node> pgtnodes = this.nodesToNudgeInGtree(j, nani, criterion);
            geneNodesToNudge.add(pgtnodes);
        }
        double[] interval = this.getMaxNodesNudgeIntervalForGeneTrees(geneNodesToNudge);
        interval[0] = Math.max(interval[0], nani.getLftHeight() - nani.getNodeHeight());
        interval[0] = Math.max(interval[0], nani.getRgtHeight() - nani.getNodeHeight());
        interval[1] = nani.getAncHeight() == Double.POSITIVE_INFINITY ? Math.min(interval[1], 0.1) : Math.min(interval[1], nani.getAncHeight() - nani.getNodeHeight());
        assert (interval[0] <= 0.0);
        assert (interval[1] >= 0.0);
        double dh = Randomizer.uniform((double)interval[0], (double)interval[1]);
        logHR += this.doSTreeNodeNudge(nani.getNodeNr(), dh);
        this.lastheightchange = dh;
        return logHR += this.doNodesNudgeOnGeneTrees(dh, geneNodesToNudge);
    }

    private double[] getMaxNodesNudgeIntervalForGeneTrees(ArrayList<ArrayList<Node>> toNudge) {
        double[] interval = new double[]{-1.7976931348623157E308, Double.MAX_VALUE};
        for (int pgt = 0; pgt < this.gTrees.size(); ++pgt) {
            double[] pgtint = this.getMaxNodesNudgeIntervalWithinOneGeneTree(toNudge.get(pgt));
            interval[0] = Math.max(interval[0], pgtint[0]);
            interval[1] = Math.min(interval[1], pgtint[1]);
        }
        return interval;
    }

    private double[] getMaxNodesNudgeIntervalWithinOneGeneTree(ArrayList<Node> nodes) {
        ArrayList<ArrayList<Node>> conncpts = this.findConnectedComponents(nodes);
        double[] interval = new double[]{-1.7976931348623157E308, Double.MAX_VALUE};
        for (ArrayList<Node> cpt : conncpts) {
            Node root = cpt.get(0);
            Node anc = root.getParent();
            if (anc != null) {
                interval[1] = Math.min(interval[1], anc.getHeight() - root.getHeight());
            }
            for (Node node : cpt) {
                double nhgt = node.getHeight();
                Node ch0 = node.getChild(0);
                Node ch1 = node.getChild(1);
                if (!cpt.contains(ch0)) {
                    interval[0] = Math.max(interval[0], ch0.getHeight() - nhgt);
                }
                if (cpt.contains(ch1)) continue;
                interval[0] = Math.max(interval[0], ch1.getHeight() - nhgt);
            }
        }
        return interval;
    }

    private ArrayList<ArrayList<Node>> findConnectedComponents(ArrayList<Node> nodes) {
        int i;
        ArrayList<ArrayList<Node>> conncpts = new ArrayList<ArrayList<Node>>();
        HashMap<Node, Integer> node2Index = new HashMap<Node, Integer>(nodes.size());
        int[] idxs = new int[nodes.size()];
        for (i = 0; i < nodes.size(); ++i) {
            node2Index.put(nodes.get(i), i);
        }
        for (i = 0; i < nodes.size(); ++i) {
            idxs[i] = -1;
        }
        int idx = 0;
        for (Node x : nodes) {
            if (idxs[(Integer)node2Index.get(x)] >= 0) continue;
            while (x.getParent() != null && node2Index.containsKey(x.getParent())) {
                x = x.getParent();
            }
            ArrayList<Node> cpt = this.findConnectedComponentFrom(x, node2Index);
            conncpts.add(cpt);
            for (Node c : cpt) {
                if (node2Index.get(c) == null) {
                    System.out.println("BUG in findConnectedComponents()");
                }
                assert (node2Index.get(c) != null);
                idxs[node2Index.get((Object)c).intValue()] = idx;
            }
            ++idx;
        }
        return conncpts;
    }

    private ArrayList<Node> findConnectedComponentFrom(Node x, HashMap<Node, Integer> node2Index) {
        ArrayList<Node> cpt = new ArrayList<Node>();
        cpt.add(x);
        if (x.getChildCount() > 0) {
            Node ch0 = x.getChild(0);
            Node ch1 = x.getChild(1);
            if (ch0 != null && node2Index.containsKey(ch0)) {
                ArrayList<Node> cpt0 = this.findConnectedComponentFrom(ch0, node2Index);
                cpt.addAll(cpt0);
            }
            if (ch1 != null && node2Index.containsKey(ch1)) {
                ArrayList<Node> cpt1 = this.findConnectedComponentFrom(ch1, node2Index);
                cpt.addAll(cpt1);
            }
        }
        return cpt;
    }

    private double doNodesNudgeOnGeneTrees(double dh, ArrayList<ArrayList<Node>> toNudge) {
        double logHR = 0.0;
        for (int j = 0; j < this.gTrees.size(); ++j) {
            logHR += this.nudgeNodesInListBy(toNudge.get(j), dh);
        }
        return logHR;
    }

    private double nudgeNodesInListBy(ArrayList<Node> nodes, double dh) {
        double logHR = 0.0;
        for (Node node : nodes) {
            node.setHeight(node.getHeight() + dh);
        }
        return logHR;
    }

    public OpNNinfoSMCNode chooseNodeForNudge() {
        int s;
        while (this.sTree.getNode(s = Randomizer.nextInt((int)this.sTree.getNodeCount())).isLeaf() || this.sTree.getNode(s).isRoot()) {
        }
        Node sN = this.sTree.getNode(s);
        Node lftN = sN.getChild(0);
        Node rgtN = sN.getChild(1);
        Node ancN = sN.getParent();
        double sHgt = sN.getHeight();
        BitUnion sUnion = this.unionArrays.sNodeUnion(s);
        BitUnion lftUnion = this.unionArrays.sNodeUnion(lftN.getNr());
        BitUnion rgtUnion = this.unionArrays.sNodeUnion(rgtN.getNr());
        double lftHgt = lftN.getHeight();
        double rgtHgt = rgtN.getHeight();
        double anchgt = ancN != null ? ancN.getHeight() : Double.POSITIVE_INFINITY;
        return new OpNNinfoSMCNode(s, sHgt, sUnion, lftHgt, lftUnion, rgtHgt, rgtUnion, anchgt);
    }

    public double doSTreeNodeNudge(int s, double dh) {
        Node sN = this.sTree.getNode(s);
        sN.setHeight(sN.getHeight() + dh);
        return 0.0;
    }

    public ArrayList<Node> nodesToNudgeInGtree(int j, OpNNinfoSMCNode nani, GtreeNodeCriterion criterion) {
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        BitUnion spp = nani.getNodeUnion();
        BitUnion sppl = nani.getLftUnion();
        BitUnion sppr = nani.getRgtUnion();
        ArrayList<Node> nodes = new ArrayList<Node>(0);
        for (int i = 0; i < gTree.getNodeCount(); ++i) {
            BitUnion gunion;
            Node iN = gTree.getNode(i);
            if (iN.isLeaf() || !(gunion = this.unionArrays.gNodeUnion(j, i)).isContainedIn(spp)) continue;
            boolean mixed_with_pure_children = this.mixedWithPureChildren(j, i, sppl, sppr);
            boolean mixed_with_a_pure_child = this.mixedWithAPureChild(j, i, sppl, sppr);
            boolean mixed = this.mixedNode(j, i, sppl, sppr);
            assert (mixed || !mixed_with_pure_children);
            assert (mixed || !mixed_with_a_pure_child);
            assert (mixed_with_a_pure_child || !mixed_with_pure_children);
            if (!this.nodeWanted(criterion, mixed_with_pure_children, mixed_with_a_pure_child, mixed)) continue;
            nodes.add(iN);
        }
        return nodes;
    }

    private boolean mixedWithPureChildren(int j, int i, BitUnion sppL, BitUnion sppR) {
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        int lft = gTree.getNode(i).getChild(0).getNr();
        int rgt = gTree.getNode(i).getChild(1).getNr();
        boolean lftInL = this.unionArrays.gNodeUnion(j, lft).isContainedIn(sppL);
        boolean rgtInL = this.unionArrays.gNodeUnion(j, rgt).isContainedIn(sppL);
        boolean lftInR = this.unionArrays.gNodeUnion(j, lft).isContainedIn(sppR);
        boolean rgtInR = this.unionArrays.gNodeUnion(j, rgt).isContainedIn(sppR);
        return lftInL && rgtInR || lftInR && rgtInL;
    }

    private boolean mixedWithAPureChild(int j, int i, BitUnion sppL, BitUnion sppR) {
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        BitUnion gunion = this.unionArrays.gNodeUnion(j, i);
        if (gunion.overlaps(sppL) && gunion.overlaps(sppR)) {
            int lft = gTree.getNode(i).getChild(0).getNr();
            int rgt = gTree.getNode(i).getChild(1).getNr();
            boolean lftInL = this.unionArrays.gNodeUnion(j, lft).isContainedIn(sppL);
            boolean rgtInL = this.unionArrays.gNodeUnion(j, rgt).isContainedIn(sppL);
            boolean lftInR = this.unionArrays.gNodeUnion(j, lft).isContainedIn(sppR);
            boolean rgtInR = this.unionArrays.gNodeUnion(j, rgt).isContainedIn(sppR);
            return lftInL || rgtInR || lftInR || rgtInL;
        }
        return false;
    }

    private boolean mixedNode(int j, int i, BitUnion sppL, BitUnion sppR) {
        BitUnion gunion = this.unionArrays.gNodeUnion(j, i);
        return gunion.overlaps(sppL) && gunion.overlaps(sppR);
    }

    private boolean nodeWanted(GtreeNodeCriterion criterion, boolean mixed_with_pure_children, boolean mixed_with_a_pure_child, boolean mixed) {
        boolean wanted;
        switch (criterion) {
            case MIXED_WITH_PURE_CHILDREN: {
                wanted = mixed_with_pure_children;
                break;
            }
            case MIXED_WITH_A_PURE_CHILD: {
                wanted = mixed_with_a_pure_child;
                break;
            }
            case MIXED: {
                wanted = mixed;
                break;
            }
            default: {
                wanted = false;
                assert (false);
                break;
            }
        }
        return wanted;
    }

    private static class OpNNinfoSMCNode {
        private final int nodeNr;
        private final double nodeHeight;
        private final double lftHeight;
        private final BitUnion nodeUnion;
        private final BitUnion lftUnion;
        private final double rgtHeight;
        private final BitUnion rgtUnion;
        private final double ancHeight;

        public OpNNinfoSMCNode(int nodeNr, double nodeHeight, BitUnion nodeUnion, double lftHeight, BitUnion lftUnion, double rgtHeight, BitUnion rgtUnion, double ancHeight) {
            this.nodeNr = nodeNr;
            this.nodeHeight = nodeHeight;
            this.nodeUnion = nodeUnion;
            this.lftHeight = lftHeight;
            this.lftUnion = lftUnion;
            this.rgtHeight = rgtHeight;
            this.rgtUnion = rgtUnion;
            this.ancHeight = ancHeight;
        }

        public int getNodeNr() {
            return this.nodeNr;
        }

        public double getNodeHeight() {
            return this.nodeHeight;
        }

        public BitUnion getNodeUnion() {
            return this.nodeUnion;
        }

        public double getLftHeight() {
            return this.lftHeight;
        }

        public BitUnion getLftUnion() {
            return this.lftUnion;
        }

        public double getRgtHeight() {
            return this.rgtHeight;
        }

        public BitUnion getRgtUnion() {
            return this.rgtUnion;
        }

        public double getAncHeight() {
            return this.ancHeight;
        }
    }

    public static enum GtreeNodeCriterion {
        MIXED_WITH_PURE_CHILDREN,
        MIXED_WITH_A_PURE_CHILD,
        MIXED;

    }
}

