/*
 * Decompiled with CFR 0.152.
 */
package stacey.util;

import beast.core.Description;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeInterface;
import java.util.ArrayList;
import java.util.List;
import stacey.util.Bindings;
import stacey.util.BitUnion;

@Description(value="Important utility class for STACEY, used by several operators. It adds a set of species (or minimal clusters) to every node in the SMC-tree and in all gene trees.")
public class UnionArrays {
    private BitUnion[] sUnions;
    private BitUnion[][] gUnions;
    private TreeInterface sTree;
    private List<Tree> gTrees;
    private static volatile UnionArrays unionArrays = null;

    public static UnionArrays initialise(TreeInterface sTree, List<Tree> gTrees, Bindings bindings) {
        if (unionArrays == null) {
            unionArrays = new UnionArrays(sTree, gTrees, bindings);
        }
        return unionArrays;
    }

    public void update() {
        this.beastSubtreeToUnions(this.sUnions, this.sTree.getRoot());
        for (int j = 0; j < this.gUnions.length; ++j) {
            this.beastSubtreeToUnions(this.gUnions[j], this.gTrees.get(j).getRoot());
        }
    }

    public void updateSMCTreeAndGTree(int j) {
        this.beastSubtreeToUnions(this.sUnions, this.sTree.getRoot());
        this.beastSubtreeToUnions(this.gUnions[j], this.gTrees.get(j).getRoot());
    }

    public void reset() {
    }

    public BitUnion sNodeUnion(int n) {
        return this.sUnions[n];
    }

    public BitUnion gNodeUnion(int j, int i) {
        return this.gUnions[j][i];
    }

    public ArrayList<Node> getStraddlers(int j, BitUnion x, BitUnion y) {
        assert (!x.overlaps(y));
        ArrayList<Node> s = new ArrayList<Node>();
        this.subtreeStraddlers(s, j, x, y, this.gTrees.get(j).getRoot());
        return s;
    }

    public int smcTreeNodeNrOfUnion(BitUnion x) {
        return this.nodeIndexOfUnionInSubSTree(this.sTree.getRoot(), x);
    }

    public int hostNodeNrOfGNode(int j, Node gnode) {
        double height = gnode.getHeight();
        int n = this.smcTreeNodeNrOfUnion(this.gNodeUnion(j, gnode.getNr()));
        assert (this.sTree.getNode(n).getHeight() <= height);
        Node hostS = this.sTree.getNode(n);
        while (!hostS.isRoot() && hostS.getParent().getHeight() < height) {
            hostS = hostS.getParent();
        }
        return hostS.getNr();
    }

    private int nodeIndexOfUnionInSubSTree(Node node, BitUnion x) {
        if (node.isLeaf()) {
            return node.getNr();
        }
        Node lftNode = node.getChild(0);
        Node rgtNode = node.getChild(1);
        if (x.isContainedIn(this.sUnions[lftNode.getNr()])) {
            return this.nodeIndexOfUnionInSubSTree(lftNode, x);
        }
        if (x.isContainedIn(this.sUnions[rgtNode.getNr()])) {
            return this.nodeIndexOfUnionInSubSTree(rgtNode, x);
        }
        return node.getNr();
    }

    private UnionArrays(TreeInterface sTree, List<Tree> gTrees, Bindings bindings) {
        this.sTree = sTree;
        this.gTrees = gTrees;
        this.sUnions = new BitUnion[sTree.getNodeCount()];
        int nSMCs = bindings.smcTreeTipCount();
        for (int n = 0; n < this.sUnions.length; ++n) {
            this.sUnions[n] = new BitUnion(nSMCs);
            if (n >= nSMCs) continue;
            this.sUnions[n].replaceWith(bindings.tipUnionOfSMCNodeNr(n));
        }
        this.gUnions = new BitUnion[gTrees.size()][];
        for (int j = 0; j < this.gUnions.length; ++j) {
            this.gUnions[j] = new BitUnion[gTrees.get(j).getNodeCount()];
            int nGTips = gTrees.get(j).getLeafNodeCount();
            for (int i = 0; i < this.gUnions[j].length; ++i) {
                this.gUnions[j][i] = new BitUnion(nSMCs);
                if (i >= nGTips) continue;
                this.gUnions[j][i].replaceWith(bindings.tipUnionOfGNodeNr(j, i));
            }
        }
    }

    private void beastSubtreeToUnions(BitUnion[] unions, Node node) {
        Node rgtNode;
        Node lftNode = node.getChild(0);
        if (!lftNode.isLeaf()) {
            this.beastSubtreeToUnions(unions, lftNode);
        }
        if (!(rgtNode = node.getChild(1)).isLeaf()) {
            this.beastSubtreeToUnions(unions, rgtNode);
        }
        unions[node.getNr()].replaceWith(unions[lftNode.getNr()]);
        unions[node.getNr()].union(unions[rgtNode.getNr()]);
    }

    private void subtreeStraddlers(ArrayList<Node> s, int j, BitUnion x, BitUnion y, Node node) {
        Node lftN = node.getChild(0);
        Node rgtN = node.getChild(1);
        assert (lftN != null);
        assert (rgtN != null);
        boolean lftx = this.gUnions[j][lftN.getNr()].overlaps(x);
        boolean lfty = this.gUnions[j][lftN.getNr()].overlaps(y);
        boolean rgtx = this.gUnions[j][rgtN.getNr()].overlaps(x);
        boolean rgty = this.gUnions[j][rgtN.getNr()].overlaps(y);
        if (lftx & rgty | lfty & rgtx) {
            s.add(node);
        }
        if (lftx & lfty) {
            this.subtreeStraddlers(s, j, x, y, lftN);
        }
        if (rgtx & rgty) {
            this.subtreeStraddlers(s, j, x, y, rgtN);
        }
    }
}

