package org.hpccsystems.spark;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import net.razorvine.pickle.Unpickler;
import org.apache.log4j.Logger;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.Partition;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.execution.python.EvaluatePython;
import org.apache.spark.sql.types.StructType;
import org.hpccsystems.commons.ecl.FieldDef;
import org.hpccsystems.dfs.client.DataPartition;
import org.hpccsystems.dfs.client.HpccRemoteFileReader;
import scala.collection.Iterator;
import scala.collection.JavaConverters;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.Buffer;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

/* loaded from: input_file:org/hpccsystems/spark/HpccRDD.class */
public class HpccRDD extends RDD<Row> implements Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger log = Logger.getLogger(HpccRDD.class.getName());
    private static final ClassTag<Row> CT_RECORD = ClassTag$.MODULE$.apply(Row.class);
    private InternalPartition[] parts;
    private FieldDef originalRecordDef;
    private FieldDef projectedRecordDef;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/hpccsystems/spark/HpccRDD$InternalPartition.class */
    public class InternalPartition implements Partition {
        private static final long serialVersionUID = 1;
        public DataPartition partition;

        private InternalPartition() {
        }

        public int hashCode() {
            return index();
        }

        public int index() {
            return this.partition.index();
        }
    }

    private static void registerPicklingFunctions() {
        EvaluatePython.registerPicklers();
        Unpickler.registerConstructor("pyspark.sql.types", "Row", new RowConstructor());
        Unpickler.registerConstructor("pyspark.sql.types", "_create_row", new RowConstructor());
    }

    public HpccRDD(SparkContext sparkContext, DataPartition[] dataPartitionArr, FieldDef fieldDef) {
        this(sparkContext, dataPartitionArr, fieldDef, fieldDef);
    }

    public HpccRDD(SparkContext sparkContext, DataPartition[] dataPartitionArr, FieldDef fieldDef, FieldDef fieldDef2) {
        super(sparkContext, new ArrayBuffer(), CT_RECORD);
        this.originalRecordDef = null;
        this.projectedRecordDef = null;
        this.parts = new InternalPartition[dataPartitionArr.length];
        for (int i = 0; i < dataPartitionArr.length; i++) {
            this.parts[i] = new InternalPartition();
            this.parts[i].partition = dataPartitionArr[i];
        }
        this.originalRecordDef = fieldDef;
        this.projectedRecordDef = fieldDef2;
    }

    public JavaRDD<Row> asJavaRDD() {
        return new JavaRDD<>(this, CT_RECORD);
    }

    public RDD<LabeledPoint> makeMLLibLabeledPoint(String str, String[] strArr) throws IllegalArgumentException {
        try {
            StructType sparkSchema = SparkSchemaTranslator.toSparkSchema(this.projectedRecordDef);
            int fieldIndex = sparkSchema.fieldIndex(str);
            int[] iArr = new int[strArr.length];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = sparkSchema.fieldIndex(strArr[i]);
            }
            return asJavaRDD().map(row -> {
                double d = row.getDouble(fieldIndex);
                double[] dArr = new double[iArr.length];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    dArr[i2] = row.getDouble(iArr[i2]);
                }
                return new LabeledPoint(d, new DenseVector(dArr));
            }).rdd();
        } catch (Exception e) {
            throw new IllegalArgumentException(e.getMessage());
        }
    }

    public RDD<Vector> makeMLLibVector(String[] strArr) throws IllegalArgumentException {
        try {
            StructType sparkSchema = SparkSchemaTranslator.toSparkSchema(this.projectedRecordDef);
            int[] iArr = new int[strArr.length];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = sparkSchema.fieldIndex(strArr[i]);
            }
            return asJavaRDD().map(row -> {
                double[] dArr = new double[iArr.length];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    dArr[i2] = row.getDouble(iArr[i2]);
                }
                return new DenseVector(dArr);
            }).rdd();
        } catch (Exception e) {
            throw new IllegalArgumentException(e.getMessage());
        }
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public InterruptibleIterator<Row> m3compute(Partition partition, TaskContext taskContext) {
        registerPicklingFunctions();
        InternalPartition internalPartition = (InternalPartition) partition;
        FieldDef fieldDef = this.originalRecordDef;
        FieldDef fieldDef2 = this.projectedRecordDef;
        if (fieldDef == null) {
            log.error("Original record defintion is null. Aborting.");
            return null;
        }
        if (fieldDef2 == null) {
            log.error("Projected record defintion is null. Aborting.");
            return null;
        }
        try {
            HpccRemoteFileReader hpccRemoteFileReader = new HpccRemoteFileReader(internalPartition.partition, fieldDef, new GenericRowRecordBuilder(fieldDef2));
            taskContext.addTaskCompletionListener(taskContext2 -> {
                if (hpccRemoteFileReader != null) {
                    try {
                        hpccRemoteFileReader.close();
                    } catch (Exception e) {
                    }
                }
            });
            return new InterruptibleIterator<>(taskContext, (Iterator) JavaConverters.asScalaIteratorConverter(hpccRemoteFileReader).asScala());
        } catch (Exception e) {
            log.error("Failed to create remote file reader with error: " + e.getMessage());
            return null;
        }
    }

    public Seq<String> getPreferredLocations(Partition partition) {
        return ((Buffer) JavaConverters.asScalaBufferConverter(Arrays.asList(((InternalPartition) partition).partition.getCopyLocations())).asScala()).seq();
    }

    public Partition[] getPartitions() {
        return this.parts;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1951498426:
                if (implMethodName.equals("lambda$makeMLLibVector$cdb40253$1")) {
                    z = false;
                    break;
                }
                break;
            case 1292854980:
                if (implMethodName.equals("lambda$makeMLLibLabeledPoint$30937160$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/hpccsystems/spark/HpccRDD") && serializedLambda.getImplMethodSignature().equals("([ILorg/apache/spark/sql/Row;)Lorg/apache/spark/mllib/linalg/Vector;")) {
                    int[] iArr = (int[]) serializedLambda.getCapturedArg(0);
                    return row -> {
                        double[] dArr = new double[iArr.length];
                        for (int i2 = 0; i2 < iArr.length; i2++) {
                            dArr[i2] = row.getDouble(iArr[i2]);
                        }
                        return new DenseVector(dArr);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/hpccsystems/spark/HpccRDD") && serializedLambda.getImplMethodSignature().equals("(I[ILorg/apache/spark/sql/Row;)Lorg/apache/spark/mllib/regression/LabeledPoint;")) {
                    int intValue = ((Integer) serializedLambda.getCapturedArg(0)).intValue();
                    int[] iArr2 = (int[]) serializedLambda.getCapturedArg(1);
                    return row2 -> {
                        double d = row2.getDouble(intValue);
                        double[] dArr = new double[iArr2.length];
                        for (int i2 = 0; i2 < iArr2.length; i2++) {
                            dArr[i2] = row2.getDouble(iArr2[i2]);
                        }
                        return new LabeledPoint(d, new DenseVector(dArr));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
