package com.github.thorbenlindhauer.inference;

import com.github.thorbenlindhauer.exception.InferenceException;
import com.github.thorbenlindhauer.factor.DiscreteFactor;
import com.github.thorbenlindhauer.factor.FactorUtil;
import com.github.thorbenlindhauer.inference.variableelimination.VariableEliminationStrategy;
import com.github.thorbenlindhauer.network.GraphicalModel;
import com.github.thorbenlindhauer.variable.Scope;
import com.github.thorbenlindhauer.variable.Variable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:com/github/thorbenlindhauer/inference/VariableEliminationInferencer.class */
public class VariableEliminationInferencer implements DiscreteModelInferencer {
    protected GraphicalModel<DiscreteFactor> graphicalModel;
    protected VariableEliminationStrategy variableEliminationStrategy;

    public VariableEliminationInferencer(GraphicalModel<DiscreteFactor> graphicalModel, VariableEliminationStrategy variableEliminationStrategy) {
        this.graphicalModel = graphicalModel;
        this.variableEliminationStrategy = variableEliminationStrategy;
    }

    @Override // com.github.thorbenlindhauer.inference.DiscreteModelInferencer
    public double jointProbability(Scope scope, int[] iArr) {
        return jointProbability(scope, iArr, null, null);
    }

    @Override // com.github.thorbenlindhauer.inference.DiscreteModelInferencer
    public double jointProbability(Scope scope, int[] iArr, Scope scope2, int[] iArr2) {
        Scope scope3 = scope;
        if (scope2 != null) {
            scope3 = scope3.union(scope2);
        }
        DiscreteFactor normalize = jointProbabilityDistribution(scope3).normalize2();
        if (scope2 != null) {
            normalize = normalize.observation(scope2, iArr2);
        }
        return normalize.marginal2(scope).getValueForAssignment(iArr);
    }

    @Override // com.github.thorbenlindhauer.inference.DiscreteModelInferencer
    public double jointProbabilityConditionedOn(Scope scope, int[] iArr, Scope scope2, int[] iArr2) {
        Scope scope3 = scope;
        if (scope2 != null) {
            scope3 = scope3.union(scope2);
        }
        DiscreteFactor jointProbabilityDistribution = jointProbabilityDistribution(scope3);
        if (scope2 != null) {
            jointProbabilityDistribution = jointProbabilityDistribution.observation(scope2, iArr2);
        }
        return jointProbabilityDistribution.normalize2().marginal2(scope).getValueForAssignment(iArr);
    }

    protected DiscreteFactor jointProbabilityDistribution(Scope scope) {
        List<String> eliminationOrder = this.variableEliminationStrategy.getEliminationOrder(this.graphicalModel, Arrays.asList(this.graphicalModel.getScope().reduceBy(scope).getVariableIds()));
        validateEliminationOrder(this.graphicalModel, scope, eliminationOrder);
        Set<DiscreteFactor> factors = this.graphicalModel.getFactors();
        for (String str : eliminationOrder) {
            Set<DiscreteFactor> factorsWithVariableInScope = factorsWithVariableInScope(factors, str);
            factors.removeAll(factorsWithVariableInScope);
            DiscreteFactor discreteFactor = (DiscreteFactor) FactorUtil.jointDistribution(factorsWithVariableInScope);
            factors.add(discreteFactor.marginal2(discreteFactor.getVariables().reduceBy(str)));
        }
        return (DiscreteFactor) FactorUtil.jointDistribution(factors);
    }

    protected Set<DiscreteFactor> factorsWithVariableInScope(Set<DiscreteFactor> set, String str) {
        HashSet hashSet = new HashSet();
        for (DiscreteFactor discreteFactor : set) {
            if (discreteFactor.getVariables().has(str)) {
                hashSet.add(discreteFactor);
            }
        }
        return hashSet;
    }

    protected void validateEliminationOrder(GraphicalModel<DiscreteFactor> graphicalModel, Scope scope, List<String> list) {
        for (Variable variable : graphicalModel.getScope().getVariables()) {
            boolean has = scope.has(variable);
            boolean contains = list.contains(variable.getId());
            if (!has && !contains) {
                throw new InferenceException("Model variable " + variable.getId() + " is neither in the joint distribution's scope, nor in the variables to be eliminated.");
            }
            if (has && contains) {
                throw new InferenceException("Model variable " + variable.getId() + " is supposed to be part of the joint probability distribution, as well as to be eliminated.");
            }
        }
    }
}
