/*
 * Decompiled with CFR 0.152.
 */
package stacey;

import beast.core.Description;
import beast.core.Input;
import beast.core.Operator;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeInterface;
import beast.util.Randomizer;
import java.util.ArrayList;
import java.util.List;
import stacey.debugtune.Checks;
import stacey.util.Bindings;
import stacey.util.BitUnion;
import stacey.util.UnionArrays;

@Description(value="A move which scales some node heights in the SMC-tree and gene trees. The further away a node is from a 'focal' SMC-tree node, the less it is affected.")
public class FocusedNodeHeightScaler
extends Operator {
    public Input<Tree> smcTreeInput = new Input("smcTree", "The species tree or minimal clusters tree", Input.Validate.REQUIRED);
    public Input<List<Tree>> geneTreesInput = new Input("geneTree", "All gene trees", new ArrayList());
    public Input<Long> delayInput = new Input("delay", "Number of times the operator is disabled.");
    private Tree sTree;
    private List<Tree> gTrees;
    private long delay = 0L;
    private int callCount = 0;
    private boolean sTreeTooSmall;
    private UnionArrays unionArrays;
    private double[] sTreeLogHeights;
    private int[] sTreeDistances;
    private double[] sTreeWeights;
    private double[][] gTreeLogHeights;
    private int[][] gTreeDistances;
    private double[][] gTreeWeights;
    private final boolean debugFlag = Boolean.valueOf(System.getProperty("stacey.debug"));
    private int numberofdebugchecks = 0;
    private static final int maxnumberofdebugchecks = 100000;

    public void initAndValidate() {
        this.sTree = (Tree)this.smcTreeInput.get();
        this.gTrees = (List)this.geneTreesInput.get();
        boolean bl = this.sTreeTooSmall = this.sTree.getLeafNodeCount() < 5;
        if (this.delayInput != null && this.delayInput.get() != null) {
            this.delay = (Long)this.delayInput.get();
        }
        Bindings bindings = Bindings.initialise((TreeInterface)this.sTree, this.gTrees);
        this.unionArrays = UnionArrays.initialise((TreeInterface)this.sTree, this.gTrees, bindings);
        this.sTreeLogHeights = new double[((Tree)this.smcTreeInput.get()).getNodeCount()];
        this.sTreeDistances = new int[((Tree)this.smcTreeInput.get()).getNodeCount()];
        this.sTreeWeights = new double[((Tree)this.smcTreeInput.get()).getNodeCount()];
        this.gTreeLogHeights = new double[this.gTrees.size()][];
        this.gTreeDistances = new int[this.gTrees.size()][];
        this.gTreeWeights = new double[this.gTrees.size()][];
        for (int j = 0; j < this.gTrees.size(); ++j) {
            TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
            this.gTreeLogHeights[j] = new double[gTree.getNodeCount()];
            this.gTreeDistances[j] = new int[gTree.getNodeCount()];
            this.gTreeWeights[j] = new double[gTree.getNodeCount()];
        }
    }

    public double proposal() {
        if (this.sTreeTooSmall) {
            return Double.NEGATIVE_INFINITY;
        }
        ++this.callCount;
        if ((long)this.callCount < this.delay) {
            return Double.NEGATIVE_INFINITY;
        }
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "FocusedNodeHeightScaler", "before move");
            ++this.numberofdebugchecks;
        }
        this.unionArrays.update();
        double logHR = this.doFocusedScalerMove();
        this.unionArrays.reset();
        if (this.debugFlag && this.numberofdebugchecks < 100000) {
            Checks.allTreesAndCompatibility((TreeInterface)this.sTree, this.gTrees, "FocusedNodeHeightScaler", "after move");
            ++this.numberofdebugchecks;
        }
        return logHR;
    }

    private double doFocusedScalerMove() {
        double[] jWeights;
        this.sTree.startEditing((Operator)this);
        for (int j = 0; j < this.gTrees.size(); ++j) {
            this.gTrees.get(j).startEditing((Operator)this);
        }
        OpFSinfoSMCNodeUnions fsNUnions = this.setUpForFSMove(this.sTreeLogHeights, this.sTreeDistances);
        int rootDistance = this.sTreeDistances[this.sTreeDistances.length - 1];
        for (int n = 0; n < this.sTreeDistances.length; ++n) {
            this.sTreeWeights[n] = this.distanceToWeight(this.sTreeDistances[n], rootDistance);
        }
        for (int j = 0; j < this.gTrees.size(); ++j) {
            int[] jDistances = this.gTreeDistances[j];
            double[] jLogHeights = this.gTreeLogHeights[j];
            this.fillInGTreeDistancesLogHeights(j, fsNUnions, jDistances, jLogHeights);
            int jRootDistance = jDistances[jDistances.length - 1];
            jWeights = this.gTreeWeights[j];
            for (int n = 0; n < jDistances.length; ++n) {
                jWeights[n] = this.distanceToWeight(jDistances[n], jRootDistance);
            }
        }
        double[] interval = new double[]{Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY};
        this.getSTreeFSBounds(interval, this.sTreeLogHeights, this.sTreeWeights);
        OpFSinfoSMCNode[] fsNodeInfos = this.getAllInternalFSMoveSNInfos(this.sTreeLogHeights, this.sTreeWeights);
        for (int j = 0; j < this.gTrees.size(); ++j) {
            double[] jLogHeights = this.gTreeLogHeights[j];
            jWeights = this.gTreeWeights[j];
            this.getGTreeFSBounds(interval, j, fsNodeInfos, jLogHeights, jWeights);
        }
        assert (!Double.isInfinite(interval[0]));
        assert (!Double.isInfinite(interval[1]));
        assert (interval[0] < 0.0);
        assert (interval[1] > 0.0);
        double logHR = 0.0;
        double nearlyOne = 0.99999999;
        double logSF = Randomizer.uniform((double)(nearlyOne * interval[0]), (double)(nearlyOne * interval[1]));
        logHR += this.fsMoveDoScale((TreeInterface)this.sTree, logSF, this.sTreeWeights);
        for (int j = 0; j < this.gTrees.size(); ++j) {
            double[] jWeights2 = this.gTreeWeights[j];
            logHR += this.fsMoveDoScale((TreeInterface)this.gTrees.get(j), logSF, jWeights2);
        }
        return logHR;
    }

    private OpFSinfoSMCNodeUnions setUpForFSMove(double[] logHeights, int[] distances) {
        int s;
        this.fillInAnyTreeLogHeights((TreeInterface)this.sTree, logHeights);
        while (!this.nodeCanBeFocus(s = Randomizer.nextInt((int)this.sTree.getNodeCount()))) {
        }
        Node sN = this.sTree.getNode(s);
        BitUnion sunion = this.unionArrays.sNodeUnion(s);
        BitUnion lftunion = this.unionArrays.sNodeUnion(sN.getChild(0).getNr());
        BitUnion rgtunion = this.unionArrays.sNodeUnion(sN.getChild(1).getNr());
        this.fillInSTreeDistances(sN, distances);
        return new OpFSinfoSMCNodeUnions(sunion, lftunion, rgtunion);
    }

    private boolean nodeCanBeFocus(int s) {
        if (this.sTree.getNode(s).isLeaf()) {
            return false;
        }
        if (this.sTree.getNode(s).isRoot()) {
            return false;
        }
        Node lftN = this.sTree.getNode(s).getChild(0);
        Node rgtN = this.sTree.getNode(s).getChild(1);
        return !lftN.isLeaf() || !rgtN.isLeaf();
    }

    private void getSTreeFSBounds(double[] interval, double[] logHeights, double[] weights) {
        this.getAnyTreeInternalFSBounds(interval, (TreeInterface)this.sTree, logHeights, weights);
    }

    private void fillInSTreeDistances(Node node, int[] distances) {
        assert (this.sTree.getNodeCount() == distances.length);
        for (int i = 0; i < distances.length; ++i) {
            distances[i] = Integer.MAX_VALUE;
        }
        distances[node.getNr()] = 0;
        while (!node.isRoot()) {
            Node ancN = node.getParent();
            int anc = ancN.getNr();
            distances[anc] = Math.min(distances[anc], distances[node.getNr()] + 1);
            node = ancN;
        }
        this.fillInAnySubtreeDistances(this.sTree.getRoot(), distances);
    }

    private OpFSinfoSMCNode[] getAllInternalFSMoveSNInfos(double[] logHeights, double[] weights) {
        OpFSinfoSMCNode[] fsNInfos = new OpFSinfoSMCNode[this.sTree.getInternalNodeCount()];
        int i = 0;
        for (int s = 0; s < this.sTree.getNodeCount(); ++s) {
            Node node = this.sTree.getNode(s);
            if (node.isLeaf()) continue;
            assert (i < fsNInfos.length);
            Node lftN = node.getChild(0);
            Node rgtN = node.getChild(1);
            BitUnion sunion = this.unionArrays.sNodeUnion(s);
            BitUnion lftunion = this.unionArrays.sNodeUnion(lftN.getNr());
            BitUnion rgtunion = this.unionArrays.sNodeUnion(rgtN.getNr());
            double logHeight = logHeights[s];
            double weight = weights[s];
            fsNInfos[i++] = new OpFSinfoSMCNode(sunion, lftunion, rgtunion, logHeight, weight);
        }
        assert (i == fsNInfos.length);
        return fsNInfos;
    }

    private void fillInGTreeDistancesLogHeights(int j, OpFSinfoSMCNodeUnions fsNUnions, int[] distances, double[] logHeights) {
        this.fillInGTreeDistances(j, fsNUnions, distances);
        this.fillInGTreeLogHeights(j, logHeights);
    }

    private void getGTreeFSBounds(double[] interval, int j, OpFSinfoSMCNode[] fsNodeInfos, double[] logHeights, double[] weights) {
        this.getGTreeFocusedScalerInternalBounds(interval, j, logHeights, weights);
        for (int i = 0; i < fsNodeInfos.length; ++i) {
            this.getGTreeFocusedScalerCompatibilityBounds(interval, j, fsNodeInfos[i], logHeights, weights);
        }
    }

    private void fillInGTreeDistances(int j, OpFSinfoSMCNodeUnions fsNUnions, int[] distances) {
        int i;
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        assert (gTree.getNodeCount() == distances.length);
        BitUnion spp = fsNUnions.getNodeUnion();
        BitUnion sppl = fsNUnions.getLftUnion();
        BitUnion sppr = fsNUnions.getRgtUnion();
        for (i = 0; i < gTree.getNodeCount(); ++i) {
            distances[i] = this.gNodeSNodeCloselyLinked(j, i, spp, sppl, sppr) ? 1 : (this.gNodeSNodeLinked(j, i, sppl, sppr) ? 2 : Integer.MAX_VALUE);
        }
        for (i = 0; i < gTree.getNodeCount(); ++i) {
            if (distances[i] >= Integer.MAX_VALUE) continue;
            while (!gTree.getNode(i).isRoot()) {
                int anc = gTree.getNode(i).getParent().getNr();
                distances[anc] = Math.min(distances[anc], distances[i] + 1);
                i = anc;
            }
        }
        this.fillInAnySubtreeDistances(gTree.getRoot(), distances);
    }

    private void fillInGTreeLogHeights(int j, double[] logHeights) {
        this.fillInAnyTreeLogHeights((TreeInterface)this.gTrees.get(j), logHeights);
    }

    private void getGTreeFocusedScalerInternalBounds(double[] interval, int j, double[] logHeights, double[] weights) {
        this.getAnyTreeInternalFSBounds(interval, (TreeInterface)this.gTrees.get(j), logHeights, weights);
    }

    private void getGTreeFocusedScalerCompatibilityBounds(double[] interval, int j, OpFSinfoSMCNode fsNodeInfo, double[] logHeights, double[] weights) {
        BitUnion sppL = fsNodeInfo.getLftUnion();
        BitUnion sppR = fsNodeInfo.getRgtUnion();
        double sLogHeight = fsNodeInfo.getLogHeight();
        double sWeight = fsNodeInfo.getWeight();
        ArrayList<Node> straddlers = this.unionArrays.getStraddlers(j, sppL, sppR);
        for (Node s : straddlers) {
            double weightDiff = sWeight - weights[s.getNr()];
            double loghDiff = logHeights[s.getNr()] - sLogHeight;
            if (weightDiff > 0.0) {
                interval[1] = Math.min(interval[1], loghDiff / weightDiff);
            }
            if (!(weightDiff < 0.0)) continue;
            interval[0] = Math.max(interval[0], loghDiff / weightDiff);
        }
    }

    private boolean gNodeSNodeCloselyLinked(int j, int i, BitUnion spp, BitUnion sppL, BitUnion sppR) {
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        Node iN = gTree.getNode(i);
        if (!iN.isLeaf() && this.unionArrays.gNodeUnion(j, i).isContainedIn(spp)) {
            int lft = iN.getChild(0).getNr();
            int rgt = iN.getChild(1).getNr();
            boolean LwithinL = this.unionArrays.gNodeUnion(j, lft).isContainedIn(sppL);
            boolean RwithinL = this.unionArrays.gNodeUnion(j, rgt).isContainedIn(sppL);
            boolean LwithinR = this.unionArrays.gNodeUnion(j, lft).isContainedIn(sppR);
            boolean RwithinR = this.unionArrays.gNodeUnion(j, rgt).isContainedIn(sppL);
            return LwithinL & RwithinR | LwithinR & RwithinL;
        }
        return false;
    }

    private boolean gNodeSNodeLinked(int j, int i, BitUnion sppL, BitUnion sppR) {
        TreeInterface gTree = (TreeInterface)this.gTrees.get(j);
        Node iN = gTree.getNode(i);
        if (!iN.isLeaf()) {
            boolean overlapsL = this.unionArrays.gNodeUnion(j, i).overlaps(sppL);
            boolean overlapsR = this.unionArrays.gNodeUnion(j, i).overlaps(sppR);
            boolean linked = overlapsL & overlapsR;
            int lft = iN.getChild(0).getNr();
            int rgt = iN.getChild(1).getNr();
            boolean LoverlapsL = this.unionArrays.gNodeUnion(j, lft).overlaps(sppL);
            boolean LoverlapsR = this.unionArrays.gNodeUnion(j, lft).overlaps(sppR);
            boolean Llinked = LoverlapsL & LoverlapsR;
            boolean RoverlapsL = this.unionArrays.gNodeUnion(j, rgt).overlaps(sppL);
            boolean RoverlapsR = this.unionArrays.gNodeUnion(j, rgt).overlaps(sppR);
            boolean Rlinked = RoverlapsL & RoverlapsR;
            return linked & !Llinked & !Rlinked;
        }
        return false;
    }

    private void fillInAnyTreeLogHeights(TreeInterface tree, double[] logHeights) {
        assert (tree.getNodeCount() == logHeights.length);
        for (int i = 0; i < logHeights.length; ++i) {
            logHeights[i] = !tree.getNode(i).isLeaf() ? Math.log(tree.getNode(i).getHeight()) : Double.NEGATIVE_INFINITY;
        }
    }

    private void fillInAnySubtreeDistances(Node node, int[] distances) {
        if (!node.isRoot()) {
            int anc = node.getParent().getNr();
            distances[node.getNr()] = Math.min(distances[node.getNr()], distances[anc] + 1);
        }
        if (!node.isLeaf()) {
            this.fillInAnySubtreeDistances(node.getChild(0), distances);
            this.fillInAnySubtreeDistances(node.getChild(1), distances);
        }
    }

    private double distanceToWeight(int distance, int rootDistance) {
        double alpha = 8.0;
        if (distance >= rootDistance) {
            return 0.0;
        }
        double x = (double)(rootDistance - distance) / (double)rootDistance;
        return (Math.pow(alpha, x) - 1.0) / (alpha - 1.0);
    }

    public void getAnyTreeInternalFSBounds(double[] interval, TreeInterface tree, double[] logHeights, double[] weights) {
        for (int n = 0; n < tree.getNodeCount(); ++n) {
            if (tree.getNode(n).isLeaf() || tree.getNode(n).isRoot()) continue;
            Node ancN = tree.getNode(n).getParent();
            double weightDiff = weights[n] - weights[ancN.getNr()];
            double loghDiff = logHeights[ancN.getNr()] - logHeights[n];
            if (weightDiff > 0.0) {
                interval[1] = Math.min(interval[1], loghDiff / weightDiff);
            }
            if (!(weightDiff < 0.0)) continue;
            interval[0] = Math.max(interval[0], loghDiff / weightDiff);
        }
    }

    public double fsMoveDoScale(TreeInterface tree, double logSF, double[] weights) {
        double logHR = 0.0;
        for (int s = 0; s < tree.getNodeCount(); ++s) {
            Node sN = tree.getNode(s);
            if (sN.isLeaf() || sN.isRoot() || !(weights[s] > 0.0)) continue;
            double logSFW = logSF * weights[s];
            sN.setHeight(sN.getHeight() * Math.exp(logSFW));
            logHR += logSFW;
        }
        return logHR;
    }

    private static class OpFSinfoSMCNode {
        private final BitUnion lftUnion;
        private final BitUnion rgtUnion;
        private final double logHeight;
        private final double weight;

        public OpFSinfoSMCNode(BitUnion nodeUnion, BitUnion lftUnion, BitUnion rgtUnion, double logHeight, double weight) {
            this.lftUnion = lftUnion;
            this.rgtUnion = rgtUnion;
            this.logHeight = logHeight;
            this.weight = weight;
        }

        public BitUnion getLftUnion() {
            return this.lftUnion;
        }

        public BitUnion getRgtUnion() {
            return this.rgtUnion;
        }

        public double getLogHeight() {
            return this.logHeight;
        }

        public double getWeight() {
            return this.weight;
        }
    }

    private static class OpFSinfoSMCNodeUnions {
        private final BitUnion nodeUnion;
        private final BitUnion lftUnion;
        private final BitUnion rgtUnion;

        public OpFSinfoSMCNodeUnions(BitUnion nodeUnion, BitUnion lftUnion, BitUnion rgtUnion) {
            this.nodeUnion = nodeUnion;
            this.lftUnion = lftUnion;
            this.rgtUnion = rgtUnion;
        }

        public BitUnion getNodeUnion() {
            return this.nodeUnion;
        }

        public BitUnion getLftUnion() {
            return this.lftUnion;
        }

        public BitUnion getRgtUnion() {
            return this.rgtUnion;
        }
    }
}

