package com.databricks.labs.automl;

import com.databricks.labs.automl.params.AutomationOutput;
import com.databricks.labs.automl.params.ConfusionOutput;
import com.databricks.labs.automl.params.DataGeneration;
import com.databricks.labs.automl.params.FeatureImportanceOutput;
import com.databricks.labs.automl.params.FeatureImportancePredictionOutput;
import com.databricks.labs.automl.params.FeatureImportanceReturn;
import com.databricks.labs.automl.params.GenerationalReport;
import com.databricks.labs.automl.params.GenericModelReturn;
import com.databricks.labs.automl.params.PredictionOutput;
import com.databricks.labs.automl.params.RandomForestModelsWithResults;
import com.databricks.labs.automl.params.TreeSplitReport;
import com.databricks.labs.automl.params.TunerOutput;
import com.databricks.labs.automl.reports.DecisionTreeSplits;
import com.databricks.labs.automl.reports.RandomForestFeatureImportance;
import com.databricks.labs.automl.tracking.MLFlowReportStructure;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import scala.Array$;
import scala.Predef$;
import scala.Tuple3;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: ManualRunner.scala */
@ScalaSignature(bytes = "\u0006\u0001\u001d3AAC\u0006\u0001)!A\u0011\u0004\u0001B\u0001B\u0003%!\u0004C\u0003!\u0001\u0011\u0005\u0011\u0005C\u0003%\u0001\u0011\u0005S\u0005C\u0003*\u0001\u0011\u0005#\u0006C\u0003/\u0001\u0011\u0005s\u0006C\u00034\u0001\u0011\u0005C\u0007C\u00039\u0001\u0011\u0005\u0013\bC\u0003>\u0001\u0011\u0005c\bC\u0003C\u0001\u0011\u00053I\u0001\u0007NC:,\u0018\r\u001c*v]:,'O\u0003\u0002\r\u001b\u00051\u0011-\u001e;p[2T!AD\b\u0002\t1\f'm\u001d\u0006\u0003!E\t!\u0002Z1uC\n\u0014\u0018nY6t\u0015\u0005\u0011\u0012aA2p[\u000e\u00011C\u0001\u0001\u0016!\t1r#D\u0001\f\u0013\tA2B\u0001\tBkR|W.\u0019;j_:\u0014VO\u001c8fe\u0006YA-\u0019;b!\u0006LHn\\1e!\tYb$D\u0001\u001d\u0015\ti2\"\u0001\u0004qCJ\fWn]\u0005\u0003?q\u0011a\u0002R1uC\u001e+g.\u001a:bi&|g.\u0001\u0004=S:LGO\u0010\u000b\u0003E\r\u0002\"A\u0006\u0001\t\u000be\u0011\u0001\u0019\u0001\u000e\u00023\u0015D\b\u000f\\8sK\u001a+\u0017\r^;sK&k\u0007o\u001c:uC:\u001cWm\u001d\u000b\u0002MA\u00111dJ\u0005\u0003Qq\u0011qCR3biV\u0014X-S7q_J$\u0018M\\2f%\u0016$XO\u001d8\u0002\u0007I,h\u000eF\u0001,!\tYB&\u0003\u0002.9\t\u0001\u0012)\u001e;p[\u0006$\u0018n\u001c8PkR\u0004X\u000f^\u0001\u0017O\u0016tWM]1uK\u0012+7-[:j_:\u001c\u0006\u000f\\5ugR\t\u0001\u0007\u0005\u0002\u001cc%\u0011!\u0007\b\u0002\u0010)J,Wm\u00159mSR\u0014V\r]8si\u0006)\"/\u001e8XSRDg)Z1ukJ,7)\u001e7mS:<G#A\u001b\u0011\u0005m1\u0014BA\u001c\u001d\u0005]1U-\u0019;ve\u0016LU\u000e]8si\u0006t7-Z(viB,H/A\u0010sk:4U-\u0019;ve\u0016\u001cU\u000f\u001c7j]\u001e<\u0016\u000e\u001e5Qe\u0016$\u0017n\u0019;j_:$\u0012A\u000f\t\u00037mJ!\u0001\u0010\u000f\u0003C\u0019+\u0017\r^;sK&k\u0007o\u001c:uC:\u001cW\r\u0015:fI&\u001cG/[8o\u001fV$\b/\u001e;\u0002#I,hnV5uQB\u0013X\rZ5di&|g\u000eF\u0001@!\tY\u0002)\u0003\u0002B9\t\u0001\u0002K]3eS\u000e$\u0018n\u001c8PkR\u0004X\u000f^\u0001\u0017eVtw+\u001b;i\u0007>tg-^:j_:\u0014V\r]8siR\tA\t\u0005\u0002\u001c\u000b&\u0011a\t\b\u0002\u0010\u0007>tg-^:j_:|U\u000f\u001e9vi\u0002")
/* loaded from: input_file:com/databricks/labs/automl/ManualRunner.class */
public class ManualRunner extends AutomationRunner {
    private final DataGeneration dataPayload;

    @Override // com.databricks.labs.automl.AutomationRunner
    public FeatureImportanceReturn exploreFeatureImportances() {
        Tuple3<RandomForestModelsWithResults, Dataset<Row>, String[]> runFeatureImportances = new RandomForestFeatureImportance(this.dataPayload.data(), _featureImportancesConfig(), this.dataPayload.modelType()).setCutoffType(_mainConfig().featureImportanceCutoffType()).setCutoffValue(_mainConfig().featureImportanceCutoffValue()).runFeatureImportances(this.dataPayload.fields());
        return new FeatureImportanceReturn((RandomForestModelsWithResults) runFeatureImportances._1(), (Dataset) runFeatureImportances._2(), (String[]) runFeatureImportances._3(), this.dataPayload.modelType());
    }

    @Override // com.databricks.labs.automl.AutomationRunner
    public AutomationOutput run() {
        final TunerOutput executeTuning = executeTuning(this.dataPayload, executeTuning$default$2());
        final ManualRunner manualRunner = null;
        return new AutomationOutput(manualRunner, executeTuning) { // from class: com.databricks.labs.automl.ManualRunner$$anon$1
            private final TunerOutput tunerResult$1;

            @Override // com.databricks.labs.automl.params.Output
            public GenericModelReturn[] modelReport() {
                return this.tunerResult$1.modelReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public GenerationalReport[] generationReport() {
                return this.tunerResult$1.generationReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> modelReportDataFrame() {
                return this.tunerResult$1.modelReportDataFrame();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> generationReportDataFrame() {
                return this.tunerResult$1.generationReportDataFrame();
            }

            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(executeTuning.mlFlowOutput());
                this.tunerResult$1 = executeTuning;
            }
        };
    }

    @Override // com.databricks.labs.automl.AutomationRunner
    public TreeSplitReport generateDecisionSplits() {
        return new DecisionTreeSplits(this.dataPayload.data(), _treeSplitsConfig(), this.dataPayload.modelType()).runTreeSplitAnalysis(this.dataPayload.fields());
    }

    @Override // com.databricks.labs.automl.AutomationRunner
    public FeatureImportanceOutput runWithFeatureCulling() {
        final FeatureImportanceReturn exploreFeatureImportances = exploreFeatureImportances();
        final AutomationOutput run = ((AutomationRunner) new AutomationRunner(this.dataPayload.data().select(Predef$.MODULE$.wrapRefArray((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(exploreFeatureImportances.fields())).$colon$plus(_mainConfig().labelCol(), ClassTag$.MODULE$.apply(String.class)))).map(str -> {
            return functions$.MODULE$.col(str);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))))).setMainConfig(_mainConfig())).run();
        final ManualRunner manualRunner = null;
        return new FeatureImportanceOutput(manualRunner, exploreFeatureImportances, run) { // from class: com.databricks.labs.automl.ManualRunner$$anon$2
            private final AutomationOutput runResults$1;

            @Override // com.databricks.labs.automl.params.Output
            public GenericModelReturn[] modelReport() {
                return this.runResults$1.modelReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public GenerationalReport[] generationReport() {
                return this.runResults$1.generationReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> modelReportDataFrame() {
                return this.runResults$1.modelReportDataFrame();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> generationReportDataFrame() {
                return this.runResults$1.generationReportDataFrame();
            }

            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(exploreFeatureImportances.data(), run.mlFlowOutput());
                this.runResults$1 = run;
            }
        };
    }

    @Override // com.databricks.labs.automl.AutomationRunner
    public FeatureImportancePredictionOutput runFeatureCullingWithPrediction() {
        final FeatureImportanceReturn exploreFeatureImportances = exploreFeatureImportances();
        String[] strArr = (String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(exploreFeatureImportances.fields())).$colon$plus(_mainConfig().labelCol(), ClassTag$.MODULE$.apply(String.class));
        Dataset select = this.dataPayload.data().select(Predef$.MODULE$.wrapRefArray((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(strArr)).map(str -> {
            return functions$.MODULE$.col(str);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
        DataGeneration dataGeneration = new DataGeneration(select, strArr, this.dataPayload.modelType());
        AutomationRunner automationRunner = (AutomationRunner) new AutomationRunner(select).setMainConfig(_mainConfig());
        final TunerOutput executeTuning = automationRunner.executeTuning(dataGeneration, automationRunner.executeTuning$default$2());
        final Dataset<Row> predictFromBestModel = predictFromBestModel(executeTuning.modelReport(), executeTuning.rawData(), executeTuning.modelSelection());
        final ManualRunner manualRunner = null;
        return new FeatureImportancePredictionOutput(manualRunner, exploreFeatureImportances, predictFromBestModel, executeTuning) { // from class: com.databricks.labs.automl.ManualRunner$$anon$3
            private final TunerOutput runResults$2;

            @Override // com.databricks.labs.automl.params.Output
            public GenericModelReturn[] modelReport() {
                return this.runResults$2.modelReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public GenerationalReport[] generationReport() {
                return this.runResults$2.generationReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> modelReportDataFrame() {
                return this.runResults$2.modelReportDataFrame();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> generationReportDataFrame() {
                return this.runResults$2.generationReportDataFrame();
            }

            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(exploreFeatureImportances.data(), predictFromBestModel, executeTuning.mlFlowOutput());
                this.runResults$2 = executeTuning;
            }
        };
    }

    @Override // com.databricks.labs.automl.AutomationRunner
    public PredictionOutput runWithPrediction() {
        final TunerOutput executeTuning = executeTuning(this.dataPayload, executeTuning$default$2());
        final Dataset<Row> predictFromBestModel = predictFromBestModel(executeTuning.modelReport(), executeTuning.rawData(), executeTuning.modelSelection());
        final ManualRunner manualRunner = null;
        return new PredictionOutput(manualRunner, predictFromBestModel, executeTuning) { // from class: com.databricks.labs.automl.ManualRunner$$anon$4
            private final TunerOutput tunerResult$2;

            @Override // com.databricks.labs.automl.params.Output
            public GenericModelReturn[] modelReport() {
                return this.tunerResult$2.modelReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public GenerationalReport[] generationReport() {
                return this.tunerResult$2.generationReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> modelReportDataFrame() {
                return this.tunerResult$2.modelReportDataFrame();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> generationReportDataFrame() {
                return this.tunerResult$2.generationReportDataFrame();
            }

            {
                this.tunerResult$2 = executeTuning;
                MLFlowReportStructure mlFlowOutput = executeTuning.mlFlowOutput();
            }
        };
    }

    @Override // com.databricks.labs.automl.AutomationRunner
    public ConfusionOutput runWithConfusionReport() {
        final PredictionOutput runWithPrediction = runWithPrediction();
        final Dataset agg = runWithPrediction.dataWithPredictions().select("prediction", Predef$.MODULE$.wrapRefArray(new String[]{_labelCol()})).groupBy("prediction", Predef$.MODULE$.wrapRefArray(new String[]{_labelCol()})).agg(functions$.MODULE$.count("*").alias("count"), Predef$.MODULE$.wrapRefArray(new Column[0]));
        final ManualRunner manualRunner = null;
        return new ConfusionOutput(manualRunner, runWithPrediction, agg) { // from class: com.databricks.labs.automl.ManualRunner$$anon$5
            private final PredictionOutput predictionPayload$1;

            @Override // com.databricks.labs.automl.params.Output
            public GenericModelReturn[] modelReport() {
                return this.predictionPayload$1.modelReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public GenerationalReport[] generationReport() {
                return this.predictionPayload$1.generationReport();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> modelReportDataFrame() {
                return this.predictionPayload$1.modelReportDataFrame();
            }

            @Override // com.databricks.labs.automl.params.Output
            public Dataset<Row> generationReportDataFrame() {
                return this.predictionPayload$1.generationReportDataFrame();
            }

            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(runWithPrediction.dataWithPredictions(), agg, runWithPrediction.mlFlowOutput());
                this.predictionPayload$1 = runWithPrediction;
            }
        };
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public ManualRunner(DataGeneration dataGeneration) {
        super(dataGeneration.data());
        this.dataPayload = dataGeneration;
    }
}
