package weka.distributed;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.Serializable;
import java.util.Random;
import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.evaluation.AggregateableEvaluation;
import weka.classifiers.evaluation.AggregateableEvaluationWithPriors;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.trees.J48;
import weka.core.BatchPredictor;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:weka/distributed/WekaClassifierEvaluationMapTask.class */
public class WekaClassifierEvaluationMapTask implements Serializable {
    private static final long serialVersionUID = 7662934051268629543L;
    protected int m_numTestInstances;
    protected int m_numInstances;
    protected int m_totalFolds = 1;
    protected int m_foldNumber = -1;
    protected boolean m_batchTrainedIncremental = false;
    protected Classifier m_classifier = null;
    protected Evaluation m_eval = null;
    protected Instances m_trainingHeader = null;
    protected long m_seed = 1;
    protected double m_predFrac = 0.0d;

    public Evaluation getEvaluation() {
        return this.m_eval;
    }

    public void setClassifier(Classifier classifier) {
        this.m_classifier = classifier;
    }

    public Classifier getClassifier() {
        return this.m_classifier;
    }

    public void setFoldNumber(int i) {
        this.m_foldNumber = i;
    }

    public int getFoldNumber() {
        return this.m_foldNumber;
    }

    public void setTotalNumFolds(int i) {
        this.m_totalFolds = i;
    }

    public int getTotalNumFolds() {
        return this.m_totalFolds;
    }

    public void setBatchTrainedIncremental(boolean z) {
        this.m_batchTrainedIncremental = z;
    }

    public boolean getBatchTrainedIncremental() {
        return this.m_batchTrainedIncremental;
    }

    public void setup(Instances instances, double[] dArr, double d, long j, double d2) throws Exception {
        this.m_trainingHeader = new Instances(instances, 0);
        if (this.m_trainingHeader.classIndex() < 0) {
            throw new Exception("No class index set in the data!");
        }
        this.m_eval = new AggregateableEvaluationWithPriors(this.m_trainingHeader);
        if (dArr != null) {
            this.m_eval.setPriors(dArr, d);
        }
        this.m_numInstances = 0;
        this.m_numTestInstances = 0;
        this.m_seed = j;
        this.m_predFrac = d2;
    }

    public void processInstance(Instance instance) throws Exception {
        if (this.m_classifier == null || !(this.m_classifier instanceof UpdateableClassifier) || this.m_batchTrainedIncremental) {
            this.m_trainingHeader.add(instance);
        } else {
            boolean z = true;
            if (this.m_totalFolds > 1 && this.m_foldNumber >= 1 && this.m_numInstances % this.m_totalFolds != this.m_foldNumber - 1) {
                z = false;
            }
            if (z) {
                if (this.m_predFrac > 0.0d) {
                    this.m_eval.evaluateModelOnceAndRecordPrediction(this.m_classifier, instance);
                } else {
                    this.m_eval.evaluateModelOnce(this.m_classifier, instance);
                }
                this.m_numTestInstances++;
            }
        }
        this.m_numInstances++;
    }

    public void finalizeTask() throws Exception {
        if (this.m_classifier == null) {
            throw new Exception("No classifier has been set");
        }
        if ((this.m_classifier instanceof UpdateableClassifier) && !this.m_batchTrainedIncremental) {
            if (this.m_predFrac > 0.0d) {
                this.m_eval.prunePredictions(this.m_predFrac, this.m_seed);
                return;
            }
            return;
        }
        this.m_trainingHeader.compactify();
        Instances instances = this.m_trainingHeader;
        instances.randomize(new Random(this.m_seed));
        if (instances.classAttribute().isNominal() && this.m_totalFolds > 1) {
            instances.stratify(this.m_totalFolds);
        }
        if (this.m_totalFolds > 1 && this.m_foldNumber >= 1) {
            instances = instances.testCV(this.m_totalFolds, this.m_foldNumber - 1);
        }
        this.m_numTestInstances = instances.numInstances();
        if (this.m_classifier instanceof BatchPredictor) {
            this.m_eval.evaluateModel(this.m_classifier, instances, new Object[0]);
            if (this.m_predFrac < 0.0d) {
                this.m_eval.deleteStoredPredictions();
            }
        } else {
            for (int i = 0; i < instances.numInstances(); i++) {
                if (this.m_predFrac > 0.0d) {
                    this.m_eval.evaluateModelOnceAndRecordPrediction(this.m_classifier, instances.instance(i));
                } else {
                    this.m_eval.evaluateModelOnce(this.m_classifier, instances.instance(i));
                }
            }
        }
        if (this.m_predFrac > 0.0d) {
            this.m_eval.prunePredictions(this.m_predFrac, this.m_seed);
        }
    }

    public static void main(String[] strArr) {
        try {
            Instances instances = new Instances(new BufferedReader(new FileReader(strArr[0])));
            instances.setClassIndex(instances.numAttributes() - 1);
            AggregateableEvaluation aggregateableEvaluation = null;
            WekaClassifierEvaluationMapTask wekaClassifierEvaluationMapTask = new WekaClassifierEvaluationMapTask();
            Classifier j48 = new J48();
            WekaClassifierMapTask wekaClassifierMapTask = new WekaClassifierMapTask();
            wekaClassifierMapTask.setClassifier(j48);
            wekaClassifierMapTask.setTotalNumFolds(10);
            for (int i = 0; i < 10; i++) {
                System.err.println("Processing fold " + (i + 1));
                wekaClassifierMapTask.setFoldNumber(i + 1);
                wekaClassifierMapTask.setup(new Instances(instances, 0));
                wekaClassifierMapTask.addToTrainingHeader(instances);
                wekaClassifierMapTask.finalizeTask();
                wekaClassifierEvaluationMapTask.setClassifier(wekaClassifierMapTask.getClassifier());
                wekaClassifierEvaluationMapTask.setTotalNumFolds(10);
                wekaClassifierEvaluationMapTask.setFoldNumber(i + 1);
                wekaClassifierEvaluationMapTask.setup(new Instances(instances, 0), null, -1.0d, 1L, 0.0d);
                for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                    wekaClassifierEvaluationMapTask.processInstance(instances.instance(i2));
                }
                wekaClassifierEvaluationMapTask.finalizeTask();
                Evaluation evaluation = wekaClassifierEvaluationMapTask.getEvaluation();
                if (aggregateableEvaluation == null) {
                    aggregateableEvaluation = new AggregateableEvaluation(evaluation);
                }
                aggregateableEvaluation.aggregate(evaluation);
            }
            System.err.println(aggregateableEvaluation.toSummaryString());
            System.err.println("\n" + aggregateableEvaluation.toClassDetailsString());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
