/*
 * Decompiled with CFR 0.152.
 */
package org.platanios.tensorflow.data.image;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import java.io.ByteArrayOutputStream;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.zip.GZIPInputStream;
import org.platanios.tensorflow.api.package$;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.data.Loader;
import org.platanios.tensorflow.data.image.MNISTDataset;
import org.platanios.tensorflow.data.image.MNISTLoader;
import org.platanios.tensorflow.data.image.MNISTLoader$MNIST$;
import org.slf4j.LoggerFactory;
import scala.Function0;
import scala.Function1;
import scala.Predef$;
import scala.collection.Seq;
import scala.runtime.BoxedUnit;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

public final class MNISTLoader$
implements Loader {
    public static MNISTLoader$ MODULE$;
    private final Logger logger;

    static {
        new MNISTLoader$();
    }

    @Override
    public boolean maybeDownload(Path path, String url, int bufferSize) {
        return Loader.maybeDownload$(this, path, url, bufferSize);
    }

    @Override
    public int maybeDownload$default$3() {
        return Loader.maybeDownload$default$3$(this);
    }

    @Override
    public Logger logger() {
        return this.logger;
    }

    public MNISTDataset load(Path path, MNISTLoader.DatasetType datasetType, int bufferSize) {
        Path trainImagesPath = path.resolve(datasetType.trainImagesFilename());
        Path trainLabelsPath = path.resolve(datasetType.trainLabelsFilename());
        Path testImagesPath = path.resolve(datasetType.testImagesFilename());
        Path testLabelsPath = path.resolve(datasetType.testLabelsFilename());
        this.maybeDownload(trainImagesPath, new StringBuilder(0).append(datasetType.url()).append(datasetType.trainImagesFilename()).toString(), bufferSize);
        this.maybeDownload(trainLabelsPath, new StringBuilder(0).append(datasetType.url()).append(datasetType.trainLabelsFilename()).toString(), bufferSize);
        this.maybeDownload(testImagesPath, new StringBuilder(0).append(datasetType.url()).append(datasetType.testImagesFilename()).toString(), bufferSize);
        this.maybeDownload(testLabelsPath, new StringBuilder(0).append(datasetType.url()).append(datasetType.testLabelsFilename()).toString(), bufferSize);
        Tensor trainImages = this.extractImages(trainImagesPath, bufferSize);
        Tensor trainLabels = this.extractLabels(trainLabelsPath, bufferSize);
        Tensor testImages = this.extractImages(testImagesPath, bufferSize);
        Tensor testLabels = this.extractLabels(testLabelsPath, bufferSize);
        return new MNISTDataset(datasetType, trainImages, trainLabels, testImages, testLabels);
    }

    public MNISTLoader.DatasetType load$default$2() {
        return MNISTLoader$MNIST$.MODULE$;
    }

    public int load$default$3() {
        return 8192;
    }

    private Tensor extractImages(Path path, int bufferSize) {
        BoxedUnit boxedUnit;
        if (this.logger().underlying().isInfoEnabled()) {
            this.logger().underlying().info("Extracting images from file '{}'.", new Object[]{path});
            boxedUnit = BoxedUnit.UNIT;
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
        GZIPInputStream inputStream = new GZIPInputStream(Files.newInputStream(path, new OpenOption[0]));
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[bufferSize];
        scala.package$.MODULE$.Stream().continually((Function0)(JFunction0.mcI.sp & Serializable & scala.Serializable)() -> inputStream.read(buffer)).takeWhile((Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)x$1 -> x$1 != -1).foreach((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)x$2 -> outputStream.write(buffer, 0, x$2));
        ByteBuffer byteBuffer = ByteBuffer.wrap(outputStream.toByteArray()).order(ByteOrder.BIG_ENDIAN);
        outputStream.close();
        inputStream.close();
        int magicNumber = (int)((long)byteBuffer.getInt() & 0xFFFFFFFFL);
        if (magicNumber != 2051) {
            throw new IllegalStateException(new StringBuilder(47).append("Invalid magic number '").append(magicNumber).append("' in MNIST image file '").append(path).append("'.").toString());
        }
        int numberOfImages = (int)((long)byteBuffer.getInt() & 0xFFFFFFFFL);
        int numberOfRows = (int)((long)byteBuffer.getInt() & 0xFFFFFFFFL);
        int numberOfColumns = (int)((long)byteBuffer.getInt() & 0xFFFFFFFFL);
        int numBytes = byteBuffer.limit() - 16;
        Tensor tensor = package$.MODULE$.Tensor().fromBuffer((DataType)package$.MODULE$.UINT8(), package$.MODULE$.Shape().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{numberOfImages, numberOfRows, numberOfColumns})), (long)numBytes, byteBuffer);
        outputStream.close();
        inputStream.close();
        return tensor;
    }

    private int extractImages$default$2() {
        return 8192;
    }

    private Tensor extractLabels(Path path, int bufferSize) {
        BoxedUnit boxedUnit;
        if (this.logger().underlying().isInfoEnabled()) {
            this.logger().underlying().info("Extracting labels from file '{}'.", new Object[]{path});
            boxedUnit = BoxedUnit.UNIT;
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
        GZIPInputStream inputStream = new GZIPInputStream(Files.newInputStream(path, new OpenOption[0]));
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[bufferSize];
        scala.package$.MODULE$.Stream().continually((Function0)(JFunction0.mcI.sp & Serializable & scala.Serializable)() -> inputStream.read(buffer)).takeWhile((Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)x$3 -> x$3 != -1).foreach((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)x$4 -> outputStream.write(buffer, 0, x$4));
        ByteBuffer byteBuffer = ByteBuffer.wrap(outputStream.toByteArray()).order(ByteOrder.BIG_ENDIAN);
        outputStream.close();
        inputStream.close();
        int magicNumber = (int)((long)byteBuffer.getInt() & 0xFFFFFFFFL);
        if (magicNumber != 2049) {
            throw new IllegalStateException(new StringBuilder(48).append("Invalid magic number '").append(magicNumber).append("' in MNIST labels file '").append(path).append("'.").toString());
        }
        int numberOfLabels = (int)((long)byteBuffer.getInt() & 0xFFFFFFFFL);
        int numBytes = byteBuffer.limit() - 8;
        Tensor tensor = package$.MODULE$.Tensor().fromBuffer((DataType)package$.MODULE$.UINT8(), package$.MODULE$.Shape().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{numberOfLabels})), (long)numBytes, byteBuffer);
        outputStream.close();
        inputStream.close();
        return tensor;
    }

    private int extractLabels$default$2() {
        return 8192;
    }

    private MNISTLoader$() {
        MODULE$ = this;
        Loader.$init$(this);
        this.logger = Logger$.MODULE$.apply(LoggerFactory.getLogger((String)"MNIST Data Loader"));
    }
}

