package weka.distributed;

import distributed.core.DistributedJob;
import distributed.core.DistributedJobConfig;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IteratedSingleClassifierEnhancer;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.bayes.NaiveBayesUpdateable;
import weka.classifiers.meta.AggregateableFilteredClassifier;
import weka.classifiers.meta.AggregateableFilteredClassifierUpdateable;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.meta.FilteredClassifierUpdateable;
import weka.classifiers.trees.REPTree;
import weka.core.Aggregateable;
import weka.core.Environment;
import weka.core.EnvironmentHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.MakePreconstructedFilter;
import weka.filters.MultiFilter;
import weka.filters.PreconstructedFilter;
import weka.filters.StreamableFilter;
import weka.filters.unsupervised.instance.ReservoirSample;

/* loaded from: input_file:weka/distributed/WekaClassifierMapTask.class */
public class WekaClassifierMapTask implements OptionHandler, EnvironmentHandler, Serializable {
    public static final String TOTAL_NUMBER_OF_MAPS = "total.num.maps";
    private static final long serialVersionUID = -5953696466790594368L;
    protected boolean m_forceBatchForUpdateable;
    protected boolean m_continueTrainingUpdateable;
    protected int m_numTrainingInstances;
    protected int m_numInstances;
    protected Instances m_trainingHeader;
    protected boolean m_useReservoirSampling;
    protected ReservoirSample m_reservoir;
    protected boolean m_forceVotedEnsemble;
    protected Classifier m_classifier = new REPTree();
    protected int m_totalFolds = 1;
    protected int m_foldNumber = -1;
    protected transient Environment m_env = Environment.getSystemWide();
    protected int m_sampleSize = -1;
    protected List<Filter> m_filtersToUse = new ArrayList();
    protected String m_seed = "1";

    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.add(new Option("\tThe fully qualified base classifier to use. Classifier options\n\tcan be supplied after a '--'", "W", 1, "-W"));
        vector.add(new Option("\tForce batch learning for Updateable classifiers.", "force-batch", 1, "-force-batch"));
        vector.add(new Option("\tForce Vote-based ensemble creation (i.e. ignore\n\ta classifier's Aggregateable status and create a final Vote\n\tensemble)", "force-vote", 0, "-force-vote"));
        vector.add(new Option("\tUse reservoir sampling for batch learning", "use-sampling", 0, "-use-sampling"));
        vector.add(new Option("\tSpecify a filter to pre-process the data with.\n\tFor Aggregateable classifiers the filter must be a StreamableFilter,\n\tmeaning that the output format produced by the filter must be able to\n\t be determined directly from the input data format (this makes the data format\n\tcompatible across map tasks). Many unsupervised attribute-based\n\tfilters are StreamableFilters. If a Vote ensemble is being produced,\n\tthis constraint does not apply and any filter may be used.\n\tThis option may be supplied multiple times in order to apply more\n\tthan one filter.", "filter", 1, "-filter"));
        vector.add(new Option("\tSample size if reservoir sampling is being used", "sample-size", 1, "-sample-size <num instances>"));
        vector.add(new Option("\tTraining fold to use (default = -1, i.e. use all the data)", "fold-number", 1, "-fold-number <fold num>"));
        vector.add(new Option("\tTotal number of folds. Use in conjunction with -fold-number (default = 1, i.e. use all the data)", "total-folds", 1, "-total-folds <num folds>"));
        vector.add(new Option("\tRandom seed for fold generation.", "seed", 1, "-seed <integer>"));
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption("W", strArr);
        if (!DistributedJobConfig.isEmpty(option)) {
            setClassifier(AbstractClassifier.forName(option, Utils.partitionOptions(strArr)));
        }
        setForceBatchLearningForUpdateableClassifiers(Utils.getFlag("force-batch", strArr));
        String option2 = Utils.getOption("fold-number", strArr);
        if (!DistributedJobConfig.isEmpty(option2)) {
            setFoldNumber(Integer.parseInt(option2));
        }
        String option3 = Utils.getOption("total-folds", strArr);
        if (!DistributedJobConfig.isEmpty(option3)) {
            setTotalNumFolds(Integer.parseInt(option3));
        }
        setUseReservoirSamplingWhenBatchLearning(Utils.getFlag("use-sampling", strArr));
        setForceVotedEnsembleCreation(Utils.getFlag("force-vote", strArr));
        String option4 = Utils.getOption("sample-size", strArr);
        if (!DistributedJobConfig.isEmpty(option4)) {
            setReservoirSampleSize(Integer.parseInt(option4));
        }
        String option5 = Utils.getOption("seed", strArr);
        if (!DistributedJobConfig.isEmpty(option5)) {
            setSeed(option5);
        }
        while (true) {
            String option6 = Utils.getOption("filter", strArr);
            if (DistributedJobConfig.isEmpty(option6)) {
                return;
            }
            String[] splitOptions = Utils.splitOptions(option6);
            if (splitOptions.length == 0) {
                throw new IllegalArgumentException("Invalid filter specification string");
            }
            OptionHandler optionHandler = (Filter) Class.forName(splitOptions[0]).newInstance();
            splitOptions[0] = "";
            if (optionHandler instanceof OptionHandler) {
                optionHandler.setOptions(splitOptions);
            }
            this.m_filtersToUse.add(optionHandler);
        }
    }

    protected String getFilterSpec(Filter filter) {
        return filter.getClass().getName() + (filter instanceof OptionHandler ? " " + Utils.joinOptions(((OptionHandler) filter).getOptions()) : "");
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-W");
        arrayList.add(this.m_classifier.getClass().getName());
        if (getForceBatchLearningForUpdateableClassifiers()) {
            arrayList.add("-force-batch");
        }
        if (getForceVotedEnsembleCreation()) {
            arrayList.add("-force-vote");
        }
        arrayList.add("-fold-number");
        arrayList.add("" + getFoldNumber());
        arrayList.add("-total-folds");
        arrayList.add("" + getTotalNumFolds());
        if (getUseReservoirSamplingWhenBatchLearning()) {
            arrayList.add("-use-sampling");
            arrayList.add("-sample-size");
            arrayList.add("" + getReservoirSampleSize());
        }
        arrayList.add("-seed");
        arrayList.add(getSeed());
        if (this.m_filtersToUse != null) {
            for (Filter filter : this.m_filtersToUse) {
                arrayList.add("-filter");
                arrayList.add(getFilterSpec(filter));
            }
        }
        if (this.m_classifier instanceof OptionHandler) {
            String[] options = this.m_classifier.getOptions();
            arrayList.add("--");
            for (String str : options) {
                arrayList.add(str);
            }
        }
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    public void setClassifier(Classifier classifier) {
        String variableValue;
        this.m_classifier = classifier;
        if (!(this.m_classifier instanceof IteratedSingleClassifierEnhancer) || (variableValue = this.m_env.getVariableValue(TOTAL_NUMBER_OF_MAPS)) == null || variableValue.length() <= 0) {
            return;
        }
        int parseInt = Integer.parseInt(variableValue);
        if (parseInt > 1) {
            parseInt--;
        }
        int numIterations = this.m_classifier.getNumIterations() / parseInt;
        if (numIterations < 1) {
            numIterations = 1;
        }
        this.m_classifier.setNumIterations(numIterations);
    }

    public Classifier getClassifier() {
        return this.m_classifier;
    }

    public String classifierTipText() {
        return "The classifier to use";
    }

    public void setSeed(String str) {
        this.m_seed = str;
    }

    public String getSeed() {
        return this.m_seed;
    }

    public String seedTipText() {
        return "Random seed for shuffling the training data and reservoir sampling (batch learning only).";
    }

    public void setForceVotedEnsembleCreation(boolean z) {
        this.m_forceVotedEnsemble = z;
    }

    public boolean getForceVotedEnsembleCreation() {
        return this.m_forceVotedEnsemble;
    }

    public String forceVotedEnsembleCreation() {
        return "Force the creation of a Vote ensemble even if the base classifier is directly aggregateable";
    }

    public Filter[] getFiltersToUse() {
        ArrayList arrayList = new ArrayList();
        for (Filter filter : this.m_filtersToUse) {
            if (!(filter instanceof PreconstructedFilter)) {
                arrayList.add(filter);
            }
        }
        Filter[] filterArr = new Filter[arrayList.size()];
        int i = 0;
        for (Filter filter2 : this.m_filtersToUse) {
            if (!(filter2 instanceof PreconstructedFilter)) {
                int i2 = i;
                i++;
                filterArr[i2] = filter2;
            }
        }
        return filterArr;
    }

    public void setFiltersToUse(Filter[] filterArr) {
        this.m_filtersToUse.clear();
        if (filterArr == null || filterArr.length <= 0) {
            return;
        }
        for (Filter filter : filterArr) {
            if (!(filter instanceof PreconstructedFilter)) {
                this.m_filtersToUse.add(filter);
            }
        }
    }

    public String filtersToUseTipText() {
        return "Filters to pre-process the data with before passing it to the classifier. Note that in order to remain directly aggregateable to a single model StreamableFilters must be used with Aggregateable classifiers.";
    }

    public void addPreconstructedFilterToUse(PreconstructedFilter preconstructedFilter) {
        this.m_filtersToUse.add(0, (Filter) preconstructedFilter);
    }

    public void setUseReservoirSamplingWhenBatchLearning(boolean z) {
        this.m_useReservoirSampling = z;
    }

    public boolean getUseReservoirSamplingWhenBatchLearning() {
        return this.m_useReservoirSampling;
    }

    public String useReservoirSamplingWhenBatchLearningTipText() {
        return "Apply reservoir sampling to enforce a maximum number of training instances when batch learning";
    }

    public void setReservoirSampleSize(int i) {
        this.m_sampleSize = i;
    }

    public int getReservoirSampleSize() {
        return this.m_sampleSize;
    }

    public String reservoirSampleSizeTipText() {
        return "The sample size (number of instances/rows) for reservoir sampling";
    }

    public void setForceBatchLearningForUpdateableClassifiers(boolean z) {
        this.m_forceBatchForUpdateable = z;
    }

    public boolean getForceBatchLearningForUpdateableClassifiers() {
        return this.m_forceBatchForUpdateable;
    }

    public String forceBatchLearningForUpdateableClassifiersTipText() {
        return "Use batch training even if the base classifier is an incremental (Updateable) one.";
    }

    public void setContinueTrainingUpdateableClassifier(boolean z) {
        this.m_continueTrainingUpdateable = z;
    }

    public boolean getContinueTrainingUpdateableClassifier() {
        return this.m_continueTrainingUpdateable;
    }

    public String continueTrainingUpdateableClassifierTipText() {
        return "Continue training (updating) an incremental classifier with the incoming data (rather than start from scratch)";
    }

    public void setFoldNumber(int i) {
        this.m_foldNumber = i;
    }

    public int getFoldNumber() {
        return this.m_foldNumber;
    }

    public String foldNumberTipText() {
        return "Set the fold number to train the classifier with. Default (-1) is to use all the data for training the classifier. Use in conjunction with setTotalNumberOfFolds";
    }

    public void setTotalNumFolds(int i) {
        this.m_totalFolds = i;
    }

    public int getTotalNumFolds() {
        return this.m_totalFolds;
    }

    public String totalNumberOfFoldsTipText() {
        return "The total number of folds. Use in conjunction with setFoldNumber(). Only has an effect if setFoldNumber() is set to something other than -1.";
    }

    public int getNumTrainingInstances() {
        return this.m_numTrainingInstances;
    }

    public void addToTrainingHeader(Instances instances) {
        for (int i = 0; i < instances.numInstances(); i++) {
            this.m_trainingHeader.add(instances.instance(i));
        }
    }

    public void addToTrainingHeader(Instance instance) {
        this.m_trainingHeader.add(instance);
    }

    protected void determineWrapperClassifierToUse(Instances instances) throws DistributedWekaException {
        FilteredClassifier filteredClassifier;
        Filter filter;
        if (this.m_continueTrainingUpdateable || this.m_filtersToUse.size() == 0) {
            return;
        }
        if (!(this.m_classifier instanceof Aggregateable) || this.m_forceVotedEnsemble) {
            if (!(getClassifier() instanceof UpdateableClassifier) || this.m_forceBatchForUpdateable) {
                filteredClassifier = new FilteredClassifier();
            } else {
                filteredClassifier = new FilteredClassifierUpdateable();
                for (Filter filter2 : this.m_filtersToUse) {
                    if (!(filter2 instanceof StreamableFilter)) {
                        throw new DistributedWekaException("Base classifier is Updateable. In this case all filters must be StreamableFilters but " + filter2.getClass().getName() + " is not.");
                    }
                }
            }
            filteredClassifier.setClassifier(getClassifier());
            if (this.m_filtersToUse.size() > 1) {
                MultiFilter multiFilter = new MultiFilter();
                multiFilter.setFilters((Filter[]) this.m_filtersToUse.toArray(new Filter[this.m_filtersToUse.size()]));
                filteredClassifier.setFilter(multiFilter);
            } else {
                filteredClassifier.setFilter(this.m_filtersToUse.get(0));
            }
            this.m_classifier = filteredClassifier;
            return;
        }
        ArrayList arrayList = new ArrayList();
        for (Filter filter3 : this.m_filtersToUse) {
            if (filter3 instanceof PreconstructedFilter) {
                arrayList.add(filter3);
            } else {
                if (!(filter3 instanceof StreamableFilter)) {
                    throw new DistributedWekaException("Base classifier is Aggregateable. In this case all filters must be StreamableFilters but " + filter3.getClass().getName() + " is not Streamable");
                }
                arrayList.add(new MakePreconstructedFilter(filter3));
            }
        }
        if (arrayList.size() > 1) {
            MultiFilter multiFilter2 = new MultiFilter();
            multiFilter2.setFilters((Filter[]) arrayList.toArray(new Filter[arrayList.size()]));
            filter = new MakePreconstructedFilter(multiFilter2);
        } else {
            filter = (Filter) arrayList.get(0);
        }
        try {
            filter.setInputFormat(instances);
            if (!(this.m_classifier instanceof UpdateableClassifier) || getForceBatchLearningForUpdateableClassifiers()) {
                AggregateableFilteredClassifier aggregateableFilteredClassifier = new AggregateableFilteredClassifier();
                aggregateableFilteredClassifier.setClassifier(getClassifier());
                aggregateableFilteredClassifier.setPreConstructedFilter(filter);
                this.m_classifier = aggregateableFilteredClassifier;
                return;
            }
            AggregateableFilteredClassifierUpdateable aggregateableFilteredClassifierUpdateable = new AggregateableFilteredClassifierUpdateable();
            aggregateableFilteredClassifierUpdateable.setClassifier(getClassifier());
            aggregateableFilteredClassifierUpdateable.setPreConstructedFilter(filter);
            this.m_classifier = aggregateableFilteredClassifierUpdateable;
        } catch (Exception e) {
            throw new DistributedWekaException(e);
        }
    }

    public void setup(Instances instances) throws DistributedWekaException {
        if (this.m_classifier == null) {
            throw new DistributedWekaException("No classifier has been configured!");
        }
        this.m_trainingHeader = new Instances(instances, 0);
        if (this.m_trainingHeader.classIndex() < 0) {
            throw new DistributedWekaException("No class index set in the data!");
        }
        determineWrapperClassifierToUse(instances);
        if ((this.m_classifier instanceof UpdateableClassifier) && !this.m_forceBatchForUpdateable && !this.m_continueTrainingUpdateable) {
            try {
                this.m_classifier.buildClassifier(instances);
            } catch (Exception e) {
                throw new DistributedWekaException(e);
            }
        }
        if (getUseReservoirSamplingWhenBatchLearning()) {
            if (getReservoirSampleSize() <= 0) {
                throw new DistributedWekaException("Reservoir sampling requested, but no sample size set.");
            }
            this.m_reservoir = new ReservoirSample();
            this.m_reservoir.setSampleSize(getReservoirSampleSize());
            int i = 1;
            if (!DistributedJobConfig.isEmpty(getSeed())) {
                String seed = getSeed();
                try {
                    seed = this.m_env.substitute(seed);
                } catch (Exception e2) {
                }
                try {
                    i = Integer.parseInt(seed);
                } catch (NumberFormatException e3) {
                    System.err.println("Trouble parsing random seed value: " + seed);
                }
            }
            this.m_reservoir.setRandomSeed(i);
            try {
                this.m_reservoir.setInputFormat(this.m_trainingHeader);
            } catch (Exception e4) {
                throw new DistributedWekaException(e4);
            }
        }
        this.m_numTrainingInstances = 0;
        this.m_numInstances = 0;
    }

    public void processInstance(Instance instance) throws DistributedWekaException {
        if ((this.m_classifier instanceof UpdateableClassifier) && !this.m_forceBatchForUpdateable) {
            boolean z = true;
            if (this.m_totalFolds > 1 && this.m_foldNumber >= 1 && this.m_numInstances % this.m_totalFolds == this.m_foldNumber - 1) {
                z = false;
            }
            if (z) {
                try {
                    this.m_classifier.updateClassifier(instance);
                    this.m_numTrainingInstances++;
                } catch (Exception e) {
                    throw new DistributedWekaException(e);
                }
            }
        } else if (this.m_reservoir != null) {
            this.m_reservoir.input(instance);
        } else {
            this.m_trainingHeader.add(instance);
        }
        this.m_numInstances++;
    }

    public void finalizeTask() throws DistributedWekaException {
        System.gc();
        Runtime runtime = Runtime.getRuntime();
        System.err.println("[ClassifierMapTask] Memory (free/total/max.) in bytes: " + String.format("%,d", Long.valueOf(runtime.freeMemory())) + " / " + String.format("%,d", Long.valueOf(runtime.totalMemory())) + " / " + String.format("%,d", Long.valueOf(runtime.maxMemory())));
        if ((this.m_classifier instanceof UpdateableClassifier) && !this.m_forceBatchForUpdateable) {
            if (this.m_classifier instanceof UpdateableBatchProcessor) {
                try {
                    System.err.println("Calling batch finished on updateable classifier...");
                    this.m_classifier.batchFinished();
                    return;
                } catch (Exception e) {
                    throw new DistributedWekaException(e);
                }
            }
            return;
        }
        if (this.m_reservoir != null) {
            this.m_reservoir.batchFinished();
            while (this.m_reservoir.numPendingOutput() > 0) {
                this.m_trainingHeader.add(this.m_reservoir.output());
            }
        }
        this.m_trainingHeader.compactify();
        long j = 1;
        if (!DistributedJobConfig.isEmpty(getSeed())) {
            String seed = getSeed();
            try {
                seed = this.m_env.substitute(seed);
            } catch (Exception e2) {
            }
            try {
                j = Long.parseLong(seed);
            } catch (NumberFormatException e3) {
                System.err.println("Trouble parsing random seed value: " + seed);
            }
        }
        Random random = new Random(j);
        this.m_trainingHeader.randomize(random);
        if (this.m_trainingHeader.classAttribute().isNominal() && this.m_totalFolds > 1) {
            this.m_trainingHeader.stratify(this.m_totalFolds);
        }
        Instances instances = this.m_trainingHeader;
        if (this.m_totalFolds > 1 && this.m_foldNumber >= 1) {
            instances = instances.trainCV(this.m_totalFolds, this.m_foldNumber - 1, random);
        }
        this.m_numTrainingInstances = instances.numInstances();
        try {
            this.m_classifier.buildClassifier(instances);
        } catch (Exception e4) {
            throw new DistributedWekaException(e4);
        }
    }

    public void setEnvironment(Environment environment) {
        this.m_env = environment;
    }

    public static void main(String[] strArr) {
        try {
            WekaClassifierMapTask wekaClassifierMapTask = new WekaClassifierMapTask();
            if (Utils.getFlag('h', strArr)) {
                System.err.println(DistributedJob.makeOptionsStr(wekaClassifierMapTask));
                System.exit(1);
            }
            Instances instances = new Instances(new BufferedReader(new FileReader(Utils.getOption("t", strArr))));
            instances.setClassIndex(instances.numAttributes() - 1);
            wekaClassifierMapTask.setOptions(strArr);
            wekaClassifierMapTask.setup(new Instances(instances, 0));
            for (int i = 0; i < instances.numInstances(); i++) {
                wekaClassifierMapTask.processInstance(instances.instance(i));
            }
            wekaClassifierMapTask.finalizeTask();
            System.err.println("Batch trained classifier:\n" + wekaClassifierMapTask.getClassifier().toString());
            WekaClassifierMapTask wekaClassifierMapTask2 = new WekaClassifierMapTask();
            wekaClassifierMapTask2.setClassifier(new NaiveBayesUpdateable());
            wekaClassifierMapTask2.setup(new Instances(instances, 0));
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                wekaClassifierMapTask2.processInstance(instances.instance(i2));
            }
            System.err.println("Incremental training (iteration 1):\n" + wekaClassifierMapTask2.getClassifier().toString());
            wekaClassifierMapTask2.setContinueTrainingUpdateableClassifier(true);
            wekaClassifierMapTask2.setup(new Instances(instances, 0));
            for (int i3 = 0; i3 < instances.numInstances(); i3++) {
                wekaClassifierMapTask2.processInstance(instances.instance(i3));
            }
            System.err.println("Incremental training (iteration 2):\n" + wekaClassifierMapTask2.getClassifier().toString());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
