package com.github.thorbenlindhauer.factor;

import com.github.thorbenlindhauer.exception.FactorOperationException;
import com.github.thorbenlindhauer.exception.ModelStructureException;
import com.github.thorbenlindhauer.math.MathUtil;
import com.github.thorbenlindhauer.variable.Scope;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.RealVectorChangingVisitor;

/* loaded from: input_file:com/github/thorbenlindhauer/factor/CanonicalGaussianFactor.class */
public class CanonicalGaussianFactor implements GaussianFactor {
    protected Scope scope;
    protected RealMatrix precisionMatrix;
    protected RealVector scaledMeanVector;
    protected double normalizationConstant;

    public CanonicalGaussianFactor(Scope scope, RealMatrix realMatrix, RealVector realVector, double d) {
        if (!scope.getDiscreteVariables().isEmpty()) {
            throw new ModelStructureException("Cannot define a Gaussian factor with discrete variables " + scope.getDiscreteVariables());
        }
        this.scope = scope;
        this.precisionMatrix = realMatrix;
        this.scaledMeanVector = realVector;
        this.normalizationConstant = d;
    }

    @Override // com.github.thorbenlindhauer.factor.Factor
    public GaussianFactor product(GaussianFactor gaussianFactor) {
        Scope union = this.scope.union(gaussianFactor.getVariables());
        int size = union.size();
        int[] createContinuousVariableMapping = union.createContinuousVariableMapping(this.scope);
        int[] createContinuousVariableMapping2 = union.createContinuousVariableMapping(gaussianFactor.getVariables());
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(union.size(), union.size());
        RealMatrix precisionMatrix = gaussianFactor.getPrecisionMatrix();
        for (int i = 0; i < size; i++) {
            RealVector arrayRealVector = new ArrayRealVector(size);
            if (createContinuousVariableMapping[i] >= 0) {
                arrayRealVector = arrayRealVector.add(padVector(this.precisionMatrix.getColumnVector(createContinuousVariableMapping[i]), size, createContinuousVariableMapping));
            }
            if (createContinuousVariableMapping2[i] >= 0) {
                arrayRealVector = arrayRealVector.add(padVector(precisionMatrix.getColumnVector(createContinuousVariableMapping2[i]), size, createContinuousVariableMapping2));
            }
            array2DRowRealMatrix.setColumnVector(i, arrayRealVector);
        }
        return new CanonicalGaussianFactor(union, array2DRowRealMatrix, padVector(this.scaledMeanVector, union.size(), createContinuousVariableMapping).add(padVector(gaussianFactor.getScaledMeanVector(), size, createContinuousVariableMapping2)), this.normalizationConstant + gaussianFactor.getNormalizationConstant());
    }

    protected static RealVector padVector(final RealVector realVector, int i, final int[] iArr) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(i);
        arrayRealVector.walkInOptimizedOrder(new RealVectorChangingVisitor() { // from class: com.github.thorbenlindhauer.factor.CanonicalGaussianFactor.1
            public double visit(int i2, double d) {
                if (iArr[i2] >= 0) {
                    return realVector.getEntry(iArr[i2]);
                }
                return 0.0d;
            }

            public void start(int i2, int i3, int i4) {
            }

            public double end() {
                return 0.0d;
            }
        });
        return arrayRealVector;
    }

    @Override // com.github.thorbenlindhauer.factor.Factor
    public GaussianFactor division(GaussianFactor gaussianFactor) {
        if (!this.scope.contains(gaussianFactor.getVariables().getVariableIds())) {
            throw new FactorOperationException("Divisor scope " + gaussianFactor.getVariables() + " is not a subset of this factor's scope " + this.scope);
        }
        int[] createContinuousVariableMapping = this.scope.createContinuousVariableMapping(gaussianFactor.getVariables());
        RealMatrix copy = this.precisionMatrix.copy();
        RealMatrix precisionMatrix = gaussianFactor.getPrecisionMatrix();
        for (int i = 0; i < this.scope.size(); i++) {
            RealVector columnVector = copy.getColumnVector(i);
            if (createContinuousVariableMapping[i] >= 0) {
                copy.setColumnVector(i, columnVector.subtract(padVector(precisionMatrix.getColumnVector(createContinuousVariableMapping[i]), this.scope.size(), createContinuousVariableMapping)));
            }
        }
        return new CanonicalGaussianFactor(this.scope, copy, this.scaledMeanVector.copy().subtract(padVector(gaussianFactor.getScaledMeanVector(), this.scope.size(), createContinuousVariableMapping)), this.normalizationConstant - gaussianFactor.getNormalizationConstant());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.thorbenlindhauer.factor.Factor
    /* renamed from: marginal */
    public GaussianFactor marginal2(Scope scope) {
        if (scope.contains(this.scope)) {
            return this;
        }
        Scope intersect = this.scope.intersect(scope);
        Scope reduceBy = this.scope.reduceBy(intersect);
        int[] createContinuousVariableMapping = intersect.createContinuousVariableMapping(this.scope);
        RealMatrix subMatrix = this.precisionMatrix.getSubMatrix(createContinuousVariableMapping, createContinuousVariableMapping);
        int[] createContinuousVariableMapping2 = reduceBy.createContinuousVariableMapping(this.scope);
        RealMatrix subMatrix2 = this.precisionMatrix.getSubMatrix(createContinuousVariableMapping2, createContinuousVariableMapping2);
        RealMatrix subMatrix3 = this.precisionMatrix.getSubMatrix(createContinuousVariableMapping, createContinuousVariableMapping2);
        RealMatrix subMatrix4 = this.precisionMatrix.getSubMatrix(createContinuousVariableMapping2, createContinuousVariableMapping);
        RealMatrix invert = new MathUtil(subMatrix2).invert();
        RealMatrix subtract = subMatrix.subtract(subMatrix3.multiply(invert.multiply(subMatrix4)));
        RealVector subVector = getSubVector(this.scaledMeanVector, createContinuousVariableMapping);
        RealVector subVector2 = getSubVector(this.scaledMeanVector, createContinuousVariableMapping2);
        return new CanonicalGaussianFactor(intersect, subtract, subVector.subtract(subMatrix3.operate(invert.operate(subVector2))), this.normalizationConstant + (0.5d * (Math.log(new MathUtil(invert.scalarMultiply(6.283185307179586d)).determinant()) + subVector2.dotProduct(invert.operate(subVector2)))));
    }

    protected RealVector getSubVector(RealVector realVector, int[] iArr) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(iArr.length);
        for (int i = 0; i < iArr.length; i++) {
            arrayRealVector.setEntry(i, realVector.getEntry(iArr[i]));
        }
        return arrayRealVector;
    }

    @Override // com.github.thorbenlindhauer.factor.GaussianFactor
    public GaussianFactor observation(Scope scope, double[] dArr) {
        if (scope.getVariables().size() != dArr.length) {
            throw new ModelStructureException("Observed variables and values do not match");
        }
        if (this.scope.intersect(scope).isEmpty()) {
            return this;
        }
        ArrayRealVector arrayRealVector = new ArrayRealVector(dArr);
        Scope reduceBy = this.scope.reduceBy(scope);
        int[] createContinuousVariableMapping = reduceBy.createContinuousVariableMapping(this.scope);
        RealMatrix subMatrix = this.precisionMatrix.getSubMatrix(createContinuousVariableMapping, createContinuousVariableMapping);
        int[] createContinuousVariableMapping2 = scope.createContinuousVariableMapping(this.scope);
        return new CanonicalGaussianFactor(reduceBy, subMatrix, getSubVector(this.scaledMeanVector, createContinuousVariableMapping).subtract(this.precisionMatrix.getSubMatrix(createContinuousVariableMapping, createContinuousVariableMapping2).operate(arrayRealVector)), (this.normalizationConstant + getSubVector(this.scaledMeanVector, createContinuousVariableMapping2).dotProduct(arrayRealVector)) - (0.5d * arrayRealVector.dotProduct(this.precisionMatrix.getSubMatrix(createContinuousVariableMapping2, createContinuousVariableMapping2).operate(arrayRealVector))));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.thorbenlindhauer.factor.Factor
    /* renamed from: normalize */
    public GaussianFactor normalize2() {
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.thorbenlindhauer.factor.Factor
    public GaussianFactor invert() {
        throw new UnsupportedOperationException("not yet implemented");
    }

    @Override // com.github.thorbenlindhauer.factor.Factor
    public Scope getVariables() {
        return this.scope;
    }

    @Override // com.github.thorbenlindhauer.factor.GaussianFactor
    public RealMatrix getPrecisionMatrix() {
        return this.precisionMatrix;
    }

    @Override // com.github.thorbenlindhauer.factor.GaussianFactor
    public RealVector getScaledMeanVector() {
        return this.scaledMeanVector;
    }

    @Override // com.github.thorbenlindhauer.factor.GaussianFactor
    public double getNormalizationConstant() {
        return this.normalizationConstant;
    }

    @Override // com.github.thorbenlindhauer.factor.GaussianFactor
    public RealMatrix getCovarianceMatrix() {
        return new MathUtil(this.precisionMatrix).invert();
    }

    @Override // com.github.thorbenlindhauer.factor.GaussianFactor
    public RealVector getMeanVector() {
        return new MathUtil(this.precisionMatrix).invert().operate(this.scaledMeanVector);
    }

    @Override // com.github.thorbenlindhauer.factor.GaussianFactor
    public double getValueForAssignment(double[] dArr) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(dArr);
        return Math.exp(((-0.5d) * arrayRealVector.dotProduct(this.precisionMatrix.operate(arrayRealVector))) + this.scaledMeanVector.dotProduct(arrayRealVector) + this.normalizationConstant);
    }

    public static CanonicalGaussianFactor fromMomentForm(Scope scope, RealVector realVector, RealMatrix realMatrix) {
        MathUtil mathUtil = new MathUtil(realMatrix);
        RealMatrix invert = mathUtil.invert();
        RealVector operate = invert.operate(realVector);
        return new CanonicalGaussianFactor(scope, invert, operate, (-(0.5d * operate.dotProduct(realVector))) - Math.log(Math.pow(6.283185307179586d, realVector.getDimension() / 2.0d) * Math.sqrt(mathUtil.determinant())));
    }

    public static CanonicalGaussianFactor fromConditionalForm(Scope scope, Scope scope2, RealVector realVector, RealMatrix realMatrix, RealMatrix realMatrix2) {
        MathUtil mathUtil = new MathUtil(realMatrix);
        RealMatrix invert = mathUtil.invert();
        RealMatrix multiply = realMatrix2.transpose().multiply(invert);
        RealMatrix scalarMultiply = multiply.scalarMultiply(-1.0d);
        RealMatrix scalarMultiply2 = invert.transpose().multiply(realMatrix2).scalarMultiply(-1.0d);
        RealMatrix multiply2 = multiply.multiply(realMatrix2);
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(scope.size(), scope.size());
        Array2DRowRealMatrix array2DRowRealMatrix2 = new Array2DRowRealMatrix(scope.size(), scope.size());
        Scope reduceBy = scope.reduceBy(scope2);
        int[] createContinuousVariableMapping = scope.createContinuousVariableMapping(reduceBy);
        int[] createContinuousVariableMapping2 = scope.createContinuousVariableMapping(scope2);
        for (int i = 0; i < scope.size(); i++) {
            RealVector columnVector = array2DRowRealMatrix.getColumnVector(i);
            if (createContinuousVariableMapping[i] >= 0) {
                RealVector add = columnVector.add(padVector(invert.getColumnVector(createContinuousVariableMapping[i]), scope.size(), createContinuousVariableMapping));
                array2DRowRealMatrix2.setColumnVector(i, add);
                columnVector = add.add(padVector(scalarMultiply.getColumnVector(createContinuousVariableMapping[i]), scope.size(), createContinuousVariableMapping2));
                array2DRowRealMatrix.setColumnVector(i, columnVector);
            }
            if (createContinuousVariableMapping2[i] >= 0) {
                RealVector add2 = columnVector.add(padVector(scalarMultiply2.getColumnVector(createContinuousVariableMapping2[i]), scope.size(), createContinuousVariableMapping));
                array2DRowRealMatrix2.setColumnVector(i, add2);
                array2DRowRealMatrix.setColumnVector(i, add2.add(padVector(multiply2.getColumnVector(createContinuousVariableMapping2[i]), scope.size(), createContinuousVariableMapping2)));
            }
        }
        Array2DRowRealMatrix array2DRowRealMatrix3 = new Array2DRowRealMatrix(1, scope.size());
        array2DRowRealMatrix3.setRowVector(0, padVector(realVector, scope.size(), createContinuousVariableMapping));
        RealVector rowVector = array2DRowRealMatrix3.multiply(array2DRowRealMatrix2).getRowVector(0);
        new Array2DRowRealMatrix(reduceBy.size(), 1).setColumnVector(0, realVector);
        return new CanonicalGaussianFactor(scope, array2DRowRealMatrix, rowVector, ((-0.5d) * realVector.dotProduct(invert.operate(realVector))) - Math.log(Math.pow(6.283185307179586d, reduceBy.size() / 2.0d) * Math.sqrt(mathUtil.determinant())));
    }
}
