package com.databricks.labs.automl.exploration.analysis.trees.extractors;

import com.databricks.labs.automl.exploration.analysis.common.AnalysisUtilities$;
import com.databricks.labs.automl.exploration.analysis.common.encoders.HierarchicalEncoding$;
import com.databricks.labs.automl.exploration.analysis.common.structures.ExtractorType;
import com.databricks.labs.automl.exploration.analysis.common.structures.PayloadDetermination$;
import com.databricks.labs.automl.exploration.analysis.common.structures.PipelineNodeData;
import com.databricks.labs.automl.exploration.analysis.common.structures.PipelineReport;
import com.databricks.labs.automl.exploration.analysis.common.structures.StringIndexerMappings;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$DecisionTreeClassifierExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$DecisionTreeRegressorExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$GBTClassifierExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$GBTRegressorExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$RandomForestClassificationExtractor$;
import com.databricks.labs.automl.exploration.analysis.common.structures.TreeModelExtractor$RandomForestRegressorExtractor$;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.ml.tree.Node;
import scala.Array$;
import scala.Enumeration;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: PipelineExtractor.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00114Q\u0001C\u0005\u0001\u001beA\u0001\u0002\t\u0001\u0003\u0002\u0003\u0006IA\t\u0005\u0006]\u0001!\ta\f\u0005\u0006g\u0001!I\u0001\u000e\u0005\u0006\u0011\u0002!I!\u0013\u0005\u0006\u0017\u0002!\t\u0001\u0014\u0005\u0006-\u0002!\ta\u0016\u0005\u0006C\u0002!\tA\u0019\u0002\u0012!&\u0004X\r\\5oK\u0016CHO]1di>\u0014(B\u0001\u0006\f\u0003))\u0007\u0010\u001e:bGR|'o\u001d\u0006\u0003\u00195\tQ\u0001\u001e:fKNT!AD\b\u0002\u0011\u0005t\u0017\r\\=tSNT!\u0001E\t\u0002\u0017\u0015D\b\u000f\\8sCRLwN\u001c\u0006\u0003%M\ta!Y;u_6d'B\u0001\u000b\u0016\u0003\u0011a\u0017MY:\u000b\u0005Y9\u0012A\u00033bi\u0006\u0014'/[2lg*\t\u0001$A\u0002d_6\u001c\"\u0001\u0001\u000e\u0011\u0005mqR\"\u0001\u000f\u000b\u0003u\tQa]2bY\u0006L!a\b\u000f\u0003\r\u0005s\u0017PU3g\u0003!\u0001\u0018\u000e]3mS:,7\u0001\u0001\t\u0003G1j\u0011\u0001\n\u0006\u0003K\u0019\n!!\u001c7\u000b\u0005\u001dB\u0013!B:qCJ\\'BA\u0015+\u0003\u0019\t\u0007/Y2iK*\t1&A\u0002pe\u001eL!!\f\u0013\u0003\u001bAK\u0007/\u001a7j]\u0016lu\u000eZ3m\u0003\u0019a\u0014N\\5u}Q\u0011\u0001G\r\t\u0003c\u0001i\u0011!\u0003\u0005\u0006A\t\u0001\rAI\u0001\u0017O\u0016$8\u000b\u001e:j]\u001eLe\u000eZ3yKJd\u0015MY3mgR\u0011Qg\u0012\t\u0005mu\u00025I\u0004\u00028wA\u0011\u0001\bH\u0007\u0002s)\u0011!(I\u0001\u0007yI|w\u000e\u001e \n\u0005qb\u0012A\u0002)sK\u0012,g-\u0003\u0002?\u007f\t\u0019Q*\u00199\u000b\u0005qb\u0002C\u0001\u001cB\u0013\t\u0011uH\u0001\u0004TiJLgn\u001a\t\u0005mu\"\u0005\t\u0005\u0002\u001c\u000b&\u0011a\t\b\u0002\u0004\u0013:$\b\"\u0002\u0011\u0004\u0001\u0004\u0011\u0013A\u0006:fg>dg/Z%oI\u0016DXM]'baBLgnZ:\u0016\u0003)\u0003BAN\u001fE\u0007\u0006yQ\r\u001f;sC\u000e$(k\\8u\u001d>$W-F\u0001N!\rYb\nU\u0005\u0003\u001fr\u0011Q!\u0011:sCf\u0004\"!\u0015+\u000e\u0003IS!a\u0015\u0013\u0002\tQ\u0014X-Z\u0005\u0003+J\u0013AAT8eK\u0006QR\r\u001f;sC\u000e$\b+\u001b9fY&tW-\u00138g_Jl\u0017\r^5p]V\t\u0001\fE\u0002\u001c\u001df\u0003\"AW0\u000e\u0003mS!\u0001X/\u0002\u0015M$(/^2ukJ,7O\u0003\u0002_\u001b\u000511m\\7n_:L!\u0001Y.\u0003\u001dAK\u0007/\u001a7j]\u0016\u0014V\r]8si\u0006!r-\u001a;WSN,\u0018\r\\5{CRLwN\u001c#bi\u0006,\u0012a\u0019\t\u000479\u0003\u0005")
/* loaded from: input_file:com/databricks/labs/automl/exploration/analysis/trees/extractors/PipelineExtractor.class */
public class PipelineExtractor {
    private final PipelineModel pipeline;

    public Map<String, Map<Object, String>> com$databricks$labs$automl$exploration$analysis$trees$extractors$PipelineExtractor$$getStringIndexerLabels(PipelineModel pipelineModel) {
        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(pipelineModel.stages())).collect(new PipelineExtractor$$anonfun$com$databricks$labs$automl$exploration$analysis$trees$extractors$PipelineExtractor$$getStringIndexerLabels$1(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Map.class))))).flatten(Predef$.MODULE$.$conforms(), ClassTag$.MODULE$.apply(Tuple2.class)))).toMap(Predef$.MODULE$.$conforms());
    }

    private Map<Object, Map<Object, String>> resolveIndexerMappings() {
        Map<String, Map<Object, String>> com$databricks$labs$automl$exploration$analysis$trees$extractors$PipelineExtractor$$getStringIndexerLabels = com$databricks$labs$automl$exploration$analysis$trees$extractors$PipelineExtractor$$getStringIndexerLabels(this.pipeline);
        Map<String, Object> finalFeaturesFromPipeline = AnalysisUtilities$.MODULE$.getFinalFeaturesFromPipeline(this.pipeline);
        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) com$databricks$labs$automl$exploration$analysis$trees$extractors$PipelineExtractor$$getStringIndexerLabels.keys().toArray(ClassTag$.MODULE$.apply(String.class)))).map(str -> {
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(finalFeaturesFromPipeline.apply(str)), com$databricks$labs$automl$exploration$analysis$trees$extractors$PipelineExtractor$$getStringIndexerLabels.apply(str));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).toMap(Predef$.MODULE$.$conforms());
    }

    public Node[] extractRootNode() {
        Node[] extractRootNode;
        Object last = Predef$.MODULE$.genericArrayOps(AnalysisUtilities$.MODULE$.getModelFromPipeline(this.pipeline)).last();
        if (last instanceof RandomForestClassificationModel) {
            extractRootNode = new TreeExtractor((RandomForestClassificationModel) last, TreeModelExtractor$RandomForestClassificationExtractor$.MODULE$).extractRootNode();
        } else if (last instanceof RandomForestRegressionModel) {
            extractRootNode = new TreeExtractor((RandomForestRegressionModel) last, TreeModelExtractor$RandomForestRegressorExtractor$.MODULE$).extractRootNode();
        } else if (last instanceof DecisionTreeClassificationModel) {
            extractRootNode = new TreeExtractor((DecisionTreeClassificationModel) last, TreeModelExtractor$DecisionTreeClassifierExtractor$.MODULE$).extractRootNode();
        } else if (last instanceof DecisionTreeRegressionModel) {
            extractRootNode = new TreeExtractor((DecisionTreeRegressionModel) last, TreeModelExtractor$DecisionTreeRegressorExtractor$.MODULE$).extractRootNode();
        } else if (last instanceof GBTClassificationModel) {
            extractRootNode = new TreeExtractor((GBTClassificationModel) last, TreeModelExtractor$GBTClassifierExtractor$.MODULE$).extractRootNode();
        } else {
            if (!(last instanceof GBTRegressionModel)) {
                throw new MatchError(last);
            }
            extractRootNode = new TreeExtractor((GBTRegressionModel) last, TreeModelExtractor$GBTRegressorExtractor$.MODULE$).extractRootNode();
        }
        return extractRootNode;
    }

    public PipelineReport[] extractPipelineInformation() {
        Enumeration.Value payloadType = PayloadDetermination$.MODULE$.payloadType(this.pipeline);
        Map<Object, Map<Object, String>> resolveIndexerMappings = resolveIndexerMappings();
        return (PipelineReport[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(extractRootNode())).map(node -> {
            return Extractor$.MODULE$.extractRules(node, payloadType, new Some(resolveIndexerMappings));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ExtractorType.class))))).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
            return new PipelineReport(tuple2._2$mcI$sp(), (PipelineNodeData) tuple2._1());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(PipelineReport.class)));
    }

    public String[] getVisualizationData() {
        String[] pipelineVectorFields = AnalysisUtilities$.MODULE$.getPipelineVectorFields(this.pipeline);
        PipelineReport[] extractPipelineInformation = extractPipelineInformation();
        Option<StringIndexerMappings[]> stringIndexerMapping = AnalysisUtilities$.MODULE$.getStringIndexerMapping(this.pipeline);
        return (String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(extractPipelineInformation)).map(pipelineReport -> {
            return HierarchicalEncoding$.MODULE$.performJSEncoding(pipelineReport.data(), pipelineVectorFields, stringIndexerMapping);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
    }

    public PipelineExtractor(PipelineModel pipelineModel) {
        this.pipeline = pipelineModel;
    }
}
