package com.github.thorbenlindhauer.cluster.ep;

import com.github.thorbenlindhauer.exception.InferenceException;
import com.github.thorbenlindhauer.factor.CanonicalGaussianFactor;
import com.github.thorbenlindhauer.factor.FactorSet;
import com.github.thorbenlindhauer.factor.FactorUtil;
import com.github.thorbenlindhauer.factor.GaussianFactor;
import com.github.thorbenlindhauer.variable.ContinuousVariable;
import com.github.thorbenlindhauer.variable.Scope;
import java.util.Collections;
import java.util.Iterator;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;

/* loaded from: input_file:com/github/thorbenlindhauer/cluster/ep/TruncatedGaussianPotentialResolver.class */
public class TruncatedGaussianPotentialResolver implements ClusterPotentialResolver<GaussianFactor> {
    protected double lowerBound;
    protected double upperBound;
    protected ContinuousVariable predictionVariable;
    protected NormalDistribution standardNormal = new NormalDistribution();

    public TruncatedGaussianPotentialResolver(ContinuousVariable continuousVariable, double d, double d2) {
        this.lowerBound = d;
        this.upperBound = d2;
        this.predictionVariable = continuousVariable;
    }

    @Override // com.github.thorbenlindhauer.cluster.ep.ClusterPotentialResolver
    public FactorSet<GaussianFactor> project(FactorSet<GaussianFactor> factorSet, Scope scope) {
        if (scope.size() != 1 || !scope.contains(this.predictionVariable.getId())) {
            throw new InferenceException("Can only project on variable " + this.predictionVariable.getId() + " not on scope " + scope);
        }
        Iterator<GaussianFactor> it = factorSet.getFactors().iterator();
        while (it.hasNext()) {
            Scope variables = it.next().getVariables();
            if (variables.size() != 1 || !variables.contains(this.predictionVariable.getId())) {
                throw new InferenceException("Can only project univariate gaussians over variable " + this.predictionVariable.getId());
            }
        }
        GaussianFactor gaussianFactor = (GaussianFactor) FactorUtil.jointDistribution(factorSet.getFactors());
        double entry = gaussianFactor.getCovarianceMatrix().getEntry(0, 0);
        double sqrt = Math.sqrt(entry);
        double entry2 = gaussianFactor.getMeanVector().getEntry(0);
        double d = this.lowerBound / sqrt;
        double d2 = this.upperBound / sqrt;
        double d3 = entry2 / sqrt;
        double vValue = vValue(d3, d, d2);
        return new FactorSet<>(Collections.singleton(CanonicalGaussianFactor.fromMomentForm(scope, new ArrayRealVector(new double[]{entry2 + (sqrt * vValue)}), new Array2DRowRealMatrix(new double[]{entry * (1.0d - wValue(vValue, d3, d, d2))}))));
    }

    protected double vValue(double d, double d2, double d3) {
        double d4 = d3 - d;
        double d5 = d2 - d;
        return (this.standardNormal.density(d5) - this.standardNormal.density(d4)) / (this.standardNormal.cumulativeProbability(d4) - this.standardNormal.cumulativeProbability(d5));
    }

    protected double wValue(double d, double d2, double d3, double d4) {
        double d5 = d4 - d2;
        double d6 = d3 - d2;
        return (d * d) + (((d5 * this.standardNormal.density(d5)) - (d6 * this.standardNormal.density(d6))) / (this.standardNormal.cumulativeProbability(d5) - this.standardNormal.cumulativeProbability(d6)));
    }
}
