package com.databricks.labs.automl.exploration.analysis.trees.extractors;

import com.databricks.labs.automl.exploration.analysis.common.AnalysisUtilities$;
import com.databricks.labs.automl.exploration.analysis.common.structures.NoParam$;
import com.databricks.labs.automl.exploration.analysis.common.structures.ParamWrapper;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$DecisionTreeClassifierExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$DecisionTreeRegressorExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$GBTClassifierExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$GBTRegressorExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$RandomForestClassificationExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$RandomForestRegressorExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.VisualizationOutput;
import com.databricks.labs.automl.exploration.analysis.trees.scripts.HTMLGenerators$;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import scala.Array$;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;

/* compiled from: VisualizationExtractor.scala */
/* loaded from: input_file:com/databricks/labs/automl/exploration/analysis/trees/extractors/VisualizationExtractor$.class */
public final class VisualizationExtractor$ {
    public static VisualizationExtractor$ MODULE$;

    static {
        new VisualizationExtractor$();
    }

    public <T> VisualizationOutput[] extractModelVisualizationDataFromModel(T t, String str, ParamWrapper<VectorAssembler> paramWrapper, ParamWrapper<String[]> paramWrapper2) {
        String[] visualizationData;
        VectorAssembler vectorAssembler;
        Predef$.MODULE$.require(paramWrapper.mo73asOption().isDefined() || paramWrapper2.mo73asOption().isDefined(), () -> {
            return new StringBuilder(135).append("Either the VectorAssembler used to build the model to test must be supplied or an Array of field").append("names from the Vector must be supplied.").toString();
        });
        Some mo73asOption = paramWrapper.mo73asOption();
        String[] inputCols = (!(mo73asOption instanceof Some) || (vectorAssembler = (VectorAssembler) mo73asOption.value()) == null) ? (String[]) paramWrapper2.mo73asOption().get() : vectorAssembler.getInputCols();
        if (t instanceof DecisionTreeRegressionModel) {
            visualizationData = new TreeExtractor((DecisionTreeRegressionModel) t, TreeModelExtractor$DecisionTreeRegressorExtractor$.MODULE$).getVisualizationData(inputCols);
        } else if (t instanceof DecisionTreeClassificationModel) {
            visualizationData = new TreeExtractor((DecisionTreeClassificationModel) t, TreeModelExtractor$DecisionTreeClassifierExtractor$.MODULE$).getVisualizationData(inputCols);
        } else if (t instanceof RandomForestRegressionModel) {
            visualizationData = new TreeExtractor((RandomForestRegressionModel) t, TreeModelExtractor$RandomForestRegressorExtractor$.MODULE$).getVisualizationData(inputCols);
        } else if (t instanceof RandomForestClassificationModel) {
            visualizationData = new TreeExtractor((RandomForestClassificationModel) t, TreeModelExtractor$RandomForestClassificationExtractor$.MODULE$).getVisualizationData(inputCols);
        } else if (t instanceof GBTRegressionModel) {
            visualizationData = new TreeExtractor((GBTRegressionModel) t, TreeModelExtractor$GBTRegressorExtractor$.MODULE$).getVisualizationData(inputCols);
        } else {
            if (!(t instanceof GBTClassificationModel)) {
                throw new UnsupportedOperationException("The model supplied is not supported.");
            }
            visualizationData = new TreeExtractor((GBTClassificationModel) t, TreeModelExtractor$GBTClassifierExtractor$.MODULE$).getVisualizationData(inputCols);
        }
        return (VisualizationOutput[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(visualizationData)).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
            return new VisualizationOutput(tuple2._2$mcI$sp(), HTMLGenerators$.MODULE$.createD3TreeVisualization((String) tuple2._1(), str, ModelConfigExtractor$.MODULE$.extractModelData(t)));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(VisualizationOutput.class)));
    }

    public <T> ParamWrapper<VectorAssembler> extractModelVisualizationDataFromModel$default$3() {
        return NoParam$.MODULE$;
    }

    public <T> ParamWrapper<String[]> extractModelVisualizationDataFromModel$default$4() {
        return NoParam$.MODULE$;
    }

    public VisualizationOutput[] extractModelVisualizationDataFromPipeline(PipelineModel pipelineModel, String str) {
        Object last = Predef$.MODULE$.genericArrayOps(AnalysisUtilities$.MODULE$.getModelFromPipeline(pipelineModel)).last();
        return (VisualizationOutput[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(new PipelineExtractor(pipelineModel).getVisualizationData())).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
            return new VisualizationOutput(tuple2._2$mcI$sp(), HTMLGenerators$.MODULE$.createD3TreeVisualization((String) tuple2._1(), str, ModelConfigExtractor$.MODULE$.extractModelData(last)));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(VisualizationOutput.class)));
    }

    private VisualizationExtractor$() {
        MODULE$ = this;
    }
}
