/*
 * Decompiled with CFR 0.152.
 */
package dr.oldevomodel.treelikelihood;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.AminoAcids;
import dr.evolution.datatype.Codons;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.oldevomodel.sitemodel.SiteModel;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.oldevomodel.treelikelihood.AbstractTreeLikelihood;
import dr.oldevomodel.treelikelihood.AminoAcidLikelihoodCore;
import dr.oldevomodel.treelikelihood.GeneralLikelihoodCore;
import dr.oldevomodel.treelikelihood.LikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeNucleotideLikelihoodCore;
import dr.oldevomodel.treelikelihood.NucleotideLikelihoodCore;
import dr.util.Identifiable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Logger;

@Deprecated
public class AdvancedTreeLikelihood
extends AbstractTreeLikelihood {
    public static final String ADVANCED_TREE_LIKELIHOOD = "advancedTreeLikelihood";
    public static final String CLADE = "clade";
    public static final String INCLUDE_STEM = "includeStem";
    public static final String TIPS = "tips";
    public static final String DELTA = "delta";
    public static final String USE_AMBIGUITIES = "useAmbiguities";
    public static final String STORE_PARTIALS = "storePartials";
    public static final String USE_SCALING = "useScaling";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newBooleanRule("useAmbiguities", true), AttributeRule.newBooleanRule("useScaling", true), new ElementRule("tips", SiteModel.class, "A siteModel that will be applied only to the tips.", 0, 1), new ElementRule("delta", new XMLSyntaxRule[]{new ElementRule(TaxonList.class, "A set of taxa to which to apply the delta model to", 0, 1), new ElementRule(Parameter.class, "A parameter that specifies the amount of extra substitutions per site at each tip.", 0, 1)}, true), new ElementRule("clade", new XMLSyntaxRule[]{AttributeRule.newBooleanRule("includeStem", true, "determines whether or not the stem branch above this clade is included in the siteModel."), new ElementRule(TaxonList.class, "A set of taxa which defines a clade to apply a different site model to"), new ElementRule(SiteModel.class, "A siteModel that will be applied only to this clade")}, 0, Integer.MAX_VALUE), new ElementRule(PatternList.class), new ElementRule(TreeModel.class), new ElementRule(SiteModel.class), new ElementRule(BranchRateModel.class, true)};

        @Override
        public String getParserName() {
            return AdvancedTreeLikelihood.ADVANCED_TREE_LIKELIHOOD;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            Identifiable identifiable;
            Object object;
            boolean bl = xMLObject.getAttribute(AdvancedTreeLikelihood.USE_AMBIGUITIES, false);
            boolean bl2 = xMLObject.getAttribute(AdvancedTreeLikelihood.USE_SCALING, false);
            PatternList patternList = (PatternList)xMLObject.getChild(PatternList.class);
            TreeModel treeModel = (TreeModel)xMLObject.getChild(TreeModel.class);
            SiteModel siteModel = (SiteModel)xMLObject.getChild(SiteModel.class);
            BranchRateModel branchRateModel = (BranchRateModel)xMLObject.getChild(BranchRateModel.class);
            AdvancedTreeLikelihood advancedTreeLikelihood = new AdvancedTreeLikelihood(patternList, treeModel, siteModel, branchRateModel, bl, bl2);
            if (xMLObject.hasChildNamed(AdvancedTreeLikelihood.TIPS)) {
                object = (SiteModel)xMLObject.getElementFirstChild(AdvancedTreeLikelihood.TIPS);
                advancedTreeLikelihood.addTipsSiteModel((SiteModel)object);
            }
            if ((object = xMLObject.getChild(AdvancedTreeLikelihood.DELTA)) != null) {
                Parameter parameter = (Parameter)((XMLObject)object).getChild(Parameter.class);
                identifiable = (TaxonList)((XMLObject)object).getChild(TaxonList.class);
                advancedTreeLikelihood.addDeltaParameter(parameter, identifiable);
            }
            for (int i = 0; i < xMLObject.getChildCount(); ++i) {
                if (!(xMLObject.getChild(i) instanceof XMLObject) || !((XMLObject)(object = (XMLObject)xMLObject.getChild(i))).getName().equals(AdvancedTreeLikelihood.CLADE)) continue;
                identifiable = (SiteModel)((XMLObject)object).getChild(SiteModel.class);
                TaxonList taxonList = (TaxonList)((XMLObject)object).getChild(TaxonList.class);
                boolean bl3 = false;
                if (((XMLObject)object).hasAttribute(AdvancedTreeLikelihood.INCLUDE_STEM)) {
                    bl3 = ((XMLObject)object).getBooleanAttribute(AdvancedTreeLikelihood.INCLUDE_STEM);
                    if (taxonList.getTaxonCount() == 1 && !bl3) {
                        throw new XMLParseException("The site model is only applied to 1 taxon and therefore must include the stem branch");
                    }
                } else if (taxonList.getTaxonCount() == 1) {
                    bl3 = true;
                }
                try {
                    advancedTreeLikelihood.addCladeSiteModel((SiteModel)identifiable, taxonList, bl3);
                    continue;
                }
                catch (TreeUtils.MissingTaxonException missingTaxonException) {
                    throw new XMLParseException("Taxon, " + missingTaxonException + ", in " + this.getParserName() + " was not found in the tree.");
                }
            }
            return advancedTreeLikelihood;
        }

        @Override
        public String getParserDescription() {
            return "This element represents the likelihood of a patternlist on a tree given the site model.";
        }

        @Override
        public Class getReturnType() {
            return Likelihood.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    protected FrequencyModel frequencyModel = null;
    protected SiteModel siteModel = null;
    protected BranchRateModel branchRateModel = null;
    private final boolean storePartials = false;
    protected SiteModel tipsSiteModel = null;
    protected Parameter deltaParameter = null;
    protected Set<Integer> deltaTips = null;
    protected ArrayList<Clade> cladeSiteModels = new ArrayList();
    private boolean commonAncestorsKnown = true;
    protected double[] rootPartials = null;
    protected double[] patternLogLikelihoods = null;
    protected int categoryCount;
    protected double[] probabilities;
    protected LikelihoodCore likelihoodCore;

    public AdvancedTreeLikelihood(PatternList patternList, TreeModel treeModel, SiteModel siteModel, BranchRateModel branchRateModel, boolean bl, boolean bl2) {
        super(ADVANCED_TREE_LIKELIHOOD, patternList, treeModel);
        try {
            int n;
            this.siteModel = siteModel;
            this.addModel(siteModel);
            this.frequencyModel = siteModel.getFrequencyModel();
            this.addModel(this.frequencyModel);
            if (!siteModel.integrateAcrossCategories()) {
                throw new RuntimeException("AdvancedTreeLikelihood can only use SiteModels that require integration across categories");
            }
            this.categoryCount = siteModel.getCategoryCount();
            if (patternList.getDataType() instanceof Nucleotides) {
                if (NativeNucleotideLikelihoodCore.isAvailable()) {
                    Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood using native nucleotide likelihood core.");
                    this.likelihoodCore = new NativeNucleotideLikelihoodCore();
                } else {
                    Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood Java nucleotide likelihood core.");
                    this.likelihoodCore = new NucleotideLikelihoodCore();
                }
            } else if (patternList.getDataType() instanceof AminoAcids) {
                Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood Java amino acid likelihood core.");
                this.likelihoodCore = new AminoAcidLikelihoodCore();
            } else if (patternList.getDataType() instanceof Codons) {
                Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java general likelihood core");
                this.likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
                bl = true;
            } else {
                Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood using Java general likelihood core");
                this.likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
            }
            Logger.getLogger("dr.evomodel").info("  " + (bl ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
            Logger.getLogger("dr.evomodel").info("  Partial likelihood scaling " + (bl2 ? "on." : "off."));
            if (branchRateModel != null) {
                this.branchRateModel = branchRateModel;
                Logger.getLogger("dr.evomodel").info("Branch rate model used: " + branchRateModel.getModelName());
            } else {
                this.branchRateModel = new DefaultBranchRateModel();
            }
            this.addModel(this.branchRateModel);
            this.probabilities = new double[this.stateCount * this.stateCount];
            this.likelihoodCore.initialize(this.nodeCount, this.patternCount, this.categoryCount, true);
            int n2 = treeModel.getExternalNodeCount();
            int n3 = treeModel.getInternalNodeCount();
            for (n = 0; n < n2; ++n) {
                String string = treeModel.getTaxonId(n);
                int n4 = patternList.getTaxonIndex(string);
                if (n4 == -1) {
                    throw new TaxonList.MissingTaxonException("Taxon, " + string + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId());
                }
                if (bl) {
                    this.setPartials(this.likelihoodCore, patternList, this.categoryCount, n4, n);
                    continue;
                }
                this.setStates(this.likelihoodCore, patternList, n4, n);
            }
            for (n = 0; n < n3; ++n) {
                this.likelihoodCore.createNodePartials(n2 + n);
            }
        }
        catch (TaxonList.MissingTaxonException missingTaxonException) {
            throw new RuntimeException(missingTaxonException.toString());
        }
    }

    public void addCladeSiteModel(SiteModel siteModel, TaxonList taxonList, boolean bl) throws TreeUtils.MissingTaxonException {
        Logger.getLogger("dr.evomodel").info("SiteModel added for clade.");
        this.cladeSiteModels.add(new Clade(siteModel, taxonList, bl));
        this.addModel(siteModel);
        this.commonAncestorsKnown = true;
    }

    public void addTipsSiteModel(SiteModel siteModel) {
        Logger.getLogger("dr.evomodel").info("SiteModel added for tips.");
        this.tipsSiteModel = siteModel;
        this.addModel(siteModel);
    }

    private void addDeltaParameter(Parameter parameter, TaxonList taxonList) {
        this.deltaParameter = parameter;
        this.deltaTips = new HashSet<Integer>();
        if (taxonList != null) {
            boolean bl = true;
            StringBuffer stringBuffer = new StringBuffer("Delta parameter added for tips: {");
            for (int i = 0; i < this.treeModel.getExternalNodeCount(); ++i) {
                NodeRef nodeRef = this.treeModel.getExternalNode(i);
                Taxon taxon = this.treeModel.getNodeTaxon(nodeRef);
                if (taxonList.getTaxonIndex(taxon) == -1) continue;
                if (!bl) {
                    stringBuffer.append(", ");
                } else {
                    bl = false;
                }
                stringBuffer.append(taxon.getId());
                this.deltaTips.add(nodeRef.getNumber());
            }
            stringBuffer.append("}");
            Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
        } else {
            Logger.getLogger("dr.evomodel").info("Delta parameter added for all tips.");
        }
        this.addVariable(parameter);
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.updateAllNodes();
        super.handleVariableChangedEvent(variable, n, changeType);
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.treeModel) {
            if (object instanceof TreeChangedEvent) {
                if (((TreeChangedEvent)object).isNodeChanged()) {
                    this.updateNodeAndChildren(((TreeChangedEvent)object).getNode());
                } else {
                    this.updateAllNodes();
                    this.commonAncestorsKnown = false;
                }
            }
        } else if (model == this.branchRateModel) {
            this.updateAllNodes();
        } else if (model == this.frequencyModel) {
            this.updateAllNodes();
        } else if (model instanceof SiteModel) {
            if (model == this.siteModel) {
                this.updateAllNodes();
            } else if (model == this.tipsSiteModel) {
                this.updateAllNodes();
            } else {
                NodeRef nodeRef = null;
                int n2 = this.cladeSiteModels.size();
                for (int i = 0; i < n2; ++i) {
                    Clade clade = this.cladeSiteModels.get(i);
                    if (!this.commonAncestorsKnown) {
                        clade.findMRCA();
                    }
                    if (clade.getSiteModel() != model) continue;
                    nodeRef = this.treeModel.getNode(clade.getNode());
                }
                this.commonAncestorsKnown = true;
                this.updateNodeAndDescendents(nodeRef);
            }
        } else {
            throw new RuntimeException("Unknown componentChangedEvent");
        }
        super.handleModelChangedEvent(model, object, n);
    }

    @Override
    protected void storeState() {
        super.storeState();
    }

    @Override
    protected void restoreState() {
        this.updateAllNodes();
        super.restoreState();
    }

    @Override
    protected double calculateLogLikelihood() {
        int n;
        NodeRef nodeRef = this.treeModel.getRoot();
        if (this.rootPartials == null) {
            this.rootPartials = new double[this.patternCount * this.stateCount];
        }
        if (this.patternLogLikelihoods == null) {
            this.patternLogLikelihoods = new double[this.patternCount];
        }
        if (!this.commonAncestorsKnown) {
            int n2 = this.cladeSiteModels.size();
            for (n = 0; n < n2; ++n) {
                this.cladeSiteModels.get(n).findMRCA();
            }
            this.commonAncestorsKnown = true;
        }
        this.traverse(this.treeModel, nodeRef, this.siteModel);
        for (n = 0; n < this.nodeCount; ++n) {
            this.updateNode[n] = false;
        }
        double d = 0.0;
        for (int i = 0; i < this.patternCount; ++i) {
            d += this.patternLogLikelihoods[i] * this.patternWeights[i];
        }
        return d;
    }

    private boolean traverse(Tree tree, NodeRef nodeRef, SiteModel siteModel) {
        int n;
        Object object;
        int n2;
        boolean bl = false;
        int n3 = nodeRef.getNumber();
        SiteModel siteModel2 = siteModel;
        if (this.tipsSiteModel != null && tree.isExternal(nodeRef)) {
            siteModel = this.tipsSiteModel;
        } else {
            n2 = this.cladeSiteModels.size();
            for (int i = 0; i < n2; ++i) {
                object = this.cladeSiteModels.get(i);
                if (((Clade)object).getNode() != n3) continue;
                siteModel2 = ((Clade)object).getSiteModel();
                if (!((Clade)object).includeStem()) break;
                siteModel = siteModel2;
                break;
            }
        }
        NodeRef nodeRef2 = tree.getParent(nodeRef);
        if (nodeRef2 != null && this.updateNode[n3]) {
            double d = this.branchRateModel.getBranchRate(tree, nodeRef);
            double d2 = d * (tree.getNodeHeight(nodeRef2) - tree.getNodeHeight(nodeRef));
            if (d2 < 0.0) {
                throw new RuntimeException("Negative branch length: " + d2);
            }
            this.likelihoodCore.setNodeMatrixForUpdate(n3);
            if (tree.isExternal(nodeRef) && this.deltaParameter != null && (this.deltaTips.size() == 0 || this.deltaTips.contains(new Integer(nodeRef.getNumber())))) {
                d2 += this.deltaParameter.getParameterValue(0);
            }
            for (n = 0; n < this.categoryCount; ++n) {
                double d3 = siteModel.getRateForCategory(n) * d2;
                siteModel.getSubstitutionModel().getTransitionProbabilities(d3, this.probabilities);
                this.likelihoodCore.setNodeMatrix(n3, n, this.probabilities);
            }
            bl = true;
        }
        if (!tree.isExternal(nodeRef)) {
            n2 = tree.getChildCount(nodeRef);
            if (n2 != 2) {
                throw new RuntimeException("binary trees only!");
            }
            object = tree.getChild(nodeRef, 0);
            boolean bl2 = this.traverse(tree, (NodeRef)object, siteModel2);
            NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
            n = this.traverse(tree, nodeRef3, siteModel2);
            if (bl2 || n != 0) {
                int n4 = object.getNumber();
                int n5 = nodeRef3.getNumber();
                this.likelihoodCore.setNodePartialsForUpdate(n3);
                this.likelihoodCore.calculatePartials(n4, n5, n3);
                if (nodeRef2 == null) {
                    double[] dArray = this.frequencyModel.getFrequencies();
                    double[] dArray2 = siteModel.getCategoryProportions();
                    this.likelihoodCore.integratePartials(n3, dArray2, this.rootPartials);
                    this.likelihoodCore.calculateLogLikelihoods(this.rootPartials, dArray, this.patternLogLikelihoods);
                }
                bl = true;
            }
        }
        return bl;
    }

    private class Clade {
        SiteModel siteModel;
        Set<String> leafSet;
        int node;
        boolean includeStem;

        Clade(SiteModel siteModel, TaxonList taxonList, boolean bl) throws TreeUtils.MissingTaxonException {
            this.siteModel = siteModel;
            this.leafSet = TreeUtils.getLeavesForTaxa(AdvancedTreeLikelihood.this.treeModel, taxonList);
            this.includeStem = bl;
            if (taxonList.getTaxonCount() == 1) {
                this.includeStem = true;
            }
            this.findMRCA();
        }

        void findMRCA() {
            this.node = TreeUtils.getCommonAncestorNode(AdvancedTreeLikelihood.this.treeModel, this.leafSet).getNumber();
        }

        int getNode() {
            return this.node;
        }

        boolean includeStem() {
            return this.includeStem;
        }

        SiteModel getSiteModel() {
            return this.siteModel;
        }
    }
}

