package io.cdap.mmds.modeler.feature;

import io.cdap.mmds.Constants;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:lib/mmds-model-1.9.0.jar:io/cdap/mmds/modeler/feature/FeatureGeneratorTrainer.class */
public class FeatureGeneratorTrainer extends FeatureGenerator {
    public FeatureGeneratorTrainer(List<String> list, Set<String> set) {
        super(list, set);
    }

    @Override // io.cdap.mmds.modeler.feature.FeatureGenerator
    protected PipelineModel getFeatureGenModel(Dataset<Row> dataset) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (String str : this.features) {
            if (isCategorical(str)) {
                String cleanName = cleanName(str);
                String indexedName = indexedName(str);
                arrayList.add(new StringIndexer().setInputCol(cleanName).setOutputCol(indexedName).setHandleInvalid("skip"));
                arrayList2.add(indexedName);
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (String str2 : this.features) {
            arrayList3.add(isCategorical(str2) ? indexedName(str2) : cleanName(str2));
            if (!isCategorical(str2)) {
                arrayList3.add(cleanName(str2));
            }
        }
        arrayList.add(new VectorAssembler().setInputCols((String[]) arrayList3.toArray(new String[arrayList3.size()])).setOutputCol(Constants.FEATURES_FIELD));
        return new Pipeline().setStages((PipelineStage[]) arrayList.toArray(new PipelineStage[arrayList.size()])).fit(dataset);
    }
}
