package org.tribuo.classification.sgd.crf;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.tribuo.classification.sgd.crf.ChainHelper;
import org.tribuo.classification.sgd.protos.CRFParametersProto;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.math.util.HeapMerger;
import org.tribuo.math.util.Merger;
import org.tribuo.protos.ProtoUtil;

/* loaded from: input_file:org/tribuo/classification/sgd/crf/CRFParameters.class */
public class CRFParameters implements Parameters, Serializable {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final int numLabels;
    private final int numFeatures;
    private static final Merger merger = new HeapMerger();
    private Tensor[] weights;
    private DenseVector biases;
    private DenseMatrix featureLabelWeights;
    private DenseMatrix labelLabelWeights;

    /* JADX INFO: Access modifiers changed from: package-private */
    public CRFParameters(int i, int i2) {
        this.biases = new DenseVector(i2);
        this.featureLabelWeights = new DenseMatrix(i2, i);
        this.labelLabelWeights = new DenseMatrix(i2, i2);
        this.weights = new Tensor[3];
        this.weights[0] = this.biases;
        this.weights[1] = this.featureLabelWeights;
        this.weights[2] = this.labelLabelWeights;
        this.numLabels = i2;
        this.numFeatures = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private CRFParameters(DenseVector denseVector, DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        this.weights = new Tensor[3];
        this.weights[0] = denseVector;
        this.weights[1] = denseMatrix;
        this.weights[2] = denseMatrix2;
        this.numLabels = denseVector.size();
        this.numFeatures = denseMatrix.getDimension2Size();
        this.biases = denseVector;
        this.featureLabelWeights = denseMatrix;
        this.labelLabelWeights = denseMatrix2;
    }

    public static CRFParameters deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        CRFParametersProto unpack = any.unpack(CRFParametersProto.class);
        int numLabels = unpack.getNumLabels();
        int numFeatures = unpack.getNumFeatures();
        DenseVector denseVector = (Tensor) ProtoUtil.deserialize(unpack.getBiases());
        DenseMatrix denseMatrix = (Tensor) ProtoUtil.deserialize(unpack.getFeatureLabelWeights());
        DenseMatrix denseMatrix2 = (Tensor) ProtoUtil.deserialize(unpack.getLabelLabelWeights());
        if (!(denseVector instanceof DenseVector)) {
            throw new IllegalArgumentException("Invalid protobuf, expected bias vector, found " + denseVector.getClass().getSimpleName());
        }
        if (denseVector.size() != numLabels) {
            throw new IllegalArgumentException("Invalid protobuf, expected bias vector with " + numLabels + " elements, but found " + denseVector.size());
        }
        if (!(denseMatrix instanceof DenseMatrix)) {
            throw new IllegalArgumentException("Invalid protobuf, expected feature/label matrix, found " + denseMatrix.getClass().getSimpleName());
        }
        DenseMatrix denseMatrix3 = denseMatrix;
        if (denseMatrix3.getDimension1Size() != numLabels || denseMatrix3.getDimension2Size() != numFeatures) {
            throw new IllegalArgumentException("Invalid protobuf, expected feature/label matrix of size [" + numLabels + ", " + numFeatures + "], found " + Arrays.toString(denseMatrix3.getShape()));
        }
        if (!(denseMatrix2 instanceof DenseMatrix)) {
            throw new IllegalArgumentException("Invalid protobuf, expected label/label matrix, found " + denseMatrix2.getClass().getSimpleName());
        }
        DenseMatrix denseMatrix4 = denseMatrix2;
        if (denseMatrix4.getDimension1Size() == numLabels && denseMatrix4.getDimension2Size() == numLabels) {
            return new CRFParameters(denseVector, denseMatrix3, denseMatrix4);
        }
        throw new IllegalArgumentException("Invalid protobuf, expected label/label matrix of size [" + numLabels + ", " + numLabels + "], found " + Arrays.toString(denseMatrix4.getShape()));
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ParametersProto m5serialize() {
        ParametersProto.Builder newBuilder = ParametersProto.newBuilder();
        newBuilder.setVersion(0);
        newBuilder.setClassName(CRFParameters.class.getName());
        CRFParametersProto.Builder newBuilder2 = CRFParametersProto.newBuilder();
        newBuilder2.setNumFeatures(this.numFeatures);
        newBuilder2.setNumLabels(this.numLabels);
        newBuilder2.setBiases(this.biases.serialize());
        newBuilder2.setFeatureLabelWeights(this.featureLabelWeights.serialize());
        newBuilder2.setLabelLabelWeights(this.labelLabelWeights.serialize());
        newBuilder.setSerializedData(Any.pack(newBuilder2.m127build()));
        return newBuilder.build();
    }

    public DenseVector getFeatureWeights(int i) {
        return this.featureLabelWeights.getColumn(i);
    }

    public double getBias(int i) {
        return this.biases.get(i);
    }

    public double getWeight(int i, int i2) {
        return this.featureLabelWeights.get(i, i2);
    }

    public DenseVector[] getLocalScores(SGDVector[] sGDVectorArr) {
        DenseVector[] denseVectorArr = new DenseVector[sGDVectorArr.length];
        for (int i = 0; i < sGDVectorArr.length; i++) {
            DenseVector leftMultiply = this.featureLabelWeights.leftMultiply(sGDVectorArr[i]);
            leftMultiply.intersectAndAddInPlace(this.biases);
            denseVectorArr[i] = leftMultiply;
        }
        return denseVectorArr;
    }

    public ChainHelper.ChainCliqueValues getCliqueValues(SGDVector[] sGDVectorArr) {
        return new ChainHelper.ChainCliqueValues(getLocalScores(sGDVectorArr), this.labelLabelWeights);
    }

    public int[] predict(SGDVector[] sGDVectorArr) {
        return ChainHelper.viterbi(getCliqueValues(sGDVectorArr)).mapValues;
    }

    public DenseVector[] predictMarginals(SGDVector[] sGDVectorArr) {
        ChainHelper.ChainBPResults beliefPropagation = ChainHelper.beliefPropagation(getCliqueValues(sGDVectorArr));
        DenseVector[] denseVectorArr = new DenseVector[sGDVectorArr.length];
        for (int i = 0; i < sGDVectorArr.length; i++) {
            denseVectorArr[i] = beliefPropagation.alphas[i].add(beliefPropagation.betas[i]);
            denseVectorArr[i].expNormalize(beliefPropagation.logZ);
        }
        return denseVectorArr;
    }

    public List<Double> predictConfidenceUsingCBP(SGDVector[] sGDVectorArr, List<Chunk> list) {
        ChainHelper.ChainCliqueValues cliqueValues = getCliqueValues(sGDVectorArr);
        double d = ChainHelper.beliefPropagation(cliqueValues).logZ;
        int[] iArr = new int[sGDVectorArr.length];
        ArrayList arrayList = new ArrayList();
        for (Chunk chunk : list) {
            Arrays.fill(iArr, -1);
            chunk.unpack(iArr);
            arrayList.add(Double.valueOf(Math.exp(ChainHelper.constrainedBeliefPropagation(cliqueValues, iArr) - d)));
        }
        return arrayList;
    }

    public Pair<Double, Tensor[]> valueAndGradient(SGDVector[] sGDVectorArr, int[] iArr) {
        ChainHelper.ChainCliqueValues cliqueValues = getCliqueValues(sGDVectorArr);
        ChainHelper.ChainBPResults beliefPropagation = ChainHelper.beliefPropagation(cliqueValues);
        double d = beliefPropagation.logZ;
        DenseVector[] denseVectorArr = beliefPropagation.alphas;
        SGDVector[] sGDVectorArr2 = beliefPropagation.betas;
        Tensor[] tensorArr = new Tensor[3];
        DenseSparseMatrix[] denseSparseMatrixArr = new DenseSparseMatrix[sGDVectorArr.length];
        DenseMatrix denseMatrix = null;
        boolean z = false;
        tensorArr[0] = new DenseVector(this.biases.size());
        DenseMatrix denseMatrix2 = new DenseMatrix(this.numLabels, this.numLabels);
        tensorArr[2] = denseMatrix2;
        double d2 = -d;
        for (int i = 0; i < sGDVectorArr.length; i++) {
            int i2 = iArr[i];
            DenseVector denseVector = cliqueValues.localValues[i];
            d2 += denseVector.get(i2);
            DenseVector denseVector2 = denseVectorArr[i];
            SGDVector sGDVector = sGDVectorArr2[i];
            DenseVector add = denseVector2.add(sGDVector);
            add.expNormalize(d);
            add.scaleInPlace(-1.0d);
            add.add(i2, 1.0d);
            tensorArr[0].intersectAndAddInPlace(add);
            DenseMatrix outer = add.outer(sGDVectorArr[i]);
            if (outer instanceof DenseSparseMatrix) {
                denseSparseMatrixArr[i] = (DenseSparseMatrix) outer;
                z = true;
            } else if (denseMatrix == null) {
                denseMatrix = outer;
            } else {
                denseMatrix.intersectAndAddInPlace(outer);
            }
            if (i >= 1) {
                DenseVector denseVector3 = denseVectorArr[i - 1];
                for (int i3 = 0; i3 < this.numLabels; i3++) {
                    double d3 = denseVector3.get(i3);
                    for (int i4 = 0; i4 < this.numLabels; i4++) {
                        denseMatrix2.add(i3, i4, -Math.exp((((d3 + this.labelLabelWeights.get(i3, i4)) + sGDVector.get(i4)) + denseVector.get(i4)) - d));
                    }
                }
                int i5 = iArr[i - 1];
                d2 += this.labelLabelWeights.get(i5, i2);
                denseMatrix2.add(i5, i2, 1.0d);
            }
        }
        if (z) {
            tensorArr[1] = merger.merge(denseSparseMatrixArr);
            if (denseMatrix != null) {
                throw new IllegalStateException("Mixture of dense and sparse features found.");
            }
        } else {
            tensorArr[1] = denseMatrix;
        }
        return new Pair<>(Double.valueOf(d2), tensorArr);
    }

    public Tensor[] getEmptyCopy() {
        return new Tensor[]{new DenseVector(this.biases.size()), new DenseMatrix(this.featureLabelWeights.getDimension1Size(), this.featureLabelWeights.getDimension2Size()), new DenseMatrix(this.labelLabelWeights.getDimension1Size(), this.labelLabelWeights.getDimension2Size())};
    }

    public Tensor[] get() {
        return this.weights;
    }

    public void set(Tensor[] tensorArr) {
        if (tensorArr.length == this.weights.length) {
            this.weights = tensorArr;
            this.biases = this.weights[0];
            this.featureLabelWeights = this.weights[1];
            this.labelLabelWeights = this.weights[2];
        }
    }

    public void update(Tensor[] tensorArr) {
        for (int i = 0; i < tensorArr.length; i++) {
            this.weights[i].intersectAndAddInPlace(tensorArr[i]);
        }
    }

    public Tensor[] merge(Tensor[][] tensorArr, int i) {
        DenseMatrix denseMatrix;
        Tensor denseVector = new DenseVector(this.biases.size());
        ArrayList arrayList = new ArrayList(i);
        DenseMatrix denseMatrix2 = null;
        Tensor denseMatrix3 = new DenseMatrix(this.labelLabelWeights.getDimension1Size(), this.labelLabelWeights.getDimension2Size());
        for (int i2 = 0; i2 < tensorArr.length; i2++) {
            denseVector.intersectAndAddInPlace(tensorArr[i2][0]);
            DenseSparseMatrix denseSparseMatrix = (Matrix) tensorArr[i2][1];
            if (denseSparseMatrix instanceof DenseSparseMatrix) {
                arrayList.add(denseSparseMatrix);
            } else if (denseMatrix2 == null) {
                denseMatrix2 = (DenseMatrix) denseSparseMatrix;
            } else {
                denseMatrix2.intersectAndAddInPlace(denseSparseMatrix);
            }
            denseMatrix3.intersectAndAddInPlace(tensorArr[i2][2]);
        }
        if (arrayList.size() > 0) {
            denseMatrix = merger.merge((DenseSparseMatrix[]) arrayList.toArray(new DenseSparseMatrix[0]));
            if (denseMatrix2 != null) {
                denseMatrix2.intersectAndAddInPlace(denseMatrix);
                denseMatrix = denseMatrix2;
            }
        } else {
            denseMatrix = denseMatrix2;
        }
        return new Tensor[]{denseVector, denseMatrix, denseMatrix3};
    }
}
