package org.tribuo.classification.experiments;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.data.DataOptions;

/* loaded from: input_file:org/tribuo/classification/experiments/RunAll.class */
public class RunAll {
    private static final Logger logger = Logger.getLogger(RunAll.class.getName());

    /* loaded from: input_file:org/tribuo/classification/experiments/RunAll$RunAllOptions.class */
    public static class RunAllOptions implements Options {
        public DataOptions general;

        @Option(charName = 'd', longName = "output-directory", usage = "Directory to write out the models and test reports.")
        public File directory;

        @Option(longName = "write-protobuf-models", usage = "Write out models in protobuf format.")
        public boolean protobuf;

        public String getOptionsDescription() {
            return "Performs the same training and test experiment on all Trainers in the supplied configuration file.";
        }
    }

    public static void main(String[] strArr) throws IOException {
        LabsLogFormatter.setAllLogFormatters();
        RunAllOptions runAllOptions = new RunAllOptions();
        try {
            ConfigurationManager configurationManager = new ConfigurationManager(strArr, runAllOptions);
            if (runAllOptions.general.trainingPath == null || runAllOptions.general.testingPath == null || runAllOptions.directory == null) {
                logger.info(configurationManager.usage());
                System.exit(1);
            }
            Pair pair = null;
            try {
                pair = runAllOptions.general.load(new LabelFactory());
            } catch (IOException e) {
                logger.log(Level.SEVERE, "Failed to load data", (Throwable) e);
                System.exit(1);
            }
            Dataset dataset = (Dataset) pair.getA();
            Dataset dataset2 = (Dataset) pair.getB();
            logger.info("Creating directory - " + runAllOptions.directory.toString());
            if (!runAllOptions.directory.exists() && !runAllOptions.directory.mkdirs()) {
                logger.warning("Failed to create directory.");
            }
            HashMap hashMap = new HashMap();
            for (Trainer trainer : configurationManager.lookupAll(Trainer.class)) {
                String simpleName = trainer.getClass().getSimpleName();
                logger.info("Training model using " + trainer.toString());
                Model train = trainer.train(dataset);
                LabelEvaluation evaluate = new LabelEvaluator().evaluate(train, dataset2);
                if (((Double) hashMap.put(simpleName, Double.valueOf(evaluate.microAveragedF1()))) != null) {
                    logger.info("Found two trainers with the name " + simpleName);
                }
                String str = runAllOptions.directory.toString() + "/" + simpleName;
                if (runAllOptions.protobuf) {
                    train.serializeToFile(Paths.get(str + ".model", new String[0]));
                } else {
                    ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str + ".model"));
                    try {
                        objectOutputStream.writeObject(train);
                        objectOutputStream.close();
                    } catch (Throwable th) {
                        try {
                            objectOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                        throw th;
                    }
                }
                PrintWriter printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(str + ".output"), StandardCharsets.UTF_8));
                try {
                    printWriter.println("Model = " + simpleName);
                    printWriter.println("Provenance = " + train.toString());
                    printWriter.println();
                    printWriter.println("ConfusionMatrix:\n" + evaluate.getConfusionMatrix().toString());
                    printWriter.println();
                    printWriter.println("Evaluation:\n" + evaluate.toString());
                    printWriter.close();
                } catch (Throwable th3) {
                    try {
                        printWriter.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                    throw th3;
                }
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                logger.info("Trainer = " + ((String) entry.getKey()) + ", F1 = " + entry.getValue());
            }
        } catch (UsageException e2) {
            logger.info(e2.getMessage());
        }
    }
}
