package org.riversun.ml.spark;

import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructType;

/* loaded from: input_file:org/riversun/ml/spark/FeatureImportance.class */
public class FeatureImportance {
    private PredictionModel<Vector, ?> model;
    private StructType schema;
    private Order sort;

    /* loaded from: input_file:org/riversun/ml/spark/FeatureImportance$Builder.class */
    public static class Builder {
        private PredictionModel<Vector, ?> model;
        private StructType schema;
        private Order sort = Order.DESCENDING;

        public Builder(PredictionModel<Vector, ?> predictionModel, StructType structType) {
            this.model = predictionModel;
            this.schema = structType;
        }

        public Builder sort(Order order) {
            this.sort = order;
            return this;
        }

        public FeatureImportance build() {
            if (this.model == null || this.schema == null) {
                throw new NullPointerException();
            }
            return new FeatureImportance(this);
        }
    }

    /* loaded from: input_file:org/riversun/ml/spark/FeatureImportance$Order.class */
    public enum Order {
        ASCENDING,
        DESCENDING,
        UNSORTED
    }

    private FeatureImportance(Builder builder) {
        this.model = builder.model;
        this.schema = builder.schema;
        this.sort = builder.sort;
    }

    public List<Importance> getResult() {
        Vector featureImportances;
        if (this.model instanceof GBTRegressionModel) {
            featureImportances = this.model.featureImportances();
        } else if (this.model instanceof GBTClassificationModel) {
            featureImportances = this.model.featureImportances();
        } else if (this.model instanceof RandomForestRegressionModel) {
            featureImportances = this.model.featureImportances();
        } else if (this.model instanceof RandomForestClassificationModel) {
            featureImportances = this.model.featureImportances();
        } else if (this.model instanceof DecisionTreeRegressionModel) {
            featureImportances = this.model.featureImportances();
        } else {
            if (!(this.model instanceof DecisionTreeClassificationModel)) {
                throw new RuntimeException(this.model + " doesn't have feature importances.You should specify an instance of GBTRegressionModel,GBTClassificationModel,RandomForestRegressionModel,RandomForestClassificationModel,DecisionTreeRegressionModel,DecisionTreeClassificationModel");
            }
            featureImportances = this.model.featureImportances();
        }
        return zipImportances(featureImportances, this.model.getFeaturesCol(), this.schema);
    }

    private List<Importance> zipImportances(Vector vector, String str, StructType structType) {
        List<Importance> list;
        Metadata metadata = structType.fields()[((Integer) structType.getFieldIndex(str).get()).intValue()].metadata().getMetadata("ml_attr").getMetadata("attrs");
        HashMap hashMap = new HashMap();
        Collector map = Collectors.toMap(metadata2 -> {
            return Integer.valueOf((int) metadata2.getLong("idx"));
        }, metadata3 -> {
            return metadata3.getString("name");
        }, (str2, str3) -> {
            return str3;
        }, HashMap::new);
        for (String str4 : new String[]{"nominal", "numeric", "binary"}) {
            if (metadata.contains(str4)) {
                hashMap.putAll((Map) Arrays.stream(metadata.getMetadataArray(str4)).collect(map));
            }
        }
        double[] array = vector.toArray();
        List<Importance> list2 = (List) IntStream.range(0, array.length).mapToObj(i -> {
            return new Importance(i, (String) hashMap.get(Integer.valueOf(i)), array[i], 0);
        }).collect(Collectors.toList());
        List<Importance> list3 = (List) list2.stream().sorted(Comparator.comparingDouble(importance -> {
            return importance.score;
        }).reversed()).collect(Collectors.toList());
        for (int i2 = 0; i2 < list3.size(); i2++) {
            list3.get(i2).rank = i2;
        }
        switch (this.sort) {
            case ASCENDING:
                list = (List) list3.stream().sorted(Comparator.comparingDouble(importance2 -> {
                    return importance2.score;
                })).collect(Collectors.toList());
                break;
            case DESCENDING:
                list = list3;
                break;
            case UNSORTED:
            default:
                list = list2;
                break;
        }
        return list;
    }
}
