/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators;

import dr.inference.distribution.LinearRegression;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.MultivariateDistribution;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class RegressionGibbsEffectOperator
extends SimpleMCMCOperator
implements GibbsOperator {
    public static final String GIBBS_OPERATOR = "regressionGibbsEffectOperator";
    private LinearRegression linearModel;
    private Parameter effect;
    private Parameter indicators;
    private boolean hasNoIndicators = true;
    private MultivariateDistribution effectPrior;
    private int dim;
    private int effectNumber;
    private int N;
    private int numEffects;
    private double[][] X;
    private double[] mean = null;
    private double[][] variance = null;
    private double[][] precision = null;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), new ElementRule(Parameter.class), new ElementRule(MultivariateDistributionLikelihood.class), new ElementRule(LinearRegression.class), new ElementRule("indicator", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true)};

        @Override
        public String getParserName() {
            return RegressionGibbsEffectOperator.GIBBS_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double d = xMLObject.getDoubleAttribute("weight");
            LinearRegression linearRegression = (LinearRegression)xMLObject.getChild(LinearRegression.class);
            Parameter parameter = (Parameter)xMLObject.getChild(Parameter.class);
            MultivariateDistributionLikelihood multivariateDistributionLikelihood = (MultivariateDistributionLikelihood)xMLObject.getChild(MultivariateDistributionLikelihood.class);
            if (multivariateDistributionLikelihood.getDistribution().getType().compareTo("MultivariateNormal") != 0) {
                throw new XMLParseException("Only a multivariate normal prior is conjugate");
            }
            XMLObject xMLObject2 = xMLObject.getChild("indicator");
            Parameter parameter2 = null;
            if (xMLObject2 != null) {
                parameter2 = (Parameter)xMLObject2.getChild(Parameter.class);
            }
            RegressionGibbsEffectOperator regressionGibbsEffectOperator = new RegressionGibbsEffectOperator(linearRegression, parameter, parameter2, multivariateDistributionLikelihood);
            regressionGibbsEffectOperator.setWeight(d);
            return regressionGibbsEffectOperator;
        }

        @Override
        public String getParserDescription() {
            return "This element returns a multivariate Gibbs operator on an internal node trait.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public RegressionGibbsEffectOperator(LinearRegression linearRegression, Parameter parameter, Parameter parameter2, MultivariateDistributionLikelihood multivariateDistributionLikelihood) {
        this.linearModel = linearRegression;
        this.effect = parameter;
        this.indicators = parameter2;
        if (parameter2 != null) {
            this.hasNoIndicators = false;
            if (parameter2.getDimension() != parameter.getDimension()) {
                throw new RuntimeException("Indicator and effect dimensions must match");
            }
        }
        this.effectNumber = linearRegression.getEffectNumber(parameter);
        this.effectPrior = multivariateDistributionLikelihood.getDistribution();
        this.dim = parameter.getDimension();
        this.N = linearRegression.getDependentVariable().getDimension();
        this.numEffects = linearRegression.getNumberOfFixedEffects();
        this.X = linearRegression.getX(this.effectNumber);
    }

    public int getStepCount() {
        return 1;
    }

    public void computeForwardDensity(double[] dArray, double[][] dArray2, double[][] dArray3) {
        int n;
        int n2;
        int n3;
        int n4;
        int n5;
        Object object;
        double[] dArray4 = this.linearModel.getTransformedDependentParameter();
        double[] dArray5 = this.linearModel.getScale();
        for (int i = 0; i < this.numEffects; ++i) {
            if (i == this.effectNumber) continue;
            object = this.linearModel.getXBeta(i);
            for (int j = 0; j < this.N; ++j) {
                int n6 = j;
                dArray4[n6] = dArray4[n6] - object[j];
            }
        }
        double[] dArray6 = this.effectPrior.getMean();
        object = this.effectPrior.getScaleMatrix();
        double[][] dArray7 = new double[this.dim][this.N];
        for (int i = 0; i < this.dim; ++i) {
            if (!this.hasNoIndicators && this.indicators.getParameterValue(i) != 1.0) continue;
            for (n5 = 0; n5 < this.N; ++n5) {
                dArray7[i][n5] = this.X[n5][i] * dArray5[n5];
            }
        }
        double[][] dArray8 = new double[this.dim][this.dim];
        for (n5 = 0; n5 < this.dim; ++n5) {
            if (!this.hasNoIndicators && this.indicators.getParameterValue(n5) != 1.0) continue;
            for (n4 = n5; n4 < this.dim; ++n4) {
                if (!this.hasNoIndicators && this.indicators.getParameterValue(n4) != 1.0) continue;
                for (n3 = 0; n3 < this.N; ++n3) {
                    double[] dArray9 = dArray8[n5];
                    int n7 = n4;
                    dArray9[n7] = dArray9[n7] + dArray7[n5][n3] * this.X[n3][n4];
                }
                dArray8[n4][n5] = dArray8[n5][n4];
            }
        }
        double[][] dArray10 = new double[this.dim][this.dim];
        for (n4 = 0; n4 < this.dim; ++n4) {
            for (n3 = n4; n3 < this.dim; ++n3) {
                double d = dArray8[n4][n3] + object[n4][n3];
                dArray10[n4][n3] = d;
                dArray10[n3][n4] = d;
            }
        }
        double[] dArray11 = new double[this.dim];
        for (n3 = 0; n3 < this.dim; ++n3) {
            for (n2 = 0; n2 < this.N; ++n2) {
                int n8 = n3;
                dArray11[n8] = dArray11[n8] + dArray7[n3][n2] * dArray4[n2];
            }
        }
        double[] dArray12 = new double[this.dim];
        for (n2 = 0; n2 < this.dim; ++n2) {
            for (n = 0; n < this.dim; ++n) {
                int n9 = n2;
                dArray12[n9] = dArray12[n9] + object[n2][n] * dArray6[n];
            }
        }
        double[] dArray13 = new double[this.dim];
        for (n = 0; n < this.dim; ++n) {
            dArray13[n] = dArray12[n] + dArray11[n];
        }
        double[][] dArray14 = new SymmetricMatrix(dArray10).inverse().toComponents();
        for (int i = 0; i < this.dim; ++i) {
            dArray[i] = 0.0;
            for (int j = 0; j < this.dim; ++j) {
                int n10 = i;
                dArray[n10] = dArray[n10] + dArray14[i][j] * dArray13[j];
                dArray2[i][j] = dArray14[i][j];
                dArray3[i][j] = dArray10[i][j];
            }
        }
    }

    public double[] getLastMean() {
        return this.mean;
    }

    public double[][] getLastVariance() {
        return this.variance;
    }

    public double[][] getLastPrecision() {
        return this.precision;
    }

    @Override
    public double doOperation() {
        if (this.mean == null) {
            this.mean = new double[this.dim];
        }
        if (this.variance == null) {
            this.variance = new double[this.dim][this.dim];
        }
        if (this.precision == null) {
            this.precision = new double[this.dim][this.dim];
        }
        this.computeForwardDensity(this.mean, this.variance, this.precision);
        double[] dArray = MultivariateNormalDistribution.nextMultivariateNormalVariance(this.mean, this.variance);
        for (int i = 0; i < this.dim; ++i) {
            this.effect.setParameterValue(i, dArray[i]);
        }
        return 0.0;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return GIBBS_OPERATOR;
    }
}

