/*
 * Decompiled with CFR 0.152.
 */
package stacey;

import beast.core.Description;
import beast.core.Input;
import beast.core.Operator;
import beast.core.parameter.RealParameter;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeInterface;
import beast.util.Randomizer;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import stacey.debugtune.Checks;
import stacey.util.Bindings;
import stacey.util.BitUnion;
import stacey.util.UnionArrays;

@Description(value="An implementation of the rubber-band operator from Yang and Rannala (2003).It changes the height of a SMC-tree node, and all gene tree nodes in the three adjacent branches.")
public class ThreeBranchAdjuster
extends Operator {
    public Input<Tree> smcTreeInput = new Input("smcTree", "The species tree or minimal clusters tree", Input.Validate.REQUIRED);
    public Input<List<Tree>> geneTreesInput = new Input("geneTree", "All gene trees", new ArrayList());
    public Input<RealParameter> popSFInput = new Input("popSF", "The population scaling factor for the STACEY coalescent", Input.Validate.REQUIRED);
    public Input<Double> tuningInput = new Input("tuning", "A tuning parameter. The default is 1.0, and allowed values are positive numbers. Larger values make bigger jumps and so decrease the acceptance ratio. Experiments in the range [0.1,10.0] seem sensible.");
    public Input<Long> delayInput = new Input("delay", "Number of times the operator is disabled.");
    private Tree sTree;
    private List<Tree> gTrees;
    private int callCount = 0;
    private boolean sTreeTooSmall;
    private long delay = 0L;
    private double tuning = 1.0;
    private UnionArrays unionArrays;
    private final boolean debugFlag = Boolean.valueOf(System.getProperty("stacey.debug"));
    private int numberofdebugchecks = 0;
    private static final int maxnumberofdebugchecks = 100000;
    private static final Comparator<Node> COALESCENCE_ORDER = new Comparator<Node>(){

        @Override
        public int compare(Node a, Node b) {
            return Double.compare(a.getHeight(), b.getHeight());
        }
    };

    public void initAndValidate() {
        this.sTree = (Tree)this.smcTreeInput.get();
        this.gTrees = (List)this.geneTreesInput.get();
        boolean bl = this.sTreeTooSmall = this.sTree.getLeafNodeCount() < 3;
        if (this.delayInput.get() != null && this.delayInput.get() != null) {
            this.delay = (Long)this.delayInput.get();
        }
        Bindings bindings = Bindings.initialise((TreeInterface)this.sTree, this.gTrees);
        this.unionArrays = UnionArrays.initialise((TreeInterface)this.sTree, this.gTrees, bindings);
    }

    public double proposal() {
        if (this.sTreeTooSmall) {
            return Double.NEGATIVE_INFINITY;
        }
        ++this.callCount;
        if ((long)this.callCount < this.delay) {
            return Double.NEGATIVE_INFINITY;
        }
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "ThreeBranchAdjuster", "before move");
            ++this.numberofdebugchecks;
        }
        this.unionArrays.update();
        double logHR = this.doThreeBranchAdjust();
        this.unionArrays.reset();
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "ThreeBranchAdjuster", "after move");
            ++this.numberofdebugchecks;
        }
        return logHR;
    }

    private double doThreeBranchAdjust() {
        double hgtRange;
        double newHeight;
        double sAncHeight;
        int sNodeNr;
        double logHR = 0.0;
        for (int j = 0; j < this.gTrees.size(); ++j) {
            this.gTrees.get(j).startEditing(null);
        }
        while (this.sTree.getNode(sNodeNr = Randomizer.nextInt((int)this.sTree.getNodeCount())).isLeaf()) {
        }
        Node sNode = this.sTree.getNode(sNodeNr);
        boolean isRoot = this.sTree.getNode(sNodeNr).isRoot();
        double sHeight = sNode.getHeight();
        if (isRoot) {
            double maxGtreeHeight = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < this.gTrees.size(); ++j) {
                Tree gTree = this.gTrees.get(j);
                maxGtreeHeight = Math.max(maxGtreeHeight, gTree.getRoot().getHeight());
            }
            sAncHeight = maxGtreeHeight;
        } else {
            sAncHeight = sNode.getParent().getHeight();
        }
        double sLftHeight = sNode.getChild(0).getHeight();
        double sRgtHeight = sNode.getChild(1).getHeight();
        BitUnion sUnion = this.unionArrays.sNodeUnion(sNodeNr);
        BitUnion sLftUnion = this.unionArrays.sNodeUnion(sNode.getChild(0).getNr());
        BitUnion sRgtUnion = this.unionArrays.sNodeUnion(sNode.getChild(1).getNr());
        ArrayList<Node> gNodes = new ArrayList<Node>();
        ArrayList<Node> gLftNodes = new ArrayList<Node>();
        ArrayList<Node> gRgtNodes = new ArrayList<Node>();
        for (int j = 0; j < this.gTrees.size(); ++j) {
            Tree gTree = this.gTrees.get(j);
            for (int i = gTree.getLeafNodeCount(); i < gTree.getNodeCount(); ++i) {
                double yHeight = gTree.getNode(i).getHeight();
                BitUnion yUnion = this.unionArrays.gNodeUnion(j, i);
                int timesAdded = 0;
                if (yUnion.isContainedIn(sUnion) && yHeight < sAncHeight && yHeight >= sHeight) {
                    gNodes.add(gTree.getNode(i));
                    ++timesAdded;
                }
                if (yUnion.isContainedIn(sLftUnion) && yHeight < sHeight && yHeight > sLftHeight) {
                    gLftNodes.add(gTree.getNode(i));
                    ++timesAdded;
                }
                if (yUnion.isContainedIn(sRgtUnion) && yHeight < sHeight && yHeight > sRgtHeight) {
                    gRgtNodes.add(gTree.getNode(i));
                    ++timesAdded;
                }
                assert (timesAdded <= 1);
            }
        }
        double minSH = Math.max(sLftHeight, sRgtHeight);
        double maxSH = sAncHeight;
        double popSF = ((RealParameter)this.popSFInput.get()).getValue();
        if (this.tuningInput != null && this.tuningInput.get() != null) {
            this.tuning = (Double)this.tuningInput.get();
        }
        if ((newHeight = sHeight + Randomizer.uniform((double)(-(hgtRange = Math.min(this.tuning * 40.0 * popSF / (double)this.gTrees.size(), 0.5 * (maxSH - minSH)))), (double)hgtRange)) < minSH) {
            newHeight = 2.0 * minSH - newHeight;
        } else if (newHeight > maxSH) {
            newHeight = 2.0 * maxSH - newHeight;
        }
        sNode.setHeight(newHeight);
        logHR += this.setNewGNodeHeightsFromSHeight(gNodes, sHeight, newHeight, sAncHeight);
        logHR += this.setNewGNodeHeightsFromSHeight(gLftNodes, sHeight, newHeight, sLftHeight);
        return logHR += this.setNewGNodeHeightsFromSHeight(gRgtNodes, sHeight, newHeight, sRgtHeight);
    }

    private double setNewGNodeHeightsFromSHeight(ArrayList<Node> nodes, double oldSH, double newSH, double limit) {
        double logHR = 0.0;
        for (int i = 0; i < nodes.size(); ++i) {
            Node node = nodes.get(i);
            double oldGH = node.getHeight();
            double scale = (newSH - limit) / (oldSH - limit);
            double newGH = limit + (oldGH - limit) * scale;
            node.setHeight(newGH);
            logHR += Math.log(scale);
        }
        return logHR;
    }

    private void setNewGNodeHeightsFromWeights(ArrayList<Node> nodes, ArrayList<Double> weights, double eta) {
        for (int i = 0; i < nodes.size(); ++i) {
            Node node = nodes.get(i);
            double wt = weights.get(i + 1);
            node.setHeight(node.getHeight() + wt * eta);
        }
    }

    private void updateBoundsFromArrays(double[] interval, double firstH, ArrayList<Node> nodes, double lastH, ArrayList<Double> weights) {
        int n = nodes.size();
        if (n > 0) {
            this.updateBounds(interval, firstH, nodes.get(0).getHeight(), weights.get(0), weights.get(1));
            int i = 0;
            while (i + 1 < n) {
                this.updateBounds(interval, nodes.get(i).getHeight(), nodes.get(i + 1).getHeight(), weights.get(i + 1), weights.get(i + 2));
                ++i;
            }
            this.updateBounds(interval, nodes.get(n - 1).getHeight(), lastH, weights.get(n), weights.get(n + 1));
        } else {
            this.updateBounds(interval, firstH, lastH, weights.get(0), weights.get(1));
        }
    }

    private void updateBounds(double[] interval, double h1, double h2, double w1, double w2) {
        double wDiff = w2 - w1;
        double hDiff = h1 - h2;
        if (wDiff > 0.0) {
            interval[0] = Math.max(interval[0], hDiff / wDiff);
        }
        if (wDiff < 0.0) {
            interval[1] = Math.min(interval[1], hDiff / wDiff);
        }
    }
}

