package com.gengoai.apollo.ml;

import com.gengoai.Copyable;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.function.SerializableFunction;
import com.gengoai.stream.MStream;
import com.gengoai.stream.StorageLevel;
import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/StreamingDataSet.class */
public class StreamingDataSet extends DataSet {
    private MStream<Datum> stream;

    public StreamingDataSet(@NonNull MStream<Datum> mStream) {
        if (mStream == null) {
            throw new NullPointerException("stream is marked non-null but is null");
        }
        this.stream = mStream;
    }

    public StreamingDataSet(@NonNull MStream<Datum> mStream, @NonNull Map<String, ObservationMetadata> map, @NonNull NDArrayFactory nDArrayFactory) {
        if (mStream == null) {
            throw new NullPointerException("stream is marked non-null but is null");
        }
        if (map == null) {
            throw new NullPointerException("metadataMap is marked non-null but is null");
        }
        if (nDArrayFactory == null) {
            throw new NullPointerException("factory is marked non-null but is null");
        }
        this.stream = mStream;
        this.metadata.putAll(map);
        this.ndArrayFactory = nDArrayFactory;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public Iterator<DataSet> batchIterator(int i) {
        return this.stream.partition(i).map(stream -> {
            return datasetOf(getType().getStreamingContext().stream(stream).map((v0) -> {
                return v0.m14copy();
            }));
        }).iterator();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet cache() {
        InMemoryDataSet inMemoryDataSet = new InMemoryDataSet(this.stream.collect());
        inMemoryDataSet.metadata.putAll((Map) Copyable.deepCopy(this.metadata));
        return inMemoryDataSet;
    }

    protected StreamingDataSet datasetOf(MStream<Datum> mStream) {
        return new StreamingDataSet(mStream, getMetadata(), getNDArrayFactory());
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSetType getType() {
        return this.stream.getContext().isDistributed() ? DataSetType.Distributed : DataSetType.LocalStreaming;
    }

    @Override // java.lang.Iterable
    public Iterator<Datum> iterator() {
        return this.stream.iterator();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet map(@NonNull SerializableFunction<? super Datum, ? extends Datum> serializableFunction) {
        if (serializableFunction == null) {
            throw new NullPointerException("function is marked non-null but is null");
        }
        StreamingDataSet streamingDataSet = new StreamingDataSet(this.stream.map(serializableFunction));
        streamingDataSet.metadata.putAll((Map) Copyable.deepCopy(this.metadata));
        return streamingDataSet;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public MStream<Datum> parallelStream() {
        return this.stream.parallel();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet persist() {
        if (!this.stream.isDistributed()) {
            return super.persist();
        }
        this.stream.persist(StorageLevel.OnDisk);
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet shuffle(Random random) {
        this.stream = this.stream.shuffle(random);
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public long size() {
        return this.stream.count();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public MStream<Datum> stream() {
        return this.stream;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -74784691:
                if (implMethodName.equals("lambda$batchIterator$b766057d$1")) {
                    z = false;
                    break;
                }
                break;
            case 3059573:
                if (implMethodName.equals("copy")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/StreamingDataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/stream/Stream;)Lcom/gengoai/apollo/ml/DataSet;")) {
                    StreamingDataSet streamingDataSet = (StreamingDataSet) serializedLambda.getCapturedArg(0);
                    return stream -> {
                        return datasetOf(getType().getStreamingContext().stream(stream).map((v0) -> {
                            return v0.m14copy();
                        }));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/Datum") && serializedLambda.getImplMethodSignature().equals("()Lcom/gengoai/apollo/ml/Datum;")) {
                    return (v0) -> {
                        return v0.m14copy();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
