package com.gengoai.apollo.ml;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.function.SerializableFunction;
import com.gengoai.stream.MStream;
import com.gengoai.stream.StreamingContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.IntStream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/InMemoryDataSet.class */
public class InMemoryDataSet extends DataSet {
    private static final long serialVersionUID = 1;

    @JsonProperty("data")
    private final List<Datum> data = new ArrayList();

    public InMemoryDataSet(@NonNull Collection<? extends Datum> collection) {
        if (collection == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        this.data.addAll(collection);
    }

    @JsonCreator
    public InMemoryDataSet(@NonNull @JsonProperty("data") Collection<? extends Datum> collection, @NonNull @JsonProperty("metadata") Map<String, ObservationMetadata> map, @NonNull @JsonProperty("ndarrayFactory") NDArrayFactory nDArrayFactory) {
        if (collection == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        if (map == null) {
            throw new NullPointerException("metadataMap is marked non-null but is null");
        }
        if (nDArrayFactory == null) {
            throw new NullPointerException("factory is marked non-null but is null");
        }
        this.data.addAll(collection);
        this.metadata.putAll(map);
        this.ndArrayFactory = nDArrayFactory;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public Iterator<DataSet> batchIterator(final int i) {
        Validation.checkArgument(i > 0, "Batch size must be > 0");
        return new Iterator<DataSet>() { // from class: com.gengoai.apollo.ml.InMemoryDataSet.1
            int index = 0;

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.index < InMemoryDataSet.this.data.size();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public DataSet next() {
                InMemoryDataSet inMemoryDataSet = new InMemoryDataSet(InMemoryDataSet.this.data.subList(this.index, Math.min(this.index + i, InMemoryDataSet.this.data.size())), InMemoryDataSet.this.getMetadata(), InMemoryDataSet.this.getNDArrayFactory());
                this.index += i;
                return inMemoryDataSet;
            }
        };
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet cache() {
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSetType getType() {
        return DataSetType.InMemory;
    }

    @Override // java.lang.Iterable
    public Iterator<Datum> iterator() {
        return this.data.iterator();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet map(SerializableFunction<? super Datum, ? extends Datum> serializableFunction) {
        IntStream.range(0, this.data.size()).forEach(i -> {
            this.data.set(i, (Datum) serializableFunction.apply(this.data.get(i)));
        });
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public MStream<Datum> parallelStream() {
        return StreamingContext.local().stream(this).parallel();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet shuffle() {
        Collections.shuffle(this.data);
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet shuffle(Random random) {
        Collections.shuffle(this.data, random);
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public long size() {
        return this.data.size();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public MStream<Datum> stream() {
        return StreamingContext.local().stream(this);
    }
}
