/*
 * Decompiled with CFR 0.152.
 */
package stacey;

import beast.core.Description;
import beast.core.Input;
import beast.core.Operator;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeInterface;
import beast.util.Randomizer;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import stacey.debugtune.Checks;
import stacey.util.Bindings;
import stacey.util.BitUnion;
import stacey.util.UnionArrays;

@Description(value="A move which does a subtree-prune-regraft move on a smcTree node,then carries a coordinated set of similar moves on the gene trees to maintain compatibility.")
public class CoordinatedPruneRegraft
extends Operator {
    public Input<Tree> smcTreeInput = new Input("smcTree", "The species tree or minimal clusters tree", Input.Validate.REQUIRED);
    public Input<List<Tree>> geneTreesInput = new Input("geneTree", "All gene trees", new ArrayList());
    public Input<Long> delayInput = new Input("delay", "Number of times the operator is disabled.");
    private Tree sTree;
    private List<Tree> gTrees;
    private long delay = 0L;
    private int callCount = 0;
    private boolean sTreeTooSmall;
    private Bindings bindings;
    private UnionArrays unionArrays;
    private final boolean debugFlag = Boolean.valueOf(System.getProperty("stacey.debug"));
    private int numberofdebugchecks = 0;
    private static final int maxnumberofdebugchecks = 100000;
    private final int debugMaxBranches = 100;
    private int[] debugRejectCounts;
    private int[] debugAcceptCounts;
    private int debugLastNofBranches;
    private static final Comparator<OpCPRinfoGTreeSpecification> GTREESPRSPEC_ORDER = new Comparator<OpCPRinfoGTreeSpecification>(){

        @Override
        public int compare(OpCPRinfoGTreeSpecification a, OpCPRinfoGTreeSpecification b) {
            return Double.compare(b.height, a.height);
        }
    };

    public void initAndValidate() {
        this.sTree = (Tree)this.smcTreeInput.get();
        this.gTrees = (List)this.geneTreesInput.get();
        boolean bl = this.sTreeTooSmall = this.sTree.getLeafNodeCount() < 3;
        if (this.delayInput != null && this.delayInput.get() != null) {
            this.delay = (Long)this.delayInput.get();
        }
        this.bindings = Bindings.initialise((TreeInterface)this.sTree, this.gTrees);
        this.unionArrays = UnionArrays.initialise((TreeInterface)this.sTree, this.gTrees, this.bindings);
        if (this.debugFlag) {
            this.debugRejectCounts = new int[101];
            this.debugAcceptCounts = new int[101];
        }
    }

    public double proposal() {
        if (this.sTreeTooSmall) {
            return Double.NEGATIVE_INFINITY;
        }
        ++this.callCount;
        if ((long)this.callCount < this.delay) {
            return Double.NEGATIVE_INFINITY;
        }
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            this.sTree = (Tree)this.smcTreeInput.get();
            this.gTrees = (List)this.geneTreesInput.get();
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "CoordinatedPruneRegraft", "before move");
            ++this.numberofdebugchecks;
        }
        this.unionArrays.update();
        double logHR = this.doCoordinatedPruneRegraft(false);
        this.unionArrays.reset();
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            this.sTree = (Tree)this.smcTreeInput.get();
            this.gTrees = (List)this.geneTreesInput.get();
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "CoordinatedPruneRegraft", "after move");
            ++this.numberofdebugchecks;
        }
        return logHR;
    }

    public void accept() {
        super.accept();
        if (this.debugFlag && this.debugLastNofBranches < 100) {
            int n = this.debugLastNofBranches;
            this.debugAcceptCounts[n] = this.debugAcceptCounts[n] + 1;
        }
    }

    public void reject(int reason) {
        super.reject(reason);
        if (this.debugFlag && this.debugLastNofBranches < 100) {
            int n = this.debugLastNofBranches;
            this.debugRejectCounts[n] = this.debugRejectCounts[n] + 1;
        }
    }

    public void storeToFile(PrintWriter out) {
        super.storeToFile(out);
        if (this.debugFlag) {
            out.print("{id:\"" + this.getID() + "\"");
            for (int nb = 0; nb < 100; ++nb) {
                out.print(" " + nb + ", accept:" + this.debugAcceptCounts[nb] + ", reject:" + this.debugRejectCounts[nb]);
            }
            out.print("}");
        }
    }

    private double doCoordinatedPruneRegraft(boolean NNI) {
        this.sTree.startEditing((Operator)this);
        for (int j = 0; j < this.gTrees.size(); ++j) {
            this.gTrees.get(j).startEditing((Operator)this);
        }
        int d = -1;
        int s = -1;
        s = this.chooseSubtreeForSPR(NNI);
        Integer[] debugnb = new Integer[1];
        d = this.chooseDestinationForSPR(s, NNI, debugnb);
        this.debugLastNofBranches = debugnb[0];
        int x = this.siblingOf(this.sTree.getNode(s)).getNr();
        BitUnion sUnion = new BitUnion(this.bindings.smcTreeTipCount());
        BitUnion xUnion = new BitUnion(this.bindings.smcTreeTipCount());
        sUnion.replaceWith(this.unionArrays.sNodeUnion(s));
        xUnion.replaceWith(this.unionArrays.sNodeUnion(x));
        double logHR = 0.0;
        OpCPRinfoSMCTreeSpecification smi = this.makeSPRmoveInfo(s, d);
        this.doSPRmove(s, d);
        logHR += this.doAllCoordinatedSPRMoves(smi);
        this.unionArrays.reset();
        this.unionArrays.update();
        int sNew = this.unionArrays.smcTreeNodeNrOfUnion(sUnion);
        int xNew = this.unionArrays.smcTreeNodeNrOfUnion(xUnion);
        OpCPRinfoSMCTreeSpecification revsmi = this.makeSPRmoveInfo(sNew, xNew);
        return logHR -= this.HRForAllCoordinatedSPRMoves(revsmi);
    }

    private double doAllCoordinatedSPRMoves(OpCPRinfoSMCTreeSpecification sppSPRInfo) {
        double logHR = 0.0;
        for (int j = 0; j < this.gTrees.size(); ++j) {
            logHR += this.doCoordinatedSPRMoves(j, sppSPRInfo);
        }
        return logHR;
    }

    private double HRForAllCoordinatedSPRMoves(OpCPRinfoSMCTreeSpecification sppSPRInfo) {
        double logHR = 0.0;
        for (int j = 0; j < this.gTrees.size(); ++j) {
            logHR += this.HRForCoordinatedSPRMoves(j, sppSPRInfo);
        }
        return logHR;
    }

    public int chooseSubtreeForSPR(boolean NNI) {
        int s;
        assert (this.sTree.getNodeCount() > 3);
        if (NNI) {
            while (this.rejectSubtreeForNNI(s = Randomizer.nextInt((int)this.sTree.getNodeCount()))) {
            }
        } else {
            while (this.rejectSubtree(s = Randomizer.nextInt((int)this.sTree.getNodeCount()))) {
            }
        }
        return s;
    }

    private boolean rejectSubtree(int s) {
        return this.sTree.getNode(s).isRoot() || this.sTree.getNode(s).getParent().isRoot();
    }

    private boolean rejectSubtreeForNNI(int s) {
        Node sN = this.sTree.getNode(s);
        if (sN.isRoot() || sN.getParent().isRoot()) {
            return true;
        }
        Node ancN = sN.getParent();
        Node sibancN = this.siblingOf(ancN);
        return sibancN.getHeight() >= ancN.getHeight();
    }

    private int chooseDestinationForSPR(int s, boolean NNI, Integer[] debugnb) {
        Node sN = this.sTree.getNode(s);
        if (NNI) {
            Node ancN = sN.getParent();
            debugnb[0] = 3;
            return this.siblingOf(ancN).getNr();
        }
        Node sibN = this.siblingOf(sN);
        Node ancN = sN.getParent();
        double h = ancN.getHeight();
        ArrayList<Integer> dests = new ArrayList<Integer>(1 + this.sTree.getNodeCount() / 2);
        ArrayList<Integer> nbs = new ArrayList<Integer>(1 + this.sTree.getNodeCount() / 2);
        for (int d = 0; d < this.sTree.getNodeCount(); ++d) {
            Node dN = this.sTree.getNode(d);
            if (d == s || d == ancN.getNr() || d == sibN.getNr() || !(dN.getHeight() <= h) || dN.isRoot() || !(dN.getParent().getHeight() >= h)) continue;
            dests.add(d);
            int nb = this.debugNumberOfBranchesBetween(s, d);
            assert (nb >= 3);
            nbs.add(nb);
        }
        int choice = Randomizer.nextInt((int)dests.size());
        int d = (Integer)dests.get(choice);
        debugnb[0] = (Integer)nbs.get(choice);
        return d;
    }

    private OpCPRinfoSMCTreeSpecification makeSPRmoveInfo(int s, int d) {
        BitUnion sppS = this.unionArrays.sNodeUnion(s);
        int m = this.mrcaOfPair(s, d);
        ArrayList<Integer> ds = new ArrayList<Integer>();
        while (d != m) {
            ds.add(d);
            d = this.sTree.getNode(d).getParent().getNr();
        }
        BitUnion[] sppDs = new BitUnion[ds.size()];
        double[] heightsDancs = new double[ds.size()];
        for (int i = 0; i < ds.size(); ++i) {
            int di = (Integer)ds.get(i);
            sppDs[i] = this.unionArrays.sNodeUnion(di);
            heightsDancs[i] = this.sTree.getNode(di).getParent().getHeight();
        }
        OpCPRinfoSMCTreeSpecification smi = new OpCPRinfoSMCTreeSpecification(sppS, sppDs, heightsDancs);
        return smi;
    }

    private void doSPRmove(int s, int d) {
        Node sr = this.sTree.getNode(s);
        Node dr = this.sTree.getNode(d);
        Node ar = sr.getParent();
        Node xr = ar.getChild(0) == sr ? ar.getChild(1) : ar.getChild(0);
        Node yr = ar.getParent();
        Node zr = dr.getParent();
        zr.removeChild(dr);
        ar.removeChild(xr);
        yr.removeChild(ar);
        zr.addChild(ar);
        ar.addChild(dr);
        yr.addChild(xr);
    }

    private int mrcaOfPair(int x, int y) {
        while (x != y) {
            Node xN = this.sTree.getNode(x);
            Node yN = this.sTree.getNode(y);
            if (xN.getHeight() < yN.getHeight()) {
                x = this.sTree.getNode(x).getParent().getNr();
                continue;
            }
            y = this.sTree.getNode(y).getParent().getNr();
        }
        return x;
    }

    private Node siblingOf(Node xN) {
        assert (!xN.isRoot());
        Node ancN = xN.getParent();
        if (ancN.getChild(0) == xN) {
            return ancN.getChild(1);
        }
        return ancN.getChild(0);
    }

    private int debugNumberOfBranchesBetween(int x, int y) {
        int n = 0;
        while (x != y) {
            Node xN = this.sTree.getNode(x);
            Node yN = this.sTree.getNode(y);
            if (xN.getHeight() < yN.getHeight()) {
                x = this.sTree.getNode(x).getParent().getNr();
            } else {
                y = this.sTree.getNode(y).getParent().getNr();
            }
            ++n;
        }
        return n;
    }

    private double doCoordinatedSPRMoves(int j, OpCPRinfoSMCTreeSpecification sppSPRInfo) {
        ArrayList<OpCPRinfoGTreeSpecification> gtsprss = this.makeListOfSPRSpecs(j, sppSPRInfo);
        Collections.sort(gtsprss, GTREESPRSPEC_ORDER);
        double logHR = 0.0;
        for (OpCPRinfoGTreeSpecification gtspr : gtsprss) {
            this.doSPRmove(gtspr.getSource(), gtspr.getDestination());
            logHR += Math.log(gtspr.getChoiceCount());
        }
        return logHR;
    }

    private double HRForCoordinatedSPRMoves(int j, OpCPRinfoSMCTreeSpecification sppSPRInfo) {
        ArrayList<OpCPRinfoGTreeSpecification> gtsprss = this.makeListOfSPRSpecs(j, sppSPRInfo);
        double logHR = 0.0;
        for (OpCPRinfoGTreeSpecification gtspr : gtsprss) {
            logHR += Math.log(gtspr.getChoiceCount());
        }
        return logHR;
    }

    private ArrayList<OpCPRinfoGTreeSpecification> makeListOfSPRSpecs(int j, OpCPRinfoSMCTreeSpecification sppSPRInfo) {
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        ArrayList<OpCPRinfoGTreeSpecification> gtsprss = new ArrayList<OpCPRinfoGTreeSpecification>(0);
        for (int a = 0; a < gTree.getNodeCount(); ++a) {
            Node aN = gTree.getNode(a);
            double h = aN.getHeight();
            double[] heightsDancs = sppSPRInfo.getHeightsDancs();
            double hM = heightsDancs[heightsDancs.length - 1];
            if (aN.isLeaf() || !(h <= hM)) continue;
            int x = gTree.getNode(a).getChild(0).getNr();
            int y = gTree.getNode(a).getChild(1).getNr();
            boolean xinS = this.unionArrays.gNodeUnion(j, x).isContainedIn(sppSPRInfo.getSppS());
            boolean yinS = this.unionArrays.gNodeUnion(j, y).isContainedIn(sppSPRInfo.getSppS());
            if (xinS && !yinS) {
                OpCPRinfoGTreeSpecification xspr = this.chooseDestinationAndConstructMove(j, a, x, sppSPRInfo);
                gtsprss.add(xspr);
            }
            if (!yinS || xinS) continue;
            OpCPRinfoGTreeSpecification yspr = this.chooseDestinationAndConstructMove(j, a, y, sppSPRInfo);
            gtsprss.add(yspr);
        }
        return gtsprss;
    }

    private OpCPRinfoGTreeSpecification chooseDestinationAndConstructMove(int j, int anc, int s, OpCPRinfoSMCTreeSpecification sppSPRInfo) {
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        Node source = gTree.getNode(s);
        double height = gTree.getNode(anc).getHeight();
        double[] heightsDancs = sppSPRInfo.getHeightsDancs();
        int nDs = heightsDancs.length;
        assert (height <= heightsDancs[nDs - 1]);
        int d = 0;
        while (height > heightsDancs[d]) {
            ++d;
        }
        ArrayList<Node> possibleDestinations = new ArrayList<Node>();
        for (int x = 0; x < gTree.getNodeCount(); ++x) {
            Node xN = gTree.getNode(x);
            if (!(xN.getHeight() <= height) || !(height <= xN.getParent().getHeight()) || !this.unionArrays.gNodeUnion(j, x).isContainedIn(sppSPRInfo.getSppDs()[d])) continue;
            possibleDestinations.add(xN);
        }
        int choiceCount = possibleDestinations.size();
        assert (choiceCount > 0);
        Node destination = (Node)possibleDestinations.get(Randomizer.nextInt((int)choiceCount));
        OpCPRinfoGTreeSpecification spr = new OpCPRinfoGTreeSpecification(source, height, choiceCount, destination);
        return spr;
    }

    private void doSPRmove(Node s, Node d) {
        Node a = s.getParent();
        Node x = a.getChild(0) == s ? a.getChild(1) : a.getChild(0);
        Node y = a.getParent();
        Node z = d.getParent();
        z.removeChild(d);
        a.removeChild(x);
        y.removeChild(a);
        z.addChild(a);
        a.addChild(d);
        y.addChild(x);
    }

    private static class OpCPRinfoSMCTreeSpecification {
        private final BitUnion sppS;
        private final BitUnion[] sppDs;
        private final double[] heightsDancs;

        public OpCPRinfoSMCTreeSpecification(BitUnion sppS, BitUnion[] sppDs, double[] heightsDancs) {
            this.sppS = sppS;
            this.sppDs = sppDs;
            this.heightsDancs = heightsDancs;
        }

        public BitUnion getSppS() {
            return this.sppS;
        }

        public BitUnion[] getSppDs() {
            return this.sppDs;
        }

        public double[] getHeightsDancs() {
            return this.heightsDancs;
        }
    }

    private static class OpCPRinfoGTreeSpecification {
        private final Node source;
        private final double height;
        private final int choiceCount;
        private final Node destination;

        public OpCPRinfoGTreeSpecification(Node source, double height, int choiceCount, Node destination) {
            this.source = source;
            this.height = height;
            this.choiceCount = choiceCount;
            this.destination = destination;
        }

        public Node getSource() {
            return this.source;
        }

        public double getHeight() {
            return this.height;
        }

        public int getChoiceCount() {
            return this.choiceCount;
        }

        public Node getDestination() {
            return this.destination;
        }
    }
}

