package ch.epfl.gsn.utils.models;

import ch.epfl.gsn.utils.models.helper.Segmenter;
import ch.epfl.gsn.utils.models.helper.Tools;
import weka.classifiers.Classifier;
import weka.classifiers.SegmentedClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.DummyFilter;
import weka.filters.unsupervised.instance.ErrorBased;
import weka.filters.unsupervised.instance.RandomSample;
import weka.filters.unsupervised.instance.SubSample;

/* loaded from: input_file:ch/epfl/gsn/utils/models/ModelSampling.class */
public class ModelSampling {
    public static final int BINARY = 0;
    public static final int BINARY_PLUS = 1;
    public static final int HEURISTIC = 2;
    public static final int HEURISTIC_PLUS = 3;
    public static final int UNIFORM = 0;
    public static final int ERROR_BASED = 1;
    public static final int RANDOM = 2;
    public static final String[] SEGMENT_NAMES = {"BINARY", "BINARY_PLUS", "HEURISTIC", "HEURISTIC_PLUS"};
    public static final String[] SAMPLING_NAMES = {"UNIFORM", "ERROR_BASED", "RANDOM"};
    public static final String[] MODEL_NAMES = {"SVM", "LINEAR"};
    private int seg_method;
    private int samp_method;
    private int model;
    private int seg_num;
    private int samp_ratio;
    private Classifier classifier = null;

    public ModelSampling(int i, int i2, int i3, int i4, int i5) {
        this.seg_method = 0;
        this.samp_method = 0;
        this.model = 0;
        this.seg_num = 1;
        this.samp_ratio = 1;
        this.seg_method = i2;
        this.samp_method = i4;
        this.model = i;
        this.seg_num = i3;
        this.samp_ratio = i5;
    }

    public static int getIdFromString(String[] strArr, String str) {
        int i = -1;
        if (str.matches("\\d")) {
            return Integer.parseInt(str);
        }
        int i2 = 0;
        while (true) {
            if (i2 >= strArr.length) {
                break;
            }
            if (strArr[i2].toUpperCase().equals(str.toUpperCase())) {
                i = i2;
                break;
            }
            i2++;
        }
        return i;
    }

    public Double predict(Instance instance) {
        try {
            return new Double(this.classifier.classifyInstance(instance));
        } catch (Exception e) {
            return null;
        }
    }

    public int train(Instances instances, int i, int i2, int i3, int i4, int i5) {
        this.seg_method = i2;
        this.samp_method = i4;
        this.model = i;
        this.seg_num = i3;
        this.samp_ratio = i5;
        return train(instances);
    }

    public int train(Instances instances) {
        Filter filter;
        try {
            Segmenter segmenter = new Segmenter(this.seg_method, this.model);
            Double[] segments = segmenter.getSegments(this.seg_num, instances);
            if (segments == null) {
                return 0;
            }
            segmenter.computeErrors(instances, segments);
            if (this.samp_method == 0) {
                Filter subSample = new SubSample();
                subSample.setInputFormat(instances);
                subSample.setRatio(this.samp_ratio);
                subSample.setM_index(0);
                filter = subSample;
            } else if (this.samp_method == 1) {
                Filter errorBased = new ErrorBased();
                errorBased.setInputFormat(instances);
                errorBased.setM_ratio(this.samp_ratio);
                errorBased.setM_errors(segmenter.Pred_errors);
                filter = errorBased;
            } else if (this.samp_method == 2) {
                Filter randomSample = new RandomSample();
                randomSample.setInputFormat(instances);
                randomSample.setM_ratio(this.samp_ratio);
                filter = randomSample;
            } else {
                Filter dummyFilter = new DummyFilter();
                dummyFilter.setInputFormat(instances);
                filter = dummyFilter;
            }
            this.classifier = new SegmentedClassifier(Tools.getClassifierById(this.model), 1, segments, filter);
            this.classifier.buildClassifier(instances);
            return 1;
        } catch (Exception e) {
            return 0;
        }
    }
}
