package org.platanios.tensorflow.api.ops.io.data;

import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.Shape$;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.core.package$exception$InvalidDataTypeException;
import org.platanios.tensorflow.api.core.package$exception$InvalidDataTypeException$;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException;
import org.platanios.tensorflow.api.core.package$exception$InvalidShapeException$;
import org.platanios.tensorflow.api.implicits.helpers.DataTypeAuxToDataType;
import org.platanios.tensorflow.api.implicits.helpers.OutputToTensor;
import org.platanios.tensorflow.api.implicits.helpers.OutputToTensor$;
import org.platanios.tensorflow.api.ops.Callback$;
import org.platanios.tensorflow.api.ops.Callback$ArgType$;
import org.platanios.tensorflow.api.ops.Function;
import org.platanios.tensorflow.api.ops.Function$ArgType$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.io.data.Dataset;
import org.platanios.tensorflow.api.ops.io.data.FlatMapDataset;
import org.platanios.tensorflow.api.ops.io.data.MapDataset;
import org.platanios.tensorflow.api.ops.io.data.RepeatDataset;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.Tensor$;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.platanios.tensorflow.jni.OutOfRangeException;
import scala.Function0;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Iterable;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.Null$;
import scala.runtime.Tuple3Zipped$;
import scala.runtime.Tuple3Zipped$Ops$;

/* compiled from: Dataset.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/io/data/Dataset$.class */
public final class Dataset$ {
    public static Dataset$ MODULE$;

    static {
        new Dataset$();
    }

    public <T, O, D, S> String $lessinit$greater$default$1() {
        return "Dataset";
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T, O, DA, D, S> Dataset<T, O, D, S> fromGenerator(Function0<Iterable<T>> function0, DA da, S s, DataTypeAuxToDataType<DA> dataTypeAuxToDataType, Data<T> data, OutputToTensor<O> outputToTensor, Function.ArgType<O> argType) {
        Object castDataType = dataTypeAuxToDataType.castDataType(da);
        Object unflattenShapes = s != null ? s : data.unflattenShapes(castDataType, (Seq) Seq$.MODULE$.fill(data.size(castDataType), () -> {
            return Shape$.MODULE$.unknown(Shape$.MODULE$.unknown$default$1());
        }));
        Seq<DataType> flattenedDataTypes = data.flattenedDataTypes(castDataType);
        Seq<Shape> flattenedShapes = data.flattenedShapes(unflattenShapes);
        Dataset.GeneratorState generatorState = new Dataset.GeneratorState(function0, data);
        MapDataset.MapDatasetOps<T, O, D, S> datasetToMapDatasetOps = org.platanios.tensorflow.api.package$.MODULE$.datasetToMapDatasetOps(new TensorDataset(Tensor$.MODULE$.apply(org.platanios.tensorflow.api.types.package$.MODULE$.INT64(), BoxesRunTime.boxToInteger(0), Predef$.MODULE$.wrapIntArray(new int[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), TensorDataset$.MODULE$.apply$default$2(), org.platanios.tensorflow.api.package$.MODULE$.tensorDataHelper(), OutputToTensor$.MODULE$.outputToTensor(), Function$ArgType$.MODULE$.outputArgType()));
        FlatMapDataset.FlatMapDatasetOps datasetToFlatMapDatasetOps = org.platanios.tensorflow.api.package$.MODULE$.datasetToFlatMapDatasetOps(datasetToMapDatasetOps.map(output -> {
            return (Output) Callback$.MODULE$.callback(boxedUnit -> {
                return Tensor$.MODULE$.apply(org.platanios.tensorflow.api.types.package$.MODULE$.INT64(), BoxesRunTime.boxToLong(generatorState.nextId()), Predef$.MODULE$.wrapLongArray(new long[0]), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.longIsSupportedType()));
            }, BoxedUnit.UNIT, org.platanios.tensorflow.api.types.package$.MODULE$.INT64(), true, Callback$.MODULE$.callback$default$5(), Callback$ArgType$.MODULE$.unitArgType(), Callback$ArgType$.MODULE$.tensorArgType());
        }, datasetToMapDatasetOps.map$default$2(), datasetToMapDatasetOps.map$default$3(), OutputToTensor$.MODULE$.outputToTensor(), org.platanios.tensorflow.api.package$.MODULE$.tensorDataHelper(), Function$ArgType$.MODULE$.outputArgType()));
        return datasetToFlatMapDatasetOps.flatMap(output2 -> {
            return flatMapFn$1(output2, s, data, outputToTensor, argType, castDataType, flattenedDataTypes, flattenedShapes, generatorState);
        }, datasetToFlatMapDatasetOps.flatMap$default$2(), outputToTensor, data, argType);
    }

    public <T, O, DA, D, S> Null$ fromGenerator$default$3() {
        return null;
    }

    public Output datasetPaddedBatch(Output output, Output output2, Seq<Output> seq, Seq<Output> seq2, Seq<Shape> seq3, String str) throws IllegalArgumentException {
        DataType dataType = output2.dataType();
        DataType.Aux<Object> INT64 = org.platanios.tensorflow.api.types.package$.MODULE$.INT64();
        if (dataType != null ? !dataType.equals(INT64) : INT64 != null) {
            throw new IllegalArgumentException(new StringBuilder(50).append("'batchSize' (dataType = ").append(output2.dataType()).append(") must be an INT64 tensor.").toString());
        }
        if (output2.rank() != -1 && output2.rank() > 0) {
            throw new IllegalArgumentException(new StringBuilder(41).append("'batchSize' (rank = ").append(output2.rank()).append(") must be equal to 0.").toString());
        }
        if (seq.exists(output3 -> {
            return BoxesRunTime.boxToBoolean($anonfun$datasetPaddedBatch$1(output3));
        })) {
            throw new IllegalArgumentException("'paddedShapes' must all be INT64 tensors.");
        }
        if (seq.exists(output4 -> {
            return BoxesRunTime.boxToBoolean($anonfun$datasetPaddedBatch$2(output4));
        })) {
            throw new IllegalArgumentException("'paddedShapes' must all be vector tensors (i.e., must have rank 1).");
        }
        if (seq2.exists(output5 -> {
            return BoxesRunTime.boxToBoolean($anonfun$datasetPaddedBatch$3(output5));
        })) {
            throw new IllegalArgumentException("'paddingValues' must all be scalar tensors (i.e., must have rank 0).");
        }
        if (seq.size() != seq2.size()) {
            throw new IllegalArgumentException(new StringBuilder(99).append("'paddedShapes' (number = ").append(seq.size()).append(") and 'paddingValues' (number = ").append(seq2.size()).append(") must ").append("contain the same number of tensors.").toString());
        }
        if (seq.size() != seq3.size()) {
            throw new IllegalArgumentException(new StringBuilder(98).append("'paddedShapes' (number = ").append(seq.size()).append(") and 'outputShapes' (number = ").append(seq3.size()).append(") must ").append("contain the same number of tensors.").toString());
        }
        return new Op.Builder("PaddedBatchDataset", str).addInput(output).addInput(output2).addInputList(seq).addInputList(seq2).setAttribute("output_shapes", (Shape[]) seq3.toArray(ClassTag$.MODULE$.apply(Shape.class))).build().outputs()[0];
    }

    public String datasetPaddedBatch$default$6() {
        return "DatasetPaddedBatch";
    }

    public static final /* synthetic */ void $anonfun$fromGenerator$2(Tensor tensor, DataType dataType, Shape shape) {
        DataType dataType2 = tensor.dataType();
        if (dataType2 != null ? !dataType2.equals(dataType) : dataType != null) {
            throw new package$exception$InvalidDataTypeException(new StringBuilder(42).append("The generator yielded an element of type ").append(tensor.dataType()).append(" ").append(new StringBuilder(39).append("where an element of type ").append(dataType).append(" was expected.").toString()).toString(), package$exception$InvalidDataTypeException$.MODULE$.apply$default$2());
        }
        if (!tensor.shape().isCompatibleWith(shape)) {
            throw new package$exception$InvalidShapeException(new StringBuilder(45).append("The generator yielded an element with shape ").append(tensor.shape()).append(" ").append(new StringBuilder(42).append("where an element with shape ").append(shape).append(" was expected.").toString()).toString(), package$exception$InvalidShapeException$.MODULE$.apply$default$2());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final Seq generatorScalaCallback$1(Tensor tensor, Data data, Seq seq, Seq seq2, Dataset.GeneratorState generatorState) throws OutOfRangeException {
        scala.collection.Iterator iterator = generatorState.getIterator(BoxesRunTime.unboxToLong(tensor.scalar()));
        if (!iterator.hasNext()) {
            throw package$exception$.MODULE$.OutOfRangeException().apply("The iterator does not contain any more elements.");
        }
        Seq<Tensor> flattenedTensors = data.flattenedTensors(iterator.next());
        Tuple3Zipped$.MODULE$.foreach$extension(Tuple3Zipped$Ops$.MODULE$.zipped$extension(Predef$.MODULE$.tuple3ToZippedOps(new Tuple3(flattenedTensors, seq, seq2)), Predef$.MODULE$.$conforms(), Predef$.MODULE$.$conforms(), Predef$.MODULE$.$conforms()), (tensor2, dataType, shape) -> {
            $anonfun$fromGenerator$2(tensor2, dataType, shape);
            return BoxedUnit.UNIT;
        });
        return flattenedTensors;
    }

    public static final /* synthetic */ void $anonfun$fromGenerator$4(Tuple2 tuple2) {
        if (!((Output) tuple2._1()).shape().isCompatibleWith((Shape) tuple2._2())) {
            throw new IllegalArgumentException(new StringBuilder(63).append("Generator output shape ").append(((Output) tuple2._1()).shape()).append(" is not compatible with provided shape ").append(tuple2._2()).append(".").toString());
        }
        ((Output) tuple2._1()).setShape((Shape) tuple2._2());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final Object generatorMapFn$1(Output output, Object obj, Data data, Object obj2, Seq seq, Seq seq2, Dataset.GeneratorState generatorState) {
        Seq<Output> seq3 = (Seq) Callback$.MODULE$.callback(tensor -> {
            return generatorScalaCallback$1(tensor, data, seq, seq2, generatorState);
        }, output, seq, true, Callback$.MODULE$.callback$default$5(), Callback$ArgType$.MODULE$.tensorArgType(), Callback$ArgType$.MODULE$.tensorSeqArgType(Seq$.MODULE$.canBuildFrom(), Seq$.MODULE$.canBuildFrom()));
        if (obj != null) {
            ((IterableLike) seq3.zip(seq2, Seq$.MODULE$.canBuildFrom())).foreach(tuple2 -> {
                $anonfun$fromGenerator$4(tuple2);
                return BoxedUnit.UNIT;
            });
        }
        return data.unflattenOutputs(obj2, seq3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final Dataset flatMapFn$1(Output output, Object obj, Data data, OutputToTensor outputToTensor, Function.ArgType argType, Object obj2, Seq seq, Seq seq2, Dataset.GeneratorState generatorState) {
        org.platanios.tensorflow.api.package$ package_ = org.platanios.tensorflow.api.package$.MODULE$;
        RepeatDataset.RepeatDatasetOps datasetToRepeatDatasetOps = org.platanios.tensorflow.api.package$.MODULE$.datasetToRepeatDatasetOps(new OutputDataset(output, OutputDataset$.MODULE$.apply$default$2(), OutputToTensor$.MODULE$.outputToTensor(), org.platanios.tensorflow.api.package$.MODULE$.tensorDataHelper(), Function$ArgType$.MODULE$.outputArgType()));
        MapDataset.MapDatasetOps datasetToMapDatasetOps = package_.datasetToMapDatasetOps(datasetToRepeatDatasetOps.repeat(datasetToRepeatDatasetOps.repeat$default$1(), datasetToRepeatDatasetOps.repeat$default$2()));
        return datasetToMapDatasetOps.map(output2 -> {
            return generatorMapFn$1(output2, obj, data, obj2, seq, seq2, generatorState);
        }, datasetToMapDatasetOps.map$default$2(), datasetToMapDatasetOps.map$default$3(), outputToTensor, data, argType);
    }

    public static final /* synthetic */ boolean $anonfun$datasetPaddedBatch$1(Output output) {
        DataType dataType = output.dataType();
        DataType.Aux<Object> INT64 = org.platanios.tensorflow.api.types.package$.MODULE$.INT64();
        return dataType != null ? !dataType.equals(INT64) : INT64 != null;
    }

    public static final /* synthetic */ boolean $anonfun$datasetPaddedBatch$2(Output output) {
        return (output.rank() == -1 || output.rank() == 1) ? false : true;
    }

    public static final /* synthetic */ boolean $anonfun$datasetPaddedBatch$3(Output output) {
        return (output.rank() == -1 || output.rank() == 0) ? false : true;
    }

    private Dataset$() {
        MODULE$ = this;
    }
}
