/*
 * Decompiled with CFR 0.152.
 */
package org.platanios.tensorflow.data.image;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.zip.GZIPInputStream;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.package;
import org.platanios.tensorflow.api.package$;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.platanios.tensorflow.data.Loader;
import org.platanios.tensorflow.data.image.CIFARDataset;
import org.platanios.tensorflow.data.image.CIFARLoader;
import org.platanios.tensorflow.data.image.CIFARLoader$CIFAR_10$;
import org.platanios.tensorflow.data.image.CIFARLoader$CIFAR_100$;
import org.slf4j.LoggerFactory;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

public final class CIFARLoader$
implements Loader {
    public static CIFARLoader$ MODULE$;
    private final Logger logger;

    static {
        new CIFARLoader$();
    }

    @Override
    public boolean maybeDownload(Path path, String url, int bufferSize) {
        return Loader.maybeDownload$(this, path, url, bufferSize);
    }

    @Override
    public int maybeDownload$default$3() {
        return Loader.maybeDownload$default$3$(this);
    }

    @Override
    public Logger logger() {
        return this.logger;
    }

    public CIFARDataset load(Path path, CIFARLoader.DatasetType datasetType, int bufferSize) {
        BoxedUnit boxedUnit;
        String url = datasetType.url();
        String compressedFilename = datasetType.compressedFilename();
        this.maybeDownload(path.resolve(compressedFilename), new StringBuilder(0).append(url).append(compressedFilename).toString(), bufferSize);
        CIFARDataset dataset = this.extractFiles(path.resolve(compressedFilename), datasetType, bufferSize);
        if (this.logger().underlying().isInfoEnabled()) {
            this.logger().underlying().info("Finished loading the CIFAR dataset.");
            boxedUnit = BoxedUnit.UNIT;
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
        return dataset;
    }

    public CIFARLoader.DatasetType load$default$2() {
        return CIFARLoader$CIFAR_10$.MODULE$;
    }

    public int load$default$3() {
        return 8192;
    }

    private CIFARDataset extractFiles(Path path, CIFARLoader.DatasetType datasetType, int bufferSize) {
        BoxedUnit boxedUnit;
        if (this.logger().underlying().isInfoEnabled()) {
            this.logger().underlying().info("Extracting data from file '{}'.", new Object[]{path});
            boxedUnit = BoxedUnit.UNIT;
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
        TarArchiveInputStream inputStream = new TarArchiveInputStream((InputStream)new GZIPInputStream(Files.newInputStream(path, new OpenOption[0])));
        CIFARDataset dataset = new CIFARDataset(datasetType, null, null, null, null);
        ObjectRef entry = ObjectRef.create((Object)inputStream.getNextTarEntry());
        while ((TarArchiveEntry)entry.elem != null) {
            if (datasetType.trainFilenames().exists((Function1 & Serializable & scala.Serializable)x$1 -> BoxesRunTime.boxToBoolean((boolean)CIFARLoader$.$anonfun$extractFiles$1(entry, x$1)))) {
                Tuple2<Tensor, Tensor> tuple2 = this.readImagesAndLabels(inputStream, (TarArchiveEntry)entry.elem, datasetType, bufferSize);
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Tensor images = (Tensor)tuple2._1();
                Tensor labels = (Tensor)tuple2._2();
                Tuple2 tuple22 = new Tuple2((Object)images, (Object)labels);
                Tuple2 tuple23 = tuple22;
                Tensor images2 = (Tensor)tuple23._1();
                Tensor labels2 = (Tensor)tuple23._2();
                Tensor trainImages = dataset.trainImages() == null ? images2 : package.tfi$.MODULE$.concatenate((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tensor[]{dataset.trainImages(), images2})), package$.MODULE$.tensorConvertibleToTensor((Object)BoxesRunTime.boxToInteger((int)0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())));
                Tensor trainLabels = dataset.trainLabels() == null ? labels2 : package.tfi$.MODULE$.concatenate((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tensor[]{dataset.trainLabels(), labels2})), package$.MODULE$.tensorConvertibleToTensor((Object)BoxesRunTime.boxToInteger((int)0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())));
                CIFARDataset qual$1 = dataset;
                Tensor x$7 = trainImages;
                Tensor x$8 = trainLabels;
                CIFARLoader.DatasetType x$9 = qual$1.copy$default$1();
                Tensor x$10 = qual$1.copy$default$4();
                Tensor x$11 = qual$1.copy$default$5();
                dataset = qual$1.copy(x$9, x$7, x$8, x$10, x$11);
            } else if (((TarArchiveEntry)entry.elem).getName().endsWith(datasetType.testFilename())) {
                Tuple2<Tensor, Tensor> tuple2 = this.readImagesAndLabels(inputStream, (TarArchiveEntry)entry.elem, datasetType, bufferSize);
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Tensor images = (Tensor)tuple2._1();
                Tensor labels = (Tensor)tuple2._2();
                Tuple2 tuple24 = new Tuple2((Object)images, (Object)labels);
                Tuple2 tuple25 = tuple24;
                Tensor images3 = (Tensor)tuple25._1();
                Tensor labels3 = (Tensor)tuple25._2();
                CIFARDataset qual$2 = dataset;
                Tensor x$12 = images3;
                Tensor x$13 = labels3;
                CIFARLoader.DatasetType x$14 = qual$2.copy$default$1();
                Tensor x$15 = qual$2.copy$default$2();
                Tensor x$16 = qual$2.copy$default$3();
                dataset = qual$2.copy(x$14, x$15, x$16, x$12, x$13);
            }
            entry.elem = inputStream.getNextTarEntry();
        }
        inputStream.close();
        return dataset;
    }

    private CIFARLoader.DatasetType extractFiles$default$2() {
        return CIFARLoader$CIFAR_10$.MODULE$;
    }

    private int extractFiles$default$3() {
        return 8192;
    }

    private Tuple2<Tensor, Tensor> readImagesAndLabels(TarArchiveInputStream inputStream, TarArchiveEntry entry, CIFARLoader.DatasetType datasetType, int bufferSize) {
        Tuple2 tuple2;
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[bufferSize];
        scala.package$.MODULE$.Stream().continually((Function0)(JFunction0.mcI.sp & Serializable & scala.Serializable)() -> inputStream.read(buffer)).takeWhile((Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)x$4 -> x$4 != -1).foreach((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)x$5 -> outputStream.write(buffer, 0, x$5));
        ByteBuffer byteBuffer = ByteBuffer.wrap(outputStream.toByteArray()).order(ByteOrder.BIG_ENDIAN);
        outputStream.close();
        int numSamples = (int)entry.getSize() / datasetType.entryByteSize();
        Shape combinedShape = package$.MODULE$.Shape().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{numSamples, datasetType.entryByteSize()}));
        Tensor combined = package$.MODULE$.Tensor().fromBuffer((DataType)package$.MODULE$.UINT8(), combinedShape, (long)((int)entry.getSize()), byteBuffer);
        CIFARLoader.DatasetType datasetType2 = datasetType;
        if (CIFARLoader$CIFAR_10$.MODULE$.equals(datasetType2)) {
            tuple2 = new Tuple2((Object)package$.MODULE$.tensorToBasicOps(combined.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Indexer[]{package$.MODULE$.$colon$colon(), package$.MODULE$.intToIndexerConstruction(1).$colon$colon()}))).reshape((Object)package$.MODULE$.Shape().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{-1, 32, 32, 3})), TensorConvertible$.MODULE$.shapeTensorConvertible()), (Object)combined.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Indexer[]{package$.MODULE$.$colon$colon(), package$.MODULE$.intToIndex(0)})));
        } else if (CIFARLoader$CIFAR_100$.MODULE$.equals(datasetType2)) {
            Indexer[] indexerArray = new Indexer[2];
            indexerArray[0] = package$.MODULE$.$colon$colon();
            int n = 0;
            indexerArray[1] = IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(package$.MODULE$.intToIndexerConstruction(2).$colon$colon(package$.MODULE$.intToIndexerConstruction(n)));
            tuple2 = new Tuple2((Object)package$.MODULE$.tensorToBasicOps(combined.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Indexer[]{package$.MODULE$.$colon$colon(), package$.MODULE$.intToIndexerConstruction(2).$colon$colon()}))).reshape((Object)package$.MODULE$.Shape().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{-1, 32, 32, 3})), TensorConvertible$.MODULE$.shapeTensorConvertible()), (Object)combined.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])indexerArray)));
        } else {
            throw new MatchError((Object)datasetType2);
        }
        return tuple2;
    }

    private CIFARLoader.DatasetType readImagesAndLabels$default$3() {
        return CIFARLoader$CIFAR_10$.MODULE$;
    }

    private int readImagesAndLabels$default$4() {
        return 8192;
    }

    public static final /* synthetic */ boolean $anonfun$extractFiles$1(ObjectRef entry$1, String x$1) {
        return ((TarArchiveEntry)entry$1.elem).getName().endsWith(x$1);
    }

    private CIFARLoader$() {
        MODULE$ = this;
        Loader.$init$(this);
        this.logger = Logger$.MODULE$.apply(LoggerFactory.getLogger((String)"CIFAR Data Loader"));
    }
}

