package io.cdap.mmds.plugin;

import com.google.common.base.Joiner;
import com.google.common.collect.Sets;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.format.StructuredRecord;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.dataset.lib.FileSet;
import io.cdap.cdap.api.dataset.lib.IndexedTable;
import io.cdap.cdap.api.spark.sql.DataFrames;
import io.cdap.cdap.etl.api.PipelineConfigurer;
import io.cdap.cdap.etl.api.StageConfigurer;
import io.cdap.cdap.etl.api.batch.SparkCompute;
import io.cdap.cdap.etl.api.batch.SparkExecutionPluginContext;
import io.cdap.mmds.Constants;
import io.cdap.mmds.api.AlgorithmType;
import io.cdap.mmds.api.Modeler;
import io.cdap.mmds.data.ModelKey;
import io.cdap.mmds.data.ModelMeta;
import io.cdap.mmds.data.ModelTable;
import io.cdap.mmds.modeler.Modelers;
import io.cdap.mmds.modeler.feature.FeatureGenerator;
import io.cdap.mmds.modeler.feature.FeatureGeneratorPredictor;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import javax.annotation.Nullable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;

@Name("MLPredictor")
@Description("Uses a deployed model to add a prediction field to incoming records.")
@Plugin(type = SparkCompute.PLUGIN_TYPE)
/* loaded from: input_file:io/cdap/mmds/plugin/MLPredictor.class */
public class MLPredictor extends SparkCompute<StructuredRecord, StructuredRecord> {
    private final PredictorConf conf;
    private String featuregenPath;
    private String modelPath;
    private String targetIndexPath;
    private Schema inputSchema;
    private Schema outputSchema;
    private Schema.Type predictionType;
    private FeatureGenerator featureGenerator;
    private Modeler modeler;

    public MLPredictor(PredictorConf predictorConf) {
        this.conf = predictorConf;
    }

    @Override // io.cdap.cdap.etl.api.batch.SparkCompute, io.cdap.cdap.etl.api.PipelineConfigurable
    public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws IllegalArgumentException {
        StageConfigurer stageConfigurer = pipelineConfigurer.getStageConfigurer();
        Schema inputSchema = stageConfigurer.getInputSchema();
        if (inputSchema == null) {
            throw new IllegalArgumentException("ML Predictor cannot be used with a null input schema. Please connect it to stages that have a set output schema.");
        }
        this.conf.validate(inputSchema);
        stageConfigurer.setOutputSchema(this.conf.getOutputSchema());
    }

    @Override // io.cdap.cdap.etl.api.batch.SparkCompute
    public void initialize(SparkExecutionPluginContext sparkExecutionPluginContext) throws Exception {
        this.inputSchema = sparkExecutionPluginContext.getInputSchema();
        this.conf.validate(this.inputSchema);
        this.outputSchema = this.conf.getOutputSchema();
        Schema schema = this.outputSchema.getField(this.conf.getPredictionField()).getSchema();
        this.predictionType = (schema.isNullable() ? schema.getNonNullable() : schema).getType();
        ModelMeta modelMeta = new ModelTable((IndexedTable) sparkExecutionPluginContext.getDataset(Constants.Dataset.MODEL_META)).get(new ModelKey(this.conf.getExperimentID(), this.conf.getModelID()));
        if (modelMeta == null) {
            throw new IllegalArgumentException(String.format("Could not find model '%s' in experiment '%s'.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
        this.modeler = Modelers.getModeler(modelMeta.getAlgorithm());
        if (this.modeler == null) {
            throw new IllegalArgumentException(String.format("Model '%s' in experiment '%s' uses unknown algorithm '%s'", this.conf.getModelID(), this.conf.getExperimentID(), modelMeta.getAlgorithm()));
        }
        if (this.modeler.getAlgorithm().getType() == AlgorithmType.REGRESSION && this.predictionType == Schema.Type.STRING) {
            throw new IllegalArgumentException(String.format("Invalid getType for prediction field '%s'. Model '%s' in experiment '%s' is a regression model, which only supports double predictions.", this.conf.getPredictionField(), this.conf.getModelID(), this.conf.getExperimentID()));
        }
        HashSet hashSet = new HashSet(modelMeta.getFeatures());
        HashSet hashSet2 = new HashSet();
        Iterator<Schema.Field> it = this.inputSchema.getFields().iterator();
        while (it.hasNext()) {
            hashSet2.add(it.next().getName());
        }
        Sets.SetView difference = Sets.difference(hashSet, hashSet2);
        if (!difference.isEmpty()) {
            throw new IllegalArgumentException(String.format("Input is missing feature fields %s.", Joiner.on(',').join((Iterable<?>) difference)));
        }
        FileSet fileSet = (FileSet) sparkExecutionPluginContext.getDataset(Constants.Dataset.MODEL_COMPONENTS);
        this.featuregenPath = getComponentPath(fileSet, Constants.Component.FEATUREGEN);
        if (this.featuregenPath == null) {
            throw new IllegalArgumentException(String.format("Could not find feature generation data for model '%s' in experiment '%s'. Please verify that the same model and model meta datasets used to train the model are used here.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
        this.featureGenerator = new FeatureGeneratorPredictor(modelMeta.getFeatures(), modelMeta.getCategoricalFeatures(), this.featuregenPath);
        this.modelPath = getComponentPath(fileSet, Constants.Component.MODEL);
        if (this.modelPath == null) {
            throw new IllegalArgumentException(String.format("Could not find the files for model '%s' in experiment '%s'. Please verify that the model was successfully trained.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
        this.targetIndexPath = getComponentPath(fileSet, Constants.Component.TARGET_INDICES);
        if (this.targetIndexPath == null && this.modeler.getAlgorithm().getType() == AlgorithmType.CLASSIFICATION && this.predictionType == Schema.Type.STRING) {
            throw new IllegalArgumentException(String.format("Could not find target index data for model '%s' in experiment '%s'. Please change the prediction field type to double.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
    }

    @Override // io.cdap.cdap.etl.api.batch.SparkCompute
    public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext sparkExecutionPluginContext, JavaRDD<StructuredRecord> javaRDD) throws Exception {
        Dataset withColumnRenamed;
        PredictionModel loadPredictor = this.modeler.loadPredictor(this.modelPath);
        StructType dataType = DataFrames.toDataType(this.inputSchema);
        Dataset<Row> createDataFrame = new SQLContext(sparkExecutionPluginContext.getSparkContext().sc()).createDataFrame(javaRDD.map(new RecordToRow(dataType)), dataType);
        HashSet hashSet = new HashSet(this.featureGenerator.getFeatures());
        ArrayList arrayList = new ArrayList();
        Iterator<Schema.Field> it = this.outputSchema.getFields().iterator();
        while (it.hasNext()) {
            String name = it.next().getName();
            if (!this.conf.getPredictionField().equals(name) && !hashSet.contains(name)) {
                arrayList.add(name);
            }
        }
        Dataset transform = loadPredictor.transform(this.featureGenerator.generateFeatures(createDataFrame, arrayList));
        if (this.modeler.getAlgorithm().getType() == AlgorithmType.CLASSIFICATION && this.predictionType == Schema.Type.STRING) {
            withColumnRenamed = new IndexToString().setLabels(StringIndexerModel.load(this.targetIndexPath).labels()).setInputCol(Constants.TRAINER_PREDICTION_FIELD).setOutputCol(this.conf.getPredictionField()).transform(transform);
        } else {
            withColumnRenamed = transform.withColumnRenamed(Constants.TRAINER_PREDICTION_FIELD, this.conf.getPredictionField());
        }
        Column[] columnArr = new Column[this.outputSchema.getFields().size()];
        int i = 0;
        Iterator<Schema.Field> it2 = this.outputSchema.getFields().iterator();
        while (it2.hasNext()) {
            columnArr[i] = new Column(it2.next().getName());
            i++;
        }
        return withColumnRenamed.select(columnArr).toJavaRDD().map(new RowToRecord(this.outputSchema));
    }

    @Nullable
    private String getComponentPath(FileSet fileSet, String str) throws IOException {
        return fileSet.getLocation(this.conf.getExperimentID()).append(this.conf.getModelID()).append(str).toURI().getPath();
    }
}
