package org.tribuo.anomaly.libsvm;

import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.anomaly.Event;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
import org.tribuo.common.libsvm.SVMParameters;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.class */
public class LibSVMAnomalyTrainer extends LibSVMTrainer<Event> {
    private static final Logger logger = Logger.getLogger(LibSVMAnomalyTrainer.class.getName());

    protected LibSVMAnomalyTrainer() {
    }

    public LibSVMAnomalyTrainer(SVMParameters<Event> sVMParameters) {
        this(sVMParameters, 12345L);
    }

    public LibSVMAnomalyTrainer(SVMParameters<Event> sVMParameters, long j) {
        super(sVMParameters, j);
    }

    public void postConfig() {
        super.postConfig();
        if (!this.svmType.isAnomaly()) {
            throw new IllegalArgumentException("Supplied classification or regression parameters to an anomaly detection SVM.");
        }
    }

    public LibSVMModel<Event> train(Dataset<Event> dataset, Map<String, Provenance> map) {
        for (Pair pair : dataset.getOutputInfo().outputCountsIterable()) {
            if (((String) pair.getA()).equals(Event.EventType.ANOMALOUS.toString()) && ((Long) pair.getB()).longValue() > 0) {
                throw new IllegalArgumentException("LibSVMAnomalyTrainer only supports EXPECTED events at training time.");
            }
        }
        return super.train(dataset, map);
    }

    protected LibSVMModel<Event> createModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Event> immutableOutputInfo, List<svm_model> list) {
        return new LibSVMAnomalyModel("svm-anomaly-detection-model", modelProvenance, immutableFeatureMap, immutableOutputInfo, list);
    }

    protected List<svm_model> trainModels(svm_parameter svm_parameterVar, int i, svm_node[][] svm_nodeVarArr, double[][] dArr, SplittableRandom splittableRandom) {
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = dArr[0].length;
        svm_problemVar.x = svm_nodeVarArr;
        svm_problemVar.y = dArr[0];
        if (svm_parameterVar.gamma == 0.0d) {
            svm_parameterVar.gamma = 1.0d / i;
        }
        String svm_check_parameter = svm.svm_check_parameter(svm_problemVar, svm_parameterVar);
        if (svm_check_parameter != null) {
            throw new IllegalArgumentException("Error checking SVM parameters: " + svm_check_parameter);
        }
        svm.rand.setSeed(splittableRandom.nextLong());
        return Collections.singletonList(svm.svm_train(svm_problemVar, svm_parameterVar));
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected Pair<svm_node[][], double[][]> extractData(Dataset<Event> dataset, ImmutableOutputInfo<Event> immutableOutputInfo, ImmutableFeatureMap immutableFeatureMap) {
        double[][] dArr = new double[1][dataset.size()];
        svm_node[] svm_nodeVarArr = new svm_node[dataset.size()];
        ArrayList arrayList = new ArrayList();
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            dArr[0][i] = extractOutput((Event) example.getOutput());
            svm_nodeVarArr[i] = exampleToNodes(example, immutableFeatureMap, arrayList);
            i++;
        }
        return new Pair<>(svm_nodeVarArr, dArr);
    }

    protected double extractOutput(Event event) {
        return event.getType() == Event.EventType.EXPECTED ? 1.0d : -1.0d;
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m2train(Dataset dataset, Map map) {
        return train((Dataset<Event>) dataset, (Map<String, Provenance>) map);
    }
}
