package org.tribuo.regression.xgboost;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.common.xgboost.XGBoostModel;
import org.tribuo.common.xgboost.XGBoostTrainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;

/* loaded from: input_file:org/tribuo/regression/xgboost/XGBoostRegressionTrainer.class */
public final class XGBoostRegressionTrainer extends XGBoostTrainer<Regressor> {
    private static final Logger logger = Logger.getLogger(XGBoostRegressionTrainer.class.getName());

    @Config(description = "The type of regression.")
    private RegressionType rType;

    /* loaded from: input_file:org/tribuo/regression/xgboost/XGBoostRegressionTrainer$RegressionType.class */
    public enum RegressionType {
        LINEAR("reg:squarederror"),
        GAMMA("reg:gamma"),
        TWEEDIE("reg:tweedie"),
        PSEUDOHUBER("reg:pseudohubererror");

        public final String paramName;

        RegressionType(String str) {
            this.paramName = str;
        }
    }

    public XGBoostRegressionTrainer(int i) {
        this(RegressionType.LINEAR, i);
    }

    public XGBoostRegressionTrainer(RegressionType regressionType, int i) {
        this(regressionType, i, 0.3d, 0.0d, 6, 1.0d, 1.0d, 1.0d, 1.0d, 0.0d, 4, true, 12345L);
    }

    public XGBoostRegressionTrainer(RegressionType regressionType, int i, int i2, boolean z) {
        this(regressionType, i, 0.3d, 0.0d, 6, 1.0d, 1.0d, 1.0d, 1.0d, 0.0d, i2, z, 12345L);
    }

    public XGBoostRegressionTrainer(RegressionType regressionType, int i, double d, double d2, int i2, double d3, double d4, double d5, double d6, double d7, int i3, boolean z, long j) {
        super(i, d, d2, i2, d3, d4, d5, d6, d7, i3, z, j);
        this.rType = RegressionType.LINEAR;
        this.rType = regressionType;
        postConfig();
    }

    public XGBoostRegressionTrainer(XGBoostTrainer.BoosterType boosterType, XGBoostTrainer.TreeMethod treeMethod, RegressionType regressionType, int i, double d, double d2, int i2, double d3, double d4, double d5, double d6, double d7, int i3, XGBoostTrainer.LoggingVerbosity loggingVerbosity, long j) {
        super(boosterType, treeMethod, i, d, d2, i2, d3, d4, d5, d6, d7, i3, loggingVerbosity, j);
        this.rType = RegressionType.LINEAR;
        this.rType = regressionType;
        postConfig();
    }

    public XGBoostRegressionTrainer(RegressionType regressionType, int i, Map<String, Object> map) {
        super(i, map);
        this.rType = RegressionType.LINEAR;
        this.rType = regressionType;
        postConfig();
    }

    private XGBoostRegressionTrainer() {
        this.rType = RegressionType.LINEAR;
    }

    public void postConfig() {
        super.postConfig();
        this.parameters.put("objective", this.rType.paramName);
    }

    public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> dataset, Map<String, Provenance> map, int i) {
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableRegressionInfo outputIDInfo = dataset.getOutputIDInfo();
        int size = outputIDInfo.size();
        if (i != -1) {
            setInvocationCount(i);
        }
        TrainerProvenance m4getProvenance = m4getProvenance();
        this.trainInvocationCounter++;
        ArrayList arrayList = new ArrayList();
        try {
            XGBoostTrainer.DMatrixTuple convertExamples = convertExamples(dataset, featureIDMap, null);
            int[] naturalOrderToIDMapping = outputIDInfo.getNaturalOrderToIDMapping();
            float[][] fArr = new float[size][dataset.size()];
            float[] fArr2 = new float[dataset.size()];
            int i2 = 0;
            Iterator it = dataset.iterator();
            while (it.hasNext()) {
                Example example = (Example) it.next();
                fArr2[i2] = example.getWeight();
                double[] values = example.getOutput().getValues();
                for (int i3 = 0; i3 < size; i3++) {
                    fArr[naturalOrderToIDMapping[i3]][i2] = (float) values[i3];
                }
                i2++;
            }
            convertExamples.data.setWeight(fArr2);
            for (int i4 = 0; i4 < size; i4++) {
                convertExamples.data.setLabel(fArr[i4]);
                arrayList.add(XGBoost.train(convertExamples.data, this.parameters, this.numTrees, Collections.emptyMap(), (IObjective) null, (IEvaluation) null));
            }
            return createModel("xgboost-regression-model", new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m4getProvenance, map), featureIDMap, outputIDInfo, arrayList, new XGBoostRegressionConverter());
        } catch (XGBoostError e) {
            logger.log(Level.SEVERE, "XGBoost threw an error", e);
            throw new IllegalStateException(e);
        }
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m4getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m2train(Dataset dataset, Map map, int i) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m3train(Dataset dataset, Map map) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map);
    }
}
