package org.tribuo.regression.rtree;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.data.DataOptions;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.rtree.impurity.MeanAbsoluteError;
import org.tribuo.regression.rtree.impurity.MeanSquaredError;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/rtree/TrainTest.class */
public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    /* loaded from: input_file:org/tribuo/regression/rtree/TrainTest$ImpurityType.class */
    public enum ImpurityType {
        MSE,
        MAE
    }

    /* loaded from: input_file:org/tribuo/regression/rtree/TrainTest$RegressionTreeOptions.class */
    public static class RegressionTreeOptions implements Options {
        public DataOptions general;

        @Option(longName = "csv-response-split-char", usage = "Character to split the CSV response on to generate multiple regression dimensions. Defaults to ':'.")
        public char splitChar = ':';

        @Option(charName = 'd', longName = "max-depth", usage = "Maximum depth in the decision tree.")
        public int depth = 6;

        @Option(charName = 'e', longName = "split-fraction", usage = "Fraction of features in split.")
        public float fraction = 1.0f;

        @Option(charName = 'm', longName = "min-child-weight", usage = "Minimum child weight.")
        public float minChildWeight = 5.0f;

        @Option(charName = 'p', longName = "min-impurity-decrease", usage = "Minimumum decrease in impurity required in order for the node to be split.")
        public float minImpurityDecrease = 0.0f;

        @Option(charName = 'r', longName = "use-random-split-points", usage = "Whether to choose split points for features at random.")
        public boolean useRandomSplitPoints = false;

        @Option(charName = 'n', longName = "normalize", usage = "Normalize the leaf outputs so each leaf sums to 1.0.")
        public boolean normalize = false;

        @Option(charName = 'i', longName = "impurity", usage = "Impurity measure to use. Defaults to MSE.")
        public ImpurityType impurityType = ImpurityType.MSE;

        @Option(charName = 't', longName = "tree-type", usage = "Tree type.")
        public TreeType treeType = TreeType.CART_INDEPENDENT;

        @Option(longName = "print-tree", usage = "Prints the decision tree.")
        public boolean printTree;

        public String getOptionsDescription() {
            return "Trains and tests a CART regression model on the specified datasets.";
        }
    }

    /* loaded from: input_file:org/tribuo/regression/rtree/TrainTest$TreeType.class */
    public enum TreeType {
        CART_INDEPENDENT,
        CART_JOINT
    }

    public static void main(String[] strArr) throws IOException {
        RegressorImpurity meanSquaredError;
        SparseTrainer cARTJointRegressionTrainer;
        LabsLogFormatter.setAllLogFormatters();
        RegressionTreeOptions regressionTreeOptions = new RegressionTreeOptions();
        try {
            ConfigurationManager configurationManager = new ConfigurationManager(strArr, regressionTreeOptions);
            RegressionFactory regressionFactory = new RegressionFactory(regressionTreeOptions.splitChar);
            Pair load = regressionTreeOptions.general.load(regressionFactory);
            Dataset dataset = (Dataset) load.getA();
            Dataset dataset2 = (Dataset) load.getB();
            switch (regressionTreeOptions.impurityType) {
                case MAE:
                    meanSquaredError = new MeanAbsoluteError();
                    break;
                case MSE:
                    meanSquaredError = new MeanSquaredError();
                    break;
                default:
                    logger.severe("unknown impurity type " + regressionTreeOptions.impurityType);
                    return;
            }
            if (regressionTreeOptions.general.trainingPath == null || regressionTreeOptions.general.testingPath == null) {
                logger.info(configurationManager.usage());
                return;
            }
            switch (regressionTreeOptions.treeType) {
                case CART_INDEPENDENT:
                    cARTJointRegressionTrainer = new CARTRegressionTrainer(regressionTreeOptions.depth, regressionTreeOptions.minChildWeight, regressionTreeOptions.minImpurityDecrease, regressionTreeOptions.fraction, regressionTreeOptions.useRandomSplitPoints, meanSquaredError, regressionTreeOptions.general.seed);
                    break;
                case CART_JOINT:
                    cARTJointRegressionTrainer = new CARTJointRegressionTrainer(regressionTreeOptions.depth, regressionTreeOptions.minChildWeight, regressionTreeOptions.minImpurityDecrease, regressionTreeOptions.fraction, regressionTreeOptions.useRandomSplitPoints, meanSquaredError, regressionTreeOptions.normalize, regressionTreeOptions.general.seed);
                    break;
                default:
                    logger.severe("unknown tree type " + regressionTreeOptions.treeType);
                    return;
            }
            logger.info("Training using " + cARTJointRegressionTrainer.toString());
            long currentTimeMillis = System.currentTimeMillis();
            SparseModel train = cARTJointRegressionTrainer.train(dataset);
            logger.info("Finished training regressor " + Util.formatDuration(currentTimeMillis, System.currentTimeMillis()));
            if (regressionTreeOptions.printTree) {
                logger.info(train.toString());
            }
            logger.info("Selected features: " + train.getActiveFeatures());
            long currentTimeMillis2 = System.currentTimeMillis();
            RegressionEvaluation evaluate = regressionFactory.getEvaluator().evaluate(train, dataset2);
            logger.info("Finished evaluating model " + Util.formatDuration(currentTimeMillis2, System.currentTimeMillis()));
            System.out.println(evaluate.toString());
            if (regressionTreeOptions.general.outputPath != null) {
                regressionTreeOptions.general.saveModel(train);
            }
        } catch (UsageException e) {
            logger.info(e.getMessage());
        }
    }
}
