package org.tribuo.common.libsvm;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/common/libsvm/LibSVMTrainer.class */
public abstract class LibSVMTrainer<T extends Output<T>> implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(LibSVMTrainer.class.getName());
    protected svm_parameter parameters;

    @Config(mandatory = true, description = "Type of SVM algorithm.")
    protected SVMType<T> svmType;

    @Config(description = "Type of Kernel.")
    private KernelType kernelType;

    @Config(description = "Polynomial degree.")
    private int degree;

    @Config(description = "Width of the RBF kernel, or scalar on sigmoid kernel.")
    private double gamma;

    @Config(description = "Polynomial coefficient or shift in sigmoid kernel.")
    private double coef0;

    @Config(description = "nu value in NU SVM.")
    private double nu;

    @Config(description = "Internal cache size, most of the time should be left at default.")
    private double cache_size;

    @Config(description = "Cost parameter for incorrect predictions.")
    private double cost;

    @Config(description = "Tolerance of the termination criterion.")
    private double eps;

    @Config(description = "Epsilon in EPSILON_SVR.")
    private double p;

    @Config(description = "Regularise the weight parameters.")
    private boolean shrinking;

    @Config(description = "Generate probability estimates.")
    private boolean probability;
    private int trainInvocationCounter;

    protected LibSVMTrainer() {
        this.kernelType = KernelType.LINEAR;
        this.degree = 3;
        this.gamma = 0.0d;
        this.coef0 = 0.0d;
        this.nu = 0.5d;
        this.cache_size = 500.0d;
        this.cost = 1.0d;
        this.eps = 0.001d;
        this.p = 0.001d;
        this.shrinking = true;
        this.probability = false;
        this.trainInvocationCounter = 0;
    }

    protected LibSVMTrainer(SVMParameters<T> sVMParameters) {
        this.kernelType = KernelType.LINEAR;
        this.degree = 3;
        this.gamma = 0.0d;
        this.coef0 = 0.0d;
        this.nu = 0.5d;
        this.cache_size = 500.0d;
        this.cost = 1.0d;
        this.eps = 0.001d;
        this.p = 0.001d;
        this.shrinking = true;
        this.probability = false;
        this.trainInvocationCounter = 0;
        this.parameters = sVMParameters.getParameters();
        this.svmType = sVMParameters.getSvmType();
        this.kernelType = sVMParameters.getKernelType();
        this.degree = this.parameters.degree;
        this.gamma = sVMParameters.getGamma();
        this.coef0 = this.parameters.coef0;
        this.nu = this.parameters.nu;
        this.cache_size = this.parameters.cache_size;
        this.cost = this.parameters.C;
        this.eps = this.parameters.eps;
        this.p = this.parameters.p;
        this.shrinking = this.parameters.shrinking == 1;
        this.probability = this.parameters.probability == 1;
    }

    public void postConfig() {
        this.parameters = new svm_parameter();
        this.parameters.svm_type = this.svmType.getNativeType();
        this.parameters.kernel_type = this.kernelType.getNativeType();
        this.parameters.degree = this.degree;
        this.parameters.gamma = this.gamma;
        this.parameters.coef0 = this.coef0;
        this.parameters.nu = this.nu;
        this.parameters.cache_size = this.cache_size;
        this.parameters.C = this.cost;
        this.parameters.eps = this.eps;
        this.parameters.p = this.p;
        this.parameters.shrinking = this.shrinking ? 1 : 0;
        this.parameters.probability = this.probability ? 1 : 0;
    }

    public String toString() {
        return "LibSVMTrainer(svm_params=" + SVMParameters.svmParamsToString(this.parameters) + ")";
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public LibSVMModel<T> m4train(Dataset<T> dataset) {
        return train((Dataset) dataset, Collections.emptyMap());
    }

    public LibSVMModel<T> train(Dataset<T> dataset, Map<String, Provenance> map) {
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo<T> outputIDInfo = dataset.getOutputIDInfo();
        ModelProvenance modelProvenance = new ModelProvenance(LibSVMModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m5getProvenance(), map);
        this.trainInvocationCounter++;
        svm_parameter svm_parameterVar = setupParameters(outputIDInfo);
        Pair<svm_node[][], double[][]> extractData = extractData(dataset, outputIDInfo, featureIDMap);
        return createModel(modelProvenance, featureIDMap, outputIDInfo, trainModels(svm_parameterVar, featureIDMap.size() + 1, (svm_node[][]) extractData.getA(), (double[][]) extractData.getB()));
    }

    protected abstract LibSVMModel<T> createModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<svm_model> list);

    protected abstract List<svm_model> trainModels(svm_parameter svm_parameterVar, int i, svm_node[][] svm_nodeVarArr, double[][] dArr);

    protected abstract Pair<svm_node[][], double[][]> extractData(Dataset<T> dataset, ImmutableOutputInfo<T> immutableOutputInfo, ImmutableFeatureMap immutableFeatureMap);

    protected svm_parameter setupParameters(ImmutableOutputInfo<T> immutableOutputInfo) {
        return SVMParameters.copyParameters(this.parameters);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public static <T extends Output<T>> svm_node[] exampleToNodes(Example<T> example, ImmutableFeatureMap immutableFeatureMap, List<svm_node> list) {
        if (list == null) {
            list = new ArrayList();
        }
        list.clear();
        int i = -1;
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            int id = immutableFeatureMap.getID(feature.getName());
            double value = feature.getValue();
            if (id > i) {
                i = id;
                svm_node svm_nodeVar = new svm_node();
                svm_nodeVar.index = id;
                svm_nodeVar.value = value;
                list.add(svm_nodeVar);
            } else if (id > -1) {
                int binarySearch = Util.binarySearch(list, id, svm_nodeVar2 -> {
                    return svm_nodeVar2.index;
                });
                if (binarySearch < 0) {
                    int i2 = -(binarySearch + 1);
                    svm_node svm_nodeVar3 = new svm_node();
                    svm_nodeVar3.index = id;
                    svm_nodeVar3.value = value;
                    list.add(i2, svm_nodeVar3);
                } else {
                    list.get(binarySearch).value += value;
                }
            }
        }
        return (svm_node[]) list.toArray(new svm_node[0]);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m5getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m3train(Dataset dataset, Map map) {
        return train(dataset, (Map<String, Provenance>) map);
    }
}
