package com.gengoai.apollo.ml.data;

import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.StreamingDataSet;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.io.CSV;
import com.gengoai.io.resource.Resource;
import com.gengoai.stream.StreamingContext;
import com.gengoai.stream.spark.SparkStream;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

/* loaded from: input_file:com/gengoai/apollo/ml/data/DistributedCSVDataSetReader.class */
public class DistributedCSVDataSetReader implements DataSetReader {
    private static final long serialVersionUID = 1;
    private final CSV csv;
    private final Schema schema;

    public DistributedCSVDataSetReader(@NonNull CSV csv) {
        if (csv == null) {
            throw new NullPointerException("csv is marked non-null but is null");
        }
        this.csv = csv;
        this.schema = null;
    }

    public DistributedCSVDataSetReader(@NonNull CSV csv, @NonNull Schema schema) {
        if (csv == null) {
            throw new NullPointerException("csv is marked non-null but is null");
        }
        if (schema == null) {
            throw new NullPointerException("schema is marked non-null but is null");
        }
        this.csv = csv;
        this.schema = schema;
        if (!csv.getHasHeader() && csv.getHeader().isEmpty()) {
            throw new IllegalArgumentException("Either the CSV must have a header or one must be defined.");
        }
    }

    @Override // com.gengoai.apollo.ml.data.DataSetReader
    public DataSet read(@NonNull Resource resource) throws IOException {
        if (resource == null) {
            throw new NullPointerException("dataResource is marked non-null but is null");
        }
        Dataset csv = new SQLContext(StreamingContext.distributed().sparkSession()).read().option("delimiter", this.csv.getDelimiter()).option("escape", this.csv.getEscape()).option("quote", this.csv.getQuote()).option("comment", this.csv.getComment()).option("header", this.csv.getHasHeader()).csv(resource.path());
        List asList = Arrays.asList(csv.columns());
        StreamingDataSet streamingDataSet = new StreamingDataSet(new SparkStream(csv.toJavaRDD().map(row -> {
            return rowToDatum(asList, row);
        })));
        Iterator it = asList.iterator();
        while (it.hasNext()) {
            streamingDataSet.updateMetadata((String) it.next(), observationMetadata -> {
                observationMetadata.setType(Variable.class);
            });
        }
        return streamingDataSet;
    }

    private Datum rowToDatum(List<String> list, Row row) {
        Datum datum = new Datum();
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            Object obj = row.get(i);
            if (this.schema != null) {
                datum.put(str, (Observation) this.schema.convert(str, obj.toString()));
            } else if (obj instanceof Number) {
                datum.put(str, (Observation) Variable.real(str, ((Number) obj).doubleValue()));
            } else {
                datum.put(str, (Observation) Variable.binary(str, obj.toString()));
            }
        }
        return datum;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 2138947190:
                if (implMethodName.equals("lambda$read$e4c4f2f2$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/data/DistributedCSVDataSetReader") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Lorg/apache/spark/sql/Row;)Lcom/gengoai/apollo/ml/Datum;")) {
                    DistributedCSVDataSetReader distributedCSVDataSetReader = (DistributedCSVDataSetReader) serializedLambda.getCapturedArg(0);
                    List list = (List) serializedLambda.getCapturedArg(1);
                    return row -> {
                        return rowToDatum(list, row);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
