package ai.libs.mlplan.multiclass.wekamlplan.weka.model;

import ai.libs.jaicore.ml.WekaUtil;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.AttributeSelection;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/mlplan/multiclass/wekamlplan/weka/model/MLPipeline.class */
public class MLPipeline extends SingleClassifierEnhancer implements Classifier, Serializable {
    private static final Logger logger = LoggerFactory.getLogger(MLPipeline.class);
    private final List<SupervisedFilterSelector> preprocessors = new ArrayList();
    private boolean trained = false;
    private int timeForTrainingPreprocessors;
    private int timeForTrainingClassifier;
    private DescriptiveStatistics timeForExecutingPreprocessors;
    private DescriptiveStatistics timeForExecutingClassifier;

    public MLPipeline(List<SupervisedFilterSelector> list, Classifier classifier) {
        if (classifier == null) {
            throw new IllegalArgumentException("Base classifier must not be null!");
        }
        this.preprocessors.addAll(list);
        super.setClassifier(classifier);
    }

    public MLPipeline(ASSearch aSSearch, ASEvaluation aSEvaluation, Classifier classifier) {
        if (classifier == null) {
            throw new IllegalArgumentException("Base classifier must not be null!");
        }
        if (aSSearch != null && aSEvaluation != null) {
            AttributeSelection attributeSelection = new AttributeSelection();
            attributeSelection.setSearch(aSSearch);
            attributeSelection.setEvaluator(aSEvaluation);
            this.preprocessors.add(new SupervisedFilterSelector(aSSearch, aSEvaluation, attributeSelection));
        }
        super.setClassifier(classifier);
    }

    public void buildClassifier(Instances instances) throws Exception {
        int numAttributes = instances.numAttributes();
        logger.info("Starting to build the preprocessors of the pipeline.");
        for (SupervisedFilterSelector supervisedFilterSelector : this.preprocessors) {
            if (!supervisedFilterSelector.isPrepared()) {
                try {
                    long currentTimeMillis = System.currentTimeMillis();
                    supervisedFilterSelector.prepare(instances);
                    this.timeForTrainingPreprocessors = (int) (System.currentTimeMillis() - currentTimeMillis);
                    int numClasses = supervisedFilterSelector.apply(instances).numClasses();
                    if (instances.numClasses() != numClasses) {
                        logger.info("{} changed number of classes from {} to {}", new Object[]{supervisedFilterSelector.getSelector(), Integer.valueOf(instances.numClasses()), Integer.valueOf(numClasses)});
                    }
                } catch (NullPointerException e) {
                    logger.error("Could not apply preprocessor", e);
                }
            }
            instances = supervisedFilterSelector.apply(instances);
        }
        logger.info("Reduced number of attributes from {} to {}", Integer.valueOf(numAttributes), Integer.valueOf(instances.numAttributes()));
        long currentTimeMillis2 = System.currentTimeMillis();
        super.getClassifier().buildClassifier(instances);
        this.timeForTrainingClassifier = (int) (System.currentTimeMillis() - currentTimeMillis2);
        this.trained = true;
        this.timeForExecutingPreprocessors = new DescriptiveStatistics();
        this.timeForExecutingClassifier = new DescriptiveStatistics();
    }

    private Instance applyPreprocessors(Instance instance) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        Iterator<SupervisedFilterSelector> it = this.preprocessors.iterator();
        while (it.hasNext()) {
            instance = it.next().apply(instance);
        }
        this.timeForExecutingPreprocessors.addValue((int) (System.currentTimeMillis() - currentTimeMillis));
        return instance;
    }

    public double classifyInstance(Instance instance) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot make predictions on untrained pipeline!");
        }
        int numAttributes = instance.numAttributes();
        Instance applyPreprocessors = applyPreprocessors(instance);
        if (numAttributes != applyPreprocessors.numAttributes()) {
            logger.info("Reduced number of attributes from {} to {}", Integer.valueOf(numAttributes), Integer.valueOf(applyPreprocessors.numAttributes()));
        }
        long currentTimeMillis = System.currentTimeMillis();
        double classifyInstance = super.getClassifier().classifyInstance(applyPreprocessors);
        this.timeForExecutingClassifier.addValue(System.currentTimeMillis() - currentTimeMillis);
        return classifyInstance;
    }

    public double[] classifyInstances(Instances instances) throws Exception {
        int size = instances.size();
        double[] dArr = new double[size];
        for (int i = 0; i < size; i++) {
            dArr[i] = classifyInstance(instances.get(i));
        }
        return dArr;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot make predictions on untrained pipeline!");
        }
        if (instance == null) {
            throw new IllegalArgumentException("Cannot make predictions for null-instance");
        }
        Instance applyPreprocessors = applyPreprocessors(instance);
        if (applyPreprocessors == null) {
            throw new IllegalStateException("The filter has turned the instance into NULL");
        }
        long currentTimeMillis = System.currentTimeMillis();
        double[] distributionForInstance = super.getClassifier().distributionForInstance(applyPreprocessors);
        this.timeForExecutingClassifier.addValue((int) (System.currentTimeMillis() - currentTimeMillis));
        return distributionForInstance;
    }

    public Capabilities getCapabilities() {
        return super.getClassifier().getCapabilities();
    }

    public Classifier getBaseClassifier() {
        return super.getClassifier();
    }

    public List<SupervisedFilterSelector> getPreprocessors() {
        return this.preprocessors;
    }

    public String toString() {
        return getPreprocessors() + " (preprocessors), " + WekaUtil.getClassifierDescriptor(getBaseClassifier()) + " (classifier)";
    }

    public long getTimeForTrainingPreprocessor() {
        return this.timeForTrainingPreprocessors;
    }

    public long getTimeForTrainingClassifier() {
        return this.timeForTrainingClassifier;
    }

    public DescriptiveStatistics getTimeForExecutingPreprocessor() {
        return this.timeForExecutingPreprocessors;
    }

    public DescriptiveStatistics getTimeForExecutingClassifier() {
        return this.timeForExecutingClassifier;
    }
}
