package org.tensorflow.framework.data.impl;

import java.util.List;
import java.util.stream.Collectors;
import org.tensorflow.Operand;
import org.tensorflow.framework.data.Dataset;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/framework/data/impl/TensorSliceDataset.class */
public class TensorSliceDataset extends Dataset {
    public TensorSliceDataset(Ops ops, List<Operand<?>> list, List<Class<? extends TType>> list2) {
        super(ops, makeVariant(ops, list, list2), list2, outputShapes(list));
    }

    private static List<Shape> outputShapes(List<Operand<?>> list) {
        return (List) list.stream().map(operand -> {
            return operand.shape().tail();
        }).collect(Collectors.toList());
    }

    private static Operand<?> makeVariant(Ops ops, List<Operand<?>> list, List<Class<? extends TType>> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Lists `tensors` and `dtypes` must have the same number of elements.");
        }
        return ops.data.tensorSliceDataset(list, outputShapes(list));
    }
}
