package org.tribuo.classification.sgd.kernel;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.sgd.protos.KernelSVMModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.kernel.Kernel;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.protos.KernelProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/classification/sgd/kernel/KernelSVMModel.class */
public class KernelSVMModel extends Model<Label> {
    private static final long serialVersionUID = 2;
    public static final int CURRENT_VERSION = 0;
    private final Kernel kernel;
    private final SparseVector[] supportVectors;
    private final DenseMatrix weights;

    /* JADX INFO: Access modifiers changed from: package-private */
    public KernelSVMModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, Kernel kernel, SparseVector[] sparseVectorArr, DenseMatrix denseMatrix) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.kernel = kernel;
        this.supportVectors = sparseVectorArr;
        this.weights = denseMatrix;
    }

    public static KernelSVMModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        KernelSVMModelProto unpack = any.unpack(KernelSVMModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        SparseVector[] sparseVectorArr = new SparseVector[unpack.getSupportVectorsCount()];
        int size = deserialize.featureDomain().size() + 1;
        List<TensorProto> supportVectorsList = unpack.getSupportVectorsList();
        for (int i2 = 0; i2 < supportVectorsList.size(); i2++) {
            Tensor deserialize2 = Tensor.deserialize(supportVectorsList.get(i2));
            if (!(deserialize2 instanceof SparseVector)) {
                throw new IllegalStateException("Invalid protobuf, support vector must be a sparse vector, found " + deserialize2.getClass());
            }
            SparseVector sparseVector = (SparseVector) deserialize2;
            if (sparseVector.size() != size) {
                throw new IllegalStateException("Invalid protobuf, support vector size must equal feature domain size, found " + sparseVector.size() + ", expected " + size);
            }
            sparseVectorArr[i2] = sparseVector;
        }
        DenseMatrix deserialize3 = Tensor.deserialize(unpack.getWeights());
        if (!(deserialize3 instanceof DenseMatrix)) {
            throw new IllegalStateException("Invalid protobuf, weights must be a dense matrix, found " + deserialize3.getClass());
        }
        DenseMatrix denseMatrix = deserialize3;
        if (denseMatrix.getDimension1Size() != deserialize.outputDomain().size()) {
            throw new IllegalStateException("Invalid protobuf, weights not the right size, expected " + deserialize.outputDomain().size() + ", found " + denseMatrix.getDimension1Size());
        }
        if (denseMatrix.getDimension2Size() != sparseVectorArr.length) {
            throw new IllegalStateException("Invalid protobuf, weights not the right size, expected " + sparseVectorArr.length + ", found " + denseMatrix.getDimension2Size());
        }
        return new KernelSVMModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), outputDomain, Kernel.deserialize(unpack.getKernel()), sparseVectorArr, denseMatrix);
    }

    public int getNumberOfSupportVectors() {
        return this.supportVectors.length;
    }

    public Prediction<Label> predict(Example<Label> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, true);
        if (createSparseVector.numActiveElements() == 1) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double[] dArr = new double[this.supportVectors.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.kernel.similarity(createSparseVector, this.supportVectors[i]);
        }
        DenseVector leftMultiply = this.weights.leftMultiply(DenseVector.createDenseVector(dArr));
        double d = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i2 = 0; i2 < leftMultiply.size(); i2++) {
            String label2 = this.outputIDInfo.getOutput(i2).getLabel();
            Label label3 = new Label(label2, leftMultiply.get(i2));
            linkedHashMap.put(label2, label3);
            if (label3.getScore() > d) {
                d = label3.getScore();
                label = label3;
            }
        }
        return new Prediction<>(label, linkedHashMap, createSparseVector.numActiveElements(), example, this.generatesProbabilities);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
        return Optional.empty();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m20serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        KernelSVMModelProto.Builder newBuilder = KernelSVMModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setKernel((KernelProto) this.kernel.serialize());
        newBuilder.setWeights(this.weights.serialize());
        for (SparseVector sparseVector : this.supportVectors) {
            newBuilder.addSupportVectors(sparseVector.serialize());
        }
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setVersion(0);
        newBuilder2.setClassName(KernelSVMModel.class.getName());
        newBuilder2.setSerializedData(Any.pack(newBuilder.m268build()));
        return newBuilder2.build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public KernelSVMModel m19copy(String str, ModelProvenance modelProvenance) {
        SparseVector[] sparseVectorArr = new SparseVector[this.supportVectors.length];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = this.supportVectors[i].copy();
        }
        return new KernelSVMModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.kernel, sparseVectorArr, new DenseMatrix(this.weights));
    }
}
