package org.datacleaner.components.machinelearning.impl;

import java.util.List;
import org.datacleaner.components.machinelearning.api.MLClassificationMetadata;
import org.datacleaner.components.machinelearning.api.MLClassificationRecord;
import org.datacleaner.components.machinelearning.api.MLClassificationTrainer;
import org.datacleaner.components.machinelearning.api.MLClassifier;
import org.datacleaner.components.machinelearning.api.MLFeatureModifier;
import org.datacleaner.components.machinelearning.api.MLTrainerCallback;
import org.datacleaner.components.machinelearning.api.MLTrainingOptions;
import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

/* loaded from: input_file:org/datacleaner/components/machinelearning/impl/SvmClasificationTrainer.class */
public class SvmClasificationTrainer implements MLClassificationTrainer {
    private final MLTrainingOptions trainingOptions;
    private final int epochs;
    private final double softMarginPenalty;
    private final SVM.Multiclass multiclass;
    private final double gaussianKernelSigma;

    public SvmClasificationTrainer(MLTrainingOptions mLTrainingOptions, int i, double d, double d2, SVM.Multiclass multiclass) {
        this.trainingOptions = mLTrainingOptions;
        this.epochs = i;
        this.gaussianKernelSigma = d;
        this.softMarginPenalty = d2;
        this.multiclass = multiclass;
    }

    @Override // org.datacleaner.components.machinelearning.api.MLClassificationTrainer
    public MLClassifier train(Iterable<MLClassificationRecord> iterable, List<MLFeatureModifier> list, MLTrainerCallback mLTrainerCallback) {
        double[][] featureVector = MLFeatureUtils.toFeatureVector(iterable, list);
        int[] classificationVector = MLFeatureUtils.toClassificationVector(iterable);
        List<Object> classifications = MLFeatureUtils.toClassifications(iterable);
        GaussianKernel gaussianKernel = new GaussianKernel(this.gaussianKernelSigma);
        int size = classifications.size();
        SVM svm = size < 3 ? new SVM(gaussianKernel, this.softMarginPenalty) : new SVM(gaussianKernel, this.softMarginPenalty, size, this.multiclass);
        for (int i = 0; i < this.epochs; i++) {
            svm.learn(featureVector, classificationVector);
            mLTrainerCallback.epochDone(i + 1, this.epochs);
        }
        svm.finish();
        svm.trainPlattScaling(featureVector, classificationVector);
        return new SmileClassifier(svm, new MLClassificationMetadata(this.trainingOptions.getClassificationType(), classifications, this.trainingOptions.getColumnNames(), list));
    }
}
