package org.tribuo.classification;

import com.oracle.labs.mlrg.olcut.config.ArgumentException;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.data.DataOptions;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/classification/TrainTestHelper.class */
public final class TrainTestHelper {
    private static final Logger logger = Logger.getLogger(TrainTestHelper.class.getName());
    private static final LabelFactory factory = new LabelFactory();

    private TrainTestHelper() {
    }

    public static Model<Label> run(ConfigurationManager configurationManager, DataOptions dataOptions, Trainer<Label> trainer) throws IOException {
        LabsLogFormatter.setAllLogFormatters();
        if (dataOptions.trainingPath == null || dataOptions.testingPath == null) {
            logger.info(configurationManager.usage());
            logger.info("Training Path = " + dataOptions.trainingPath + ", Testing Path = " + dataOptions.testingPath);
            throw new ArgumentException("training-file", "test-file", "Must supply both training and testing data.");
        }
        Pair load = dataOptions.load(factory);
        Dataset dataset = (Dataset) load.getA();
        logger.info("Training data has " + dataset.getFeatureIDMap().size() + " features.");
        Dataset dataset2 = (Dataset) load.getB();
        logger.info("Training using " + trainer.toString());
        long currentTimeMillis = System.currentTimeMillis();
        Model<Label> train = trainer.train(dataset);
        logger.info("Finished training classifier " + Util.formatDuration(currentTimeMillis, System.currentTimeMillis()));
        long currentTimeMillis2 = System.currentTimeMillis();
        LabelEvaluation labelEvaluation = (LabelEvaluation) factory.getEvaluator().evaluate(train, dataset2);
        logger.info("Finished evaluating model " + Util.formatDuration(currentTimeMillis2, System.currentTimeMillis()));
        if (train.generatesProbabilities()) {
            logger.info("Average AUC = " + labelEvaluation.averageAUCROC(false));
            logger.info("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
        }
        System.out.println(labelEvaluation.toString());
        System.out.println(labelEvaluation.getConfusionMatrix().toString());
        if (dataOptions.outputPath != null) {
            dataOptions.saveModel(train);
        }
        return train;
    }
}
