package dk.bayes.dsl.variable.gaussian.multivariate;

import dk.bayes.dsl.InferEngine;
import dk.bayes.dsl.Variable;
import dk.bayes.dsl.variable.gaussian.multivariatelinear.MultivariateLinearGaussian;
import dk.bayes.math.linear.Matrix;
import scala.Some;
import scala.collection.Seq$;
import scala.collection.SeqLike;

/* compiled from: inferMultivariateGaussianSimplest.scala */
/* loaded from: input_file:dk/bayes/dsl/variable/gaussian/multivariate/inferMultivariateGaussianSimplest$.class */
public final class inferMultivariateGaussianSimplest$ implements InferEngine<MultivariateGaussian, MultivariateGaussian> {
    public static final inferMultivariateGaussianSimplest$ MODULE$ = null;

    static {
        new inferMultivariateGaussianSimplest$();
    }

    @Override // dk.bayes.dsl.InferEngine
    public boolean isSupported(MultivariateGaussian multivariateGaussian) {
        Some unapplySeq = Seq$.MODULE$.unapplySeq(multivariateGaussian.getChildren());
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(1) != 0) {
            return false;
        }
        Variable variable = (Variable) ((SeqLike) unapplySeq.get()).apply(0);
        if (!(variable instanceof MultivariateLinearGaussian)) {
            return false;
        }
        MultivariateLinearGaussian multivariateLinearGaussian = (MultivariateLinearGaussian) variable;
        return multivariateLinearGaussian.getParents().size() == 1 && multivariateLinearGaussian.getParents().apply(0) == multivariateGaussian && !multivariateLinearGaussian.hasChildren() && multivariateLinearGaussian.b().size() == multivariateGaussian.m().size() && multivariateLinearGaussian.yValue().isDefined();
    }

    @Override // dk.bayes.dsl.InferEngine
    public MultivariateGaussian infer(MultivariateGaussian multivariateGaussian) {
        MultivariateLinearGaussian multivariateLinearGaussian = (MultivariateLinearGaussian) multivariateGaussian.getChildren().head();
        Matrix inv = multivariateGaussian.v().inv();
        Matrix inv2 = multivariateLinearGaussian.v().inv();
        Matrix inv3 = inv.$plus(multivariateLinearGaussian.a().t().$times(inv2).$times(multivariateLinearGaussian.a())).inv();
        return new MultivariateGaussian(inv3.$times(multivariateLinearGaussian.a().t().$times(inv2).$times(((Matrix) multivariateLinearGaussian.yValue().get()).$minus(multivariateLinearGaussian.b())).$plus(inv.$times(multivariateGaussian.m()))), inv3);
    }

    private inferMultivariateGaussianSimplest$() {
        MODULE$ = this;
    }
}
