package org.tribuo.classification.sgd.crf;

import java.util.Arrays;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;

/* loaded from: input_file:org/tribuo/classification/sgd/crf/ChainHelper.class */
public final class ChainHelper {

    /* loaded from: input_file:org/tribuo/classification/sgd/crf/ChainHelper$ChainBPResults.class */
    public static final class ChainBPResults {
        public final double logZ;
        public final DenseVector[] alphas;
        public final DenseVector[] betas;
        public final ChainCliqueValues scores;

        ChainBPResults(double d, DenseVector[] denseVectorArr, DenseVector[] denseVectorArr2, ChainCliqueValues chainCliqueValues) {
            this.logZ = d;
            this.alphas = denseVectorArr;
            this.betas = denseVectorArr2;
            this.scores = chainCliqueValues;
        }
    }

    /* loaded from: input_file:org/tribuo/classification/sgd/crf/ChainHelper$ChainCliqueValues.class */
    public static final class ChainCliqueValues {
        public final DenseVector[] localValues;
        public final DenseMatrix transitionValues;

        /* JADX INFO: Access modifiers changed from: package-private */
        public ChainCliqueValues(DenseVector[] denseVectorArr, DenseMatrix denseMatrix) {
            this.localValues = denseVectorArr;
            this.transitionValues = denseMatrix;
        }
    }

    /* loaded from: input_file:org/tribuo/classification/sgd/crf/ChainHelper$ChainViterbiResults.class */
    public static final class ChainViterbiResults {
        public final double mapScore;
        public final int[] mapValues;
        public final ChainCliqueValues scores;

        ChainViterbiResults(double d, int[] iArr, ChainCliqueValues chainCliqueValues) {
            this.mapScore = d;
            this.mapValues = iArr;
            this.scores = chainCliqueValues;
        }
    }

    private ChainHelper() {
    }

    public static ChainBPResults beliefPropagation(ChainCliqueValues chainCliqueValues) {
        int dimension1Size = chainCliqueValues.transitionValues.getDimension1Size();
        DenseMatrix denseMatrix = chainCliqueValues.transitionValues;
        DenseVector[] denseVectorArr = chainCliqueValues.localValues;
        DenseVector[] denseVectorArr2 = new DenseVector[denseVectorArr.length];
        DenseVector[] denseVectorArr3 = new DenseVector[denseVectorArr.length];
        for (int i = 0; i < denseVectorArr.length; i++) {
            denseVectorArr2[i] = denseVectorArr[i].copy();
            denseVectorArr3[i] = new DenseVector(dimension1Size, Double.NEGATIVE_INFINITY);
        }
        double[] dArr = new double[dimension1Size];
        for (int i2 = 1; i2 < denseVectorArr.length; i2++) {
            DenseVector denseVector = denseVectorArr2[i2];
            DenseVector denseVector2 = denseVectorArr2[i2 - 1];
            for (int i3 = 0; i3 < dimension1Size; i3++) {
                for (int i4 = 0; i4 < dimension1Size; i4++) {
                    dArr[i4] = denseMatrix.get(i4, i3) + denseVector2.get(i4);
                }
                denseVector.add(i3, sumLogProbs(dArr));
            }
        }
        denseVectorArr3[denseVectorArr3.length - 1].fill(0.0d);
        for (int length = denseVectorArr.length - 2; length >= 0; length--) {
            DenseVector denseVector3 = denseVectorArr3[length];
            DenseVector denseVector4 = denseVectorArr3[length + 1];
            DenseVector denseVector5 = denseVectorArr[length + 1];
            for (int i5 = 0; i5 < dimension1Size; i5++) {
                for (int i6 = 0; i6 < dimension1Size; i6++) {
                    dArr[i6] = denseMatrix.get(i5, i6) + denseVector4.get(i6) + denseVector5.get(i6);
                }
                denseVector3.set(i5, sumLogProbs(dArr));
            }
        }
        return new ChainBPResults(sumLogProbs(denseVectorArr2[denseVectorArr2.length - 1]), denseVectorArr2, denseVectorArr3, chainCliqueValues);
    }

    public static double constrainedBeliefPropagation(ChainCliqueValues chainCliqueValues, int[] iArr) {
        int dimension1Size = chainCliqueValues.transitionValues.getDimension1Size();
        DenseMatrix denseMatrix = chainCliqueValues.transitionValues;
        DenseVector[] denseVectorArr = chainCliqueValues.localValues;
        if (denseVectorArr.length != iArr.length) {
            throw new IllegalArgumentException("Must have the same number of constraints as tokens");
        }
        DenseVector[] denseVectorArr2 = new DenseVector[denseVectorArr.length];
        for (int i = 0; i < denseVectorArr.length; i++) {
            denseVectorArr2[i] = denseVectorArr[i].copy();
        }
        double[] dArr = new double[dimension1Size];
        for (int i2 = 1; i2 < denseVectorArr.length; i2++) {
            DenseVector denseVector = denseVectorArr2[i2];
            DenseVector denseVector2 = denseVectorArr2[i2 - 1];
            for (int i3 = 0; i3 < dimension1Size; i3++) {
                if (iArr[i2] == -1 || iArr[i2] == i3) {
                    for (int i4 = 0; i4 < dimension1Size; i4++) {
                        dArr[i4] = denseMatrix.get(i4, i3) + denseVector2.get(i4);
                    }
                    denseVector.add(i3, sumLogProbs(dArr));
                } else {
                    denseVector.set(i3, Double.NEGATIVE_INFINITY);
                }
            }
        }
        return sumLogProbs(denseVectorArr2[denseVectorArr2.length - 1]);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static ChainViterbiResults viterbi(ChainCliqueValues chainCliqueValues) {
        DenseMatrix denseMatrix = chainCliqueValues.transitionValues;
        DenseVector[] denseVectorArr = chainCliqueValues.localValues;
        int dimension1Size = denseMatrix.getDimension1Size();
        DenseVector[] denseVectorArr2 = new DenseVector[chainCliqueValues.localValues.length];
        int[] iArr = new int[chainCliqueValues.localValues.length];
        for (int i = 0; i < chainCliqueValues.localValues.length; i++) {
            denseVectorArr2[i] = new DenseVector(dimension1Size, Double.NEGATIVE_INFINITY);
            iArr[i] = new int[dimension1Size];
            Arrays.fill(iArr[i], -1);
        }
        denseVectorArr2[0].setElements(denseVectorArr[0]);
        for (int i2 = 1; i2 < chainCliqueValues.localValues.length; i2++) {
            DenseVector denseVector = denseVectorArr[i2];
            DenseVector denseVector2 = denseVectorArr2[i2];
            int[] iArr2 = iArr[i2];
            DenseVector denseVector3 = denseVectorArr2[i2 - 1];
            for (int i3 = 0; i3 < dimension1Size; i3++) {
                double d = Double.NEGATIVE_INFINITY;
                int i4 = -1;
                double d2 = denseVector.get(i3);
                for (int i5 = 0; i5 < dimension1Size; i5++) {
                    double d3 = denseMatrix.get(i5, i3) + denseVector3.get(i5) + d2;
                    if (d3 > d) {
                        d = d3;
                        i4 = i5;
                    }
                }
                denseVector2.set(i3, d);
                if (i4 < 0) {
                    i4 = 0;
                }
                iArr2[i3] = i4;
            }
        }
        int[] iArr3 = new int[chainCliqueValues.localValues.length];
        iArr3[iArr3.length - 1] = denseVectorArr2[denseVectorArr2.length - 1].indexOfMax();
        for (int length = iArr3.length - 2; length >= 0; length--) {
            iArr3[length] = iArr[length + 1][iArr3[length + 1]];
        }
        return new ChainViterbiResults(denseVectorArr2[denseVectorArr2.length - 1].maxValue(), iArr3, chainCliqueValues);
    }

    public static double sumLogProbs(DenseVector denseVector) {
        double d = denseVector.get(0);
        int i = 0;
        for (int i2 = 1; i2 < denseVector.size(); i2++) {
            double d2 = denseVector.get(i2);
            if (d2 > d) {
                d = d2;
                i = i2;
            }
        }
        if (d == Double.NEGATIVE_INFINITY) {
            return d;
        }
        boolean z = false;
        double d3 = 0.0d;
        double d4 = d - 30.0d;
        for (int i3 = 0; i3 < denseVector.size(); i3++) {
            double d5 = denseVector.get(i3);
            if (d5 >= d4 && i3 != i && !Double.isInfinite(d5)) {
                z = true;
                d3 += Math.exp(d5 - d);
            }
        }
        return z ? d + Math.log1p(d3) : d;
    }

    public static double sumLogProbs(double[] dArr) {
        double d = dArr[0];
        int i = 0;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            double d2 = dArr[i2];
            if (d2 > d) {
                d = d2;
                i = i2;
            }
        }
        if (d == Double.NEGATIVE_INFINITY) {
            return d;
        }
        boolean z = false;
        double d3 = 0.0d;
        double d4 = d - 30.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (dArr[i3] >= d4 && i3 != i && !Double.isInfinite(dArr[i3])) {
                z = true;
                d3 += Math.exp(dArr[i3] - d);
            }
        }
        return z ? d + Math.log1p(d3) : d;
    }
}
