package io.cdap.mmds.modeler.train;

import com.google.common.collect.ImmutableMap;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.mmds.Constants;
import io.cdap.mmds.api.AlgorithmType;
import io.cdap.mmds.api.Modeler;
import io.cdap.mmds.data.EvaluationMetrics;
import io.cdap.mmds.data.ModelTrainerInfo;
import io.cdap.mmds.modeler.Modelers;
import io.cdap.mmds.modeler.feature.FeatureGeneratorTrainer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:lib/mmds-model-1.10.2.jar:io/cdap/mmds/modeler/train/ModelTrainer.class */
public class ModelTrainer {
    private static final Logger LOG = LoggerFactory.getLogger(ModelTrainer.class.getName());
    private final String algorithm;
    private final String outcomeField;
    private final Schema.Type outcomeType;
    private final Map<String, String> trainingParams;
    private final List<String> featureNames = new ArrayList();
    private final Set<String> categoricalFeatures = new HashSet();
    private final Schema schema;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/mmds-model-1.10.2.jar:io/cdap/mmds/modeler/train/ModelTrainer$PredictionLabelFunction.class */
    public static class PredictionLabelFunction implements Function<Row, Tuple2> {
        private PredictionLabelFunction() {
        }

        public Tuple2 call(Row row) throws Exception {
            return new Tuple2(row.get(0), row.get(1));
        }
    }

    public ModelTrainer(ModelTrainerInfo modelTrainerInfo) {
        this.algorithm = modelTrainerInfo.getModel().getAlgorithm();
        this.trainingParams = ImmutableMap.copyOf((Map) modelTrainerInfo.getModel().getHyperparameters());
        this.schema = modelTrainerInfo.getDataSplitStats().getSchema();
        this.outcomeField = modelTrainerInfo.getExperiment().getOutcome();
        this.outcomeType = Schema.Type.valueOf(modelTrainerInfo.getExperiment().getOutcomeType().toUpperCase());
        for (Schema.Field field : this.schema.getFields()) {
            String name = field.getName();
            if (!name.equals(this.outcomeField)) {
                this.featureNames.add(name);
                Schema schema = field.getSchema();
                if (isCategorical(schema.isNullable() ? schema.getNonNullable().getType() : schema.getType())) {
                    this.categoricalFeatures.add(name);
                }
            }
        }
    }

    private boolean isCategorical(Schema.Type type) {
        return type == Schema.Type.STRING || type == Schema.Type.BOOLEAN;
    }

    public ModelOutput train(Dataset<Row> dataset, Dataset<Row> dataset2) throws IOException {
        EvaluationMetrics evaluationMetrics;
        Dataset<Row> drop = dataset.na().drop(new String[]{this.outcomeField});
        Dataset<Row> drop2 = dataset2.na().drop(new String[]{this.outcomeField});
        LOG.info("Generating features for training and test data.");
        FeatureGeneratorTrainer featureGeneratorTrainer = new FeatureGeneratorTrainer(this.featureNames, this.categoricalFeatures);
        Dataset<Row> generateFeatures = featureGeneratorTrainer.generateFeatures(drop, this.outcomeField);
        LOG.info("Training features successfully generated.");
        Dataset<Row> generateFeatures2 = featureGeneratorTrainer.generateFeatures(drop2, this.outcomeField);
        LOG.info("Test features successfully generated.");
        String str = this.outcomeField;
        StringIndexerModel stringIndexerModel = null;
        boolean isCategorical = isCategorical(this.outcomeType);
        String str2 = Constants.TRAINER_PREDICTION_FIELD;
        if (isCategorical) {
            String str3 = this.outcomeField;
            if (this.outcomeType == Schema.Type.BOOLEAN) {
                str3 = "_c_" + this.outcomeField;
                Column cast = new Column(this.outcomeField).cast(DataTypes.StringType);
                generateFeatures = generateFeatures.withColumn(str3, cast);
                generateFeatures2 = generateFeatures2.withColumn(str3, cast);
            }
            str = "_t_" + this.outcomeField;
            stringIndexerModel = new StringIndexer().setInputCol(str3).setOutputCol(str).fit(generateFeatures);
            generateFeatures = stringIndexerModel.transform(generateFeatures);
            generateFeatures2 = stringIndexerModel.transform(generateFeatures2);
            str2 = "_n_" + str2;
        }
        Modeler modeler = Modelers.getModeler(this.algorithm);
        Predictor createPredictor = modeler.createPredictor(this.trainingParams);
        createPredictor.setLabelCol(str);
        createPredictor.setFeaturesCol(Constants.FEATURES_FIELD);
        createPredictor.setPredictionCol(str2);
        LOG.info("Training model...");
        MLWritable fit = createPredictor.fit(generateFeatures);
        LOG.info("Model successfully trained.");
        LOG.info("Generating predictions on test data.");
        Dataset transform = fit.transform(generateFeatures2);
        LOG.info("Predictions successfully generated.");
        if (isCategorical(this.outcomeType)) {
            transform = new IndexToString().setLabels(stringIndexerModel.labels()).setInputCol(str2).setOutputCol(Constants.TRAINER_PREDICTION_FIELD).transform(transform);
        }
        LOG.info("Calculating evaluation metrics...");
        RDD rdd = transform.select(new Column[]{new Column(str2), new Column(str).cast(DataTypes.DoubleType)}).toJavaRDD().map(new PredictionLabelFunction()).rdd();
        try {
            if (modeler.getAlgorithm().getType() == AlgorithmType.REGRESSION) {
                RegressionMetrics regressionMetrics = new RegressionMetrics(rdd, false);
                double rootMeanSquaredError = regressionMetrics.rootMeanSquaredError();
                double r2 = regressionMetrics.r2();
                double meanAbsoluteError = regressionMetrics.meanAbsoluteError();
                double explainedVariance = regressionMetrics.explainedVariance();
                LOG.info("root mean squared error = {}, r2 = {}, mean absolute error = {}, explained variance = {}", new Object[]{Double.valueOf(rootMeanSquaredError), Double.valueOf(r2), Double.valueOf(meanAbsoluteError), Double.valueOf(explainedVariance)});
                evaluationMetrics = new EvaluationMetrics(rootMeanSquaredError, r2, explainedVariance, meanAbsoluteError);
            } else {
                MulticlassMetrics multiclassMetrics = new MulticlassMetrics(rdd);
                double weightedPrecision = multiclassMetrics.weightedPrecision();
                double weightedRecall = multiclassMetrics.weightedRecall();
                double weightedFMeasure = multiclassMetrics.weightedFMeasure();
                LOG.info("precision = {}, recall = {}, f1 = {}", new Object[]{Double.valueOf(weightedPrecision), Double.valueOf(weightedRecall), Double.valueOf(weightedFMeasure)});
                evaluationMetrics = new EvaluationMetrics(weightedPrecision, weightedRecall, weightedFMeasure);
            }
            Column[] columnArr = new Column[this.schema.getFields().size() + 1];
            columnArr[0] = new Column(Constants.TRAINER_PREDICTION_FIELD);
            int i = 1;
            Iterator it = this.schema.getFields().iterator();
            while (it.hasNext()) {
                columnArr[i] = new Column(((Schema.Field) it.next()).getName());
                i++;
            }
            return ModelOutput.builder().setTargetIndexModel(stringIndexerModel).setFeatureGenModel(featureGeneratorTrainer.getFeatureGenModel()).setModel(fit).setEvaluationMetrics(evaluationMetrics).setFeatureNames(this.featureNames).setCategoricalFeatures(this.categoricalFeatures).setPredictions(transform.select(columnArr)).setAlgorithmType(modeler.getAlgorithm().getType()).setSchema(this.schema).build();
        } catch (IllegalArgumentException e) {
            throw new RuntimeException("Failed to get evaluation metrics for the model. Please check the logs for warnings or errors related to training problems. If there were training problems, please check that your features are the correct type. String features should represent categories, with multiple records for each category value. For example, an ID should not be used as a feature, as there is a unique value for each record.", e);
        }
    }
}
