package com.gengoai.apollo.ml;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.function.SerializableFunction;
import com.gengoai.function.Unchecked;
import com.gengoai.io.MultiFileWriter;
import com.gengoai.io.SaveMode;
import com.gengoai.io.resource.Resource;
import com.gengoai.json.Json;
import com.gengoai.stream.MStream;
import com.gengoai.stream.StreamingContext;
import com.gengoai.tuple.Tuples;
import java.io.IOException;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Stream;
import lombok.NonNull;

@JsonDeserialize(as = InMemoryDataSet.class)
/* loaded from: input_file:com/gengoai/apollo/ml/DataSet.class */
public abstract class DataSet implements Iterable<Datum>, Serializable {
    private static final long serialVersionUID = 1;
    public static int PROBE_LIMIT = 500;
    protected final Map<String, ObservationMetadata> metadata = new ConcurrentHashMap();

    @NonNull
    protected NDArrayFactory ndArrayFactory = NDArrayFactory.ND;

    public abstract Iterator<DataSet> batchIterator(int i);

    public abstract DataSet cache();

    public List<Datum> collect() {
        return stream().collect();
    }

    public ObservationMetadata getMetadata(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("source is marked non-null but is null");
        }
        return this.metadata.get(str);
    }

    public Map<String, ObservationMetadata> getMetadata() {
        return this.metadata;
    }

    public NDArrayFactory getNDArrayFactory() {
        return this.ndArrayFactory;
    }

    public DataSet setNDArrayFactory(@NonNull NDArrayFactory nDArrayFactory) {
        if (nDArrayFactory == null) {
            throw new NullPointerException("ndArrayFactory is marked non-null but is null");
        }
        this.ndArrayFactory = nDArrayFactory;
        return this;
    }

    @JsonIgnore
    public StreamingContext getStreamingContext() {
        return getType().getStreamingContext();
    }

    @JsonIgnore
    public abstract DataSetType getType();

    public abstract DataSet map(@NonNull SerializableFunction<? super Datum, ? extends Datum> serializableFunction);

    public abstract MStream<Datum> parallelStream();

    public DataSet persist(@NonNull Resource resource) {
        if (resource == null) {
            throw new NullPointerException("resource is marked non-null but is null");
        }
        SQLiteDataSet sQLiteDataSet = new SQLiteDataSet(resource, stream().javaStream());
        sQLiteDataSet.putAllMetadata(getMetadata());
        sQLiteDataSet.setNDArrayFactory(getNDArrayFactory());
        return sQLiteDataSet;
    }

    public DataSet persist() {
        SQLiteDataSet sQLiteDataSet = new SQLiteDataSet((Stream<Datum>) stream().javaStream());
        sQLiteDataSet.putAllMetadata(getMetadata());
        sQLiteDataSet.setNDArrayFactory(getNDArrayFactory());
        return sQLiteDataSet;
    }

    public DataSet probe() {
        parallelStream().limit(PROBE_LIMIT).flatMap(datum -> {
            return datum.entrySet().stream();
        }).map(entry -> {
            return ((Observation) entry.getValue()).isNDArray() ? Tuples.$((String) entry.getKey(), ((Observation) entry.getValue()).getClass(), Long.valueOf(((Observation) entry.getValue()).asNDArray().length())) : Tuples.$((String) entry.getKey(), ((Observation) entry.getValue()).getClass(), 0L);
        }).distinct().forEach(tuple3 -> {
            updateMetadata((String) tuple3.v1, observationMetadata -> {
                observationMetadata.setType((Class) tuple3.v2);
                observationMetadata.setDimension(((Long) tuple3.v3).longValue());
            });
        });
        return this;
    }

    public DataSet putAllMetadata(@NonNull Map<String, ObservationMetadata> map) {
        if (map == null) {
            throw new NullPointerException("metadata is marked non-null but is null");
        }
        this.metadata.putAll(map);
        return this;
    }

    public DataSet removeMetadata(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("source is marked non-null but is null");
        }
        this.metadata.remove(str);
        return this;
    }

    public void save(@NonNull Resource resource, int i, @NonNull SaveMode saveMode) throws IOException {
        if (resource == null) {
            throw new NullPointerException("resource is marked non-null but is null");
        }
        if (saveMode == null) {
            throw new NullPointerException("saveMode is marked non-null but is null");
        }
        if (saveMode.validate(resource)) {
            resource.mkdirs();
            MultiFileWriter multiFileWriter = new MultiFileWriter(resource, "part-", i);
            try {
                parallelStream().forEach(Unchecked.consumer(datum -> {
                    multiFileWriter.write(Json.dumps(datum) + "\n");
                }));
                multiFileWriter.close();
                resource.getChild("metadata.json").write(Json.dumps(getMetadata()));
            } catch (Throwable th) {
                try {
                    multiFileWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
    }

    public DataSet shuffle() {
        return shuffle(new Random());
    }

    public abstract DataSet shuffle(Random random);

    public abstract long size();

    public abstract MStream<Datum> stream();

    public List<Datum> take(int i) {
        return stream().take(i);
    }

    public DataSet updateMetadata(@NonNull String str, @NonNull Consumer<ObservationMetadata> consumer) {
        if (str == null) {
            throw new NullPointerException("source is marked non-null but is null");
        }
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
        this.metadata.compute(str, (str2, observationMetadata) -> {
            if (observationMetadata == null) {
                observationMetadata = new ObservationMetadata();
            }
            consumer.accept(observationMetadata);
            return observationMetadata;
        });
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -428653140:
                if (implMethodName.equals("lambda$save$950655a4$1")) {
                    z = 3;
                    break;
                }
                break;
            case 863172052:
                if (implMethodName.equals("lambda$probe$e3e37133$1")) {
                    z = true;
                    break;
                }
                break;
            case 1702700729:
                if (implMethodName.equals("lambda$probe$f49a7d91$1")) {
                    z = false;
                    break;
                }
                break;
            case 1702700730:
                if (implMethodName.equals("lambda$probe$f49a7d91$2")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/DataSet") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Ljava/util/stream/Stream;")) {
                    return datum -> {
                        return datum.entrySet().stream();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/DataSet") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/tuple/Tuple3;)V")) {
                    DataSet dataSet = (DataSet) serializedLambda.getCapturedArg(0);
                    return tuple3 -> {
                        updateMetadata((String) tuple3.v1, observationMetadata -> {
                            observationMetadata.setType((Class) tuple3.v2);
                            observationMetadata.setDimension(((Long) tuple3.v3).longValue());
                        });
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/DataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map$Entry;)Lcom/gengoai/tuple/Tuple3;")) {
                    return entry -> {
                        return ((Observation) entry.getValue()).isNDArray() ? Tuples.$((String) entry.getKey(), ((Observation) entry.getValue()).getClass(), Long.valueOf(((Observation) entry.getValue()).asNDArray().length())) : Tuples.$((String) entry.getKey(), ((Observation) entry.getValue()).getClass(), 0L);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/CheckedConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/DataSet") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/io/MultiFileWriter;Lcom/gengoai/apollo/ml/Datum;)V")) {
                    MultiFileWriter multiFileWriter = (MultiFileWriter) serializedLambda.getCapturedArg(0);
                    return datum2 -> {
                        multiFileWriter.write(Json.dumps(datum2) + "\n");
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
