package org.tensorflow.framework.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.tensorflow.Operand;
import org.tensorflow.framework.data.impl.BatchDataset;
import org.tensorflow.framework.data.impl.MapDataset;
import org.tensorflow.framework.data.impl.SkipDataset;
import org.tensorflow.framework.data.impl.TFRecordDataset;
import org.tensorflow.framework.data.impl.TakeDataset;
import org.tensorflow.framework.data.impl.TensorSliceDataset;
import org.tensorflow.framework.data.impl.TextLineDataset;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/framework/data/Dataset.class */
public abstract class Dataset implements Iterable<List<Operand<?>>> {
    protected Ops tf;
    private Operand<?> variant;
    private List<Class<? extends TType>> outputTypes;
    private List<Shape> outputShapes;

    public Dataset(Ops ops, Operand<?> operand, List<Class<? extends TType>> list, List<Shape> list2) {
        if (ops == null) {
            throw new IllegalArgumentException("Ops accessor cannot be null.");
        }
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("`outputTypes` and `outputShapes` must have the same size.");
        }
        this.tf = ops;
        this.variant = operand;
        this.outputTypes = list;
        this.outputShapes = list2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Dataset(Dataset dataset) {
        this.tf = dataset.tf;
        this.variant = dataset.variant;
        this.outputTypes = dataset.outputTypes;
        this.outputShapes = dataset.outputShapes;
    }

    public final Dataset batch(long j, boolean z) {
        ArrayList arrayList = new ArrayList();
        this.outputShapes.forEach(shape -> {
            arrayList.add(shape.prepend(-1L));
        });
        return new BatchDataset(this.tf, getVariant(), this.tf.constant(j), this.tf.constant(z), this.outputTypes, arrayList);
    }

    public final Dataset batch(long j) {
        return batch(j, false);
    }

    public final Dataset skip(long j) {
        return new SkipDataset(this.tf, getVariant(), this.tf.constant(j), getOutputTypes(), getOutputShapes());
    }

    public final Dataset take(long j) {
        return new TakeDataset(this.tf, getVariant(), this.tf.constant(j), getOutputTypes(), getOutputShapes());
    }

    public Dataset mapOneComponent(int i, Function<Operand<?>, Operand<?>> function) {
        return map(list -> {
            ArrayList arrayList = new ArrayList(list);
            arrayList.set(i, (Operand) function.apply((Operand) list.get(i)));
            return arrayList;
        });
    }

    public Dataset mapAllComponents(Function<Operand<?>, Operand<?>> function) {
        return map(list -> {
            ArrayList arrayList = new ArrayList();
            list.forEach(operand -> {
                arrayList.add((Operand) function.apply(operand));
            });
            return arrayList;
        });
    }

    public Dataset map(Function<List<Operand<?>>, List<Operand<?>>> function) {
        return new MapDataset(this, function);
    }

    @Override // java.lang.Iterable
    public Iterator<List<Operand<?>>> iterator() {
        return makeOneShotIterator().iterator();
    }

    public DatasetIterator makeInitializeableIterator() {
        DatasetIterator fromStructure = DatasetIterator.fromStructure(this.tf, this.outputTypes, this.outputShapes);
        fromStructure.makeInitializer(this);
        return fromStructure;
    }

    public DatasetIterator makeOneShotIterator() {
        DatasetIterator makeInitializeableIterator = makeInitializeableIterator();
        makeInitializeableIterator.makeInitializer(this);
        return makeInitializeableIterator;
    }

    public static Dataset fromTensorSlices(Ops ops, List<Operand<?>> list, List<Class<? extends TType>> list2) {
        return new TensorSliceDataset(ops, list, list2);
    }

    public static Dataset tfRecordDataset(Ops ops, String str, String str2, long j) {
        return new TFRecordDataset(ops, ops.constant(str), ops.constant(str2), ops.constant(j));
    }

    public static Dataset textLineDataset(Ops ops, String str, String str2, long j) {
        return new TextLineDataset(ops, ops.constant(str), ops.constant(str2), ops.constant(j));
    }

    public Operand<?> getVariant() {
        return this.variant;
    }

    public List<Class<? extends TType>> getOutputTypes() {
        return this.outputTypes;
    }

    public List<Shape> getOutputShapes() {
        return this.outputShapes;
    }

    public Ops getOpsInstance() {
        return this.tf;
    }

    public String toString() {
        return "Dataset{outputTypes=" + Arrays.toString(getOutputTypes().stream().map((v0) -> {
            return v0.getSimpleName();
        }).toArray()) + ", outputShapes=" + Arrays.toString(getOutputShapes().stream().map((v0) -> {
            return v0.toString();
        }).toArray()) + "}";
    }
}
