package io.cdap.mmds.manager;

import com.google.common.base.Ascii;
import com.google.common.base.Joiner;
import io.cdap.cdap.api.ServiceDiscoverer;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.plugin.PluginProperties;
import io.cdap.cdap.api.spark.service.SparkHttpServicePluginContext;
import io.cdap.cdap.api.spark.sql.DataFrames;
import io.cdap.cdap.etl.api.PipelineConfigurer;
import io.cdap.cdap.etl.api.Transform;
import io.cdap.mmds.NullableMath;
import io.cdap.mmds.data.ColumnSplitStats;
import io.cdap.mmds.data.DataSplitInfo;
import io.cdap.mmds.splitter.DataSplitResult;
import io.cdap.mmds.splitter.DatasetSplitter;
import io.cdap.mmds.splitter.ToCatHisto;
import io.cdap.mmds.splitter.ToDoubleValues;
import io.cdap.mmds.splitter.ToNumericHisto;
import io.cdap.mmds.stats.CategoricalHisto;
import io.cdap.mmds.stats.NumericHisto;
import io.cdap.mmds.stats.NumericStats;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.twill.filesystem.Location;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:io/cdap/mmds/manager/DataSplitStatsGenerator.class */
public class DataSplitStatsGenerator implements AutoCloseable {
    private static final Logger LOG = LoggerFactory.getLogger(DataSplitStatsGenerator.class);
    private final SparkSession sparkSession;
    private final DatasetSplitter splitter;
    private final SparkHttpServicePluginContext pluginContext;
    private final ServiceDiscoverer serviceDiscoverer;
    private final PipelineConfigurer pipelineConfigurer;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.cdap.mmds.manager.DataSplitStatsGenerator$1, reason: invalid class name */
    /* loaded from: input_file:io/cdap/mmds/manager/DataSplitStatsGenerator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type = new int[Schema.Type.values().length];

        static {
            try {
                $SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type[Schema.Type.BOOLEAN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type[Schema.Type.STRING.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type[Schema.Type.INT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type[Schema.Type.LONG.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type[Schema.Type.FLOAT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type[Schema.Type.DOUBLE.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    public DataSplitStatsGenerator(SparkSession sparkSession, DatasetSplitter datasetSplitter, SparkHttpServicePluginContext sparkHttpServicePluginContext, ServiceDiscoverer serviceDiscoverer) {
        this.sparkSession = sparkSession;
        this.splitter = datasetSplitter;
        this.pluginContext = sparkHttpServicePluginContext;
        this.serviceDiscoverer = serviceDiscoverer;
        this.pipelineConfigurer = new WranglerPipelineConfigurer(sparkHttpServicePluginContext);
    }

    public DataSplitResult split(DataSplitInfo dataSplitInfo) throws IOException {
        Transform transform = (Transform) this.pluginContext.usePlugin(Transform.PLUGIN_TYPE, "Wrangler", "wrangler", PluginProperties.builder().add("schema", dataSplitInfo.getDataSplit().getSchema().toString()).add("field", "*").add("directives", Joiner.on("\n").join((Iterable<?>) dataSplitInfo.getDataSplit().getDirectives())).add("threshold", "-1").add("precondition", "false").build());
        if (transform == null) {
            throw new IllegalStateException("Could not find wrangler plugin. Please make sure it has been deployed with MMDS as a parent.");
        }
        transform.configurePipeline(this.pipelineConfigurer);
        Schema schema = dataSplitInfo.getDataSplit().getSchema();
        Dataset<Row> cache = this.sparkSession.createDataFrame(this.sparkSession.read().format("text").load(dataSplitInfo.getExperiment().getSrcpath()).javaRDD().flatMap(new WranglerFunction(schema, this.pluginContext, this.serviceDiscoverer)), DataFrames.toDataType(schema)).cache();
        long convert = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        Dataset<Row>[] split = this.splitter.split(cache, dataSplitInfo.getDataSplit().getParams());
        Dataset<Row> cache2 = split[0].cache();
        Dataset<Row> cache3 = split[1].cache();
        Location splitLocation = dataSplitInfo.getSplitLocation();
        Location append = splitLocation.append("train");
        Location append2 = splitLocation.append("test");
        String path = append.toURI().getPath();
        String path2 = append2.toURI().getPath();
        cache2.write().mode(SaveMode.Overwrite).format("parquet").save(path);
        cache3.write().mode(SaveMode.Overwrite).format("parquet").save(path2);
        long convert2 = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to split = {} seconds", Long.valueOf(convert2 - convert));
        List<ColumnSplitStats> stats = getStats(cache2, cache3, dataSplitInfo.getDataSplit().getSchema());
        LOG.info("Time to get stats = {} seconds", Long.valueOf(TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS) - convert2));
        return new DataSplitResult(path, path2, stats);
    }

    private List<ColumnSplitStats> getStats(Dataset<Row> dataset, Dataset<Row> dataset2, Schema schema) {
        ArrayList arrayList = new ArrayList(schema.getFields().size());
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        for (Schema.Field field : schema.getFields()) {
            String name = field.getName();
            Schema schema2 = field.getSchema();
            Schema.Type type = (schema2.isNullable() ? schema2.getNonNullable() : schema2).getType();
            Column column = new Column(name);
            switch (AnonymousClass1.$SwitchMap$io$cdap$cdap$api$data$schema$Schema$Type[type.ordinal()]) {
                case Ascii.SOH /* 1 */:
                    arrayList2.add(column.cast(DataTypes.StringType));
                    arrayList3.add(name);
                    break;
                case 2:
                    arrayList2.add(column);
                    arrayList3.add(name);
                    break;
                case Ascii.ETX /* 3 */:
                case 4:
                case Ascii.ENQ /* 5 */:
                    arrayList4.add(column.cast(DataTypes.DoubleType));
                    arrayList5.add(name);
                    break;
                case Ascii.ACK /* 6 */:
                    arrayList4.add(column);
                    arrayList5.add(name);
                    break;
            }
        }
        int size = arrayList2.size();
        int size2 = arrayList4.size();
        Dataset select = dataset.select((Column[]) arrayList2.toArray(new Column[size]));
        Dataset select2 = dataset2.select((Column[]) arrayList2.toArray(new Column[size]));
        Dataset select3 = dataset.select((Column[]) arrayList4.toArray(new Column[size2]));
        Dataset select4 = dataset2.select((Column[]) arrayList4.toArray(new Column[size2]));
        long convert = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        Map collectAsMap = select.javaRDD().flatMapToPair(new ToCatHisto(arrayList3)).reduceByKey((v0, v1) -> {
            return v0.merge(v1);
        }, arrayList2.size()).collectAsMap();
        Map collectAsMap2 = select2.javaRDD().flatMapToPair(new ToCatHisto(arrayList3)).reduceByKey((v0, v1) -> {
            return v0.merge(v1);
        }, arrayList2.size()).collectAsMap();
        long convert2 = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to get categorical stats = {} seconds", Long.valueOf(convert2 - convert));
        for (Map.Entry entry : collectAsMap.entrySet()) {
            String str = (String) entry.getKey();
            arrayList.add(new ColumnSplitStats(str, (CategoricalHisto) entry.getValue(), (CategoricalHisto) collectAsMap2.get(str)));
        }
        JavaPairRDD flatMapToPair = select3.javaRDD().flatMapToPair(new ToDoubleValues(arrayList5));
        JavaPairRDD flatMapToPair2 = select4.javaRDD().flatMapToPair(new ToDoubleValues(arrayList5));
        Map collectAsMap3 = flatMapToPair.mapValues(NumericStats::new).reduceByKey((v0, v1) -> {
            return v0.merge(v1);
        }, size2).collectAsMap();
        Map collectAsMap4 = flatMapToPair2.mapValues(NumericStats::new).reduceByKey((v0, v1) -> {
            return v0.merge(v1);
        }, size2).collectAsMap();
        HashMap hashMap = new HashMap();
        for (Map.Entry entry2 : collectAsMap3.entrySet()) {
            String str2 = (String) entry2.getKey();
            NumericStats numericStats = (NumericStats) entry2.getValue();
            NumericStats numericStats2 = (NumericStats) collectAsMap4.get(str2);
            hashMap.put(str2, new Tuple2(NullableMath.min(numericStats.getMin(), numericStats2.getMin()), NullableMath.max(numericStats.getMax(), numericStats2.getMax())));
        }
        long convert3 = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to get numeric stats, 1st pass = {} seconds", Long.valueOf(convert3 - convert2));
        Map collectAsMap5 = flatMapToPair.mapToPair(new ToNumericHisto(hashMap)).reduceByKey((v0, v1) -> {
            return v0.merge(v1);
        }, arrayList4.size()).collectAsMap();
        Map collectAsMap6 = flatMapToPair2.mapToPair(new ToNumericHisto(hashMap)).reduceByKey((v0, v1) -> {
            return v0.merge(v1);
        }, arrayList4.size()).collectAsMap();
        LOG.info("Time to get numeric stats, 2nd pass = {} seconds", Long.valueOf(TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS) - convert3));
        for (Map.Entry entry3 : collectAsMap5.entrySet()) {
            String str3 = (String) entry3.getKey();
            arrayList.add(new ColumnSplitStats(str3, (NumericHisto) entry3.getValue(), (NumericHisto) collectAsMap6.get(str3)));
        }
        return arrayList;
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        this.pluginContext.close();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 103785528:
                if (implMethodName.equals("merge")) {
                    z = false;
                    break;
                }
                break;
            case 1818100338:
                if (implMethodName.equals("<init>")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/CategoricalHisto") && serializedLambda.getImplMethodSignature().equals("(Lio/cdap/mmds/stats/CategoricalHisto;)Lio/cdap/mmds/stats/CategoricalHisto;")) {
                    return (v0, v1) -> {
                        return v0.merge(v1);
                    };
                }
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/CategoricalHisto") && serializedLambda.getImplMethodSignature().equals("(Lio/cdap/mmds/stats/CategoricalHisto;)Lio/cdap/mmds/stats/CategoricalHisto;")) {
                    return (v0, v1) -> {
                        return v0.merge(v1);
                    };
                }
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/NumericStats") && serializedLambda.getImplMethodSignature().equals("(Lio/cdap/mmds/stats/NumericStats;)Lio/cdap/mmds/stats/NumericStats;")) {
                    return (v0, v1) -> {
                        return v0.merge(v1);
                    };
                }
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/NumericStats") && serializedLambda.getImplMethodSignature().equals("(Lio/cdap/mmds/stats/NumericStats;)Lio/cdap/mmds/stats/NumericStats;")) {
                    return (v0, v1) -> {
                        return v0.merge(v1);
                    };
                }
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/NumericHisto") && serializedLambda.getImplMethodSignature().equals("(Lio/cdap/mmds/stats/NumericHisto;)Lio/cdap/mmds/stats/NumericHisto;")) {
                    return (v0, v1) -> {
                        return v0.merge(v1);
                    };
                }
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/NumericHisto") && serializedLambda.getImplMethodSignature().equals("(Lio/cdap/mmds/stats/NumericHisto;)Lio/cdap/mmds/stats/NumericHisto;")) {
                    return (v0, v1) -> {
                        return v0.merge(v1);
                    };
                }
                break;
            case Ascii.SOH /* 1 */:
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/NumericStats") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;)V")) {
                    return NumericStats::new;
                }
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/cdap/mmds/stats/NumericStats") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;)V")) {
                    return NumericStats::new;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
