package org.datavec.image.loader;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.javacpp.lept;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgcodecs;
import org.bytedeco.javacpp.opencv_imgproc;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/datavec/image/loader/NativeImageLoader.class */
public class NativeImageLoader extends BaseImageLoader {
    public static final String[] ALLOWED_FORMATS;
    OpenCVFrameConverter.ToMat converter;
    static final /* synthetic */ boolean $assertionsDisabled;

    public NativeImageLoader() {
        this.converter = null;
    }

    public NativeImageLoader(int i, int i2) {
        this.converter = null;
        this.height = i;
        this.width = i2;
    }

    public NativeImageLoader(int i, int i2, int i3) {
        this.converter = null;
        this.height = i;
        this.width = i2;
        this.channels = i3;
    }

    public NativeImageLoader(int i, int i2, int i3, boolean z) {
        this(i, i2, i3);
        this.centerCropIfNeeded = z;
    }

    public NativeImageLoader(int i, int i2, int i3, ImageTransform imageTransform) {
        this(i, i2, i3);
        this.imageTransform = imageTransform;
        this.converter = new OpenCVFrameConverter.ToMat();
    }

    public NativeImageLoader(int i, int i2, int i3, ImageTransform imageTransform, double d) {
        this(i, i2, i3, imageTransform);
        this.normalizeIfNeeded = d > 0.0d;
        this.normalizeValue = d;
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public String[] getAllowedFormats() {
        return ALLOWED_FORMATS;
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asRowVector(File file) throws IOException {
        return asMatrix(file).ravel();
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asRowVector(InputStream inputStream) throws IOException {
        return asMatrix(inputStream).ravel();
    }

    public INDArray asRowVector(opencv_core.Mat mat) throws IOException {
        return asMatrix(mat).ravel();
    }

    static opencv_core.Mat convert(lept.PIX pix) {
        lept.PIX pix2 = null;
        if (pix.colormap() != null) {
            lept.PIX pixRemoveColormap = lept.pixRemoveColormap(pix, 2);
            pix = pixRemoveColormap;
            pix2 = pixRemoveColormap;
        } else if (pix.d() < 8) {
            lept.PIX pix3 = null;
            switch (pix.d()) {
                case 1:
                    pix3 = lept.pixConvert1To8((lept.PIX) null, pix, (byte) 0, (byte) -1);
                    break;
                case 2:
                    pix3 = lept.pixConvert2To8(pix, (byte) 0, (byte) 85, (byte) -86, (byte) -1, 0);
                    break;
                case 3:
                default:
                    if (!$assertionsDisabled) {
                        throw new AssertionError();
                    }
                    break;
                case 4:
                    pix3 = lept.pixConvert4To8(pix, 0);
                    break;
            }
            lept.PIX pix4 = pix3;
            pix = pix4;
            pix2 = pix4;
        }
        int h = pix.h();
        int w = pix.w();
        int d = pix.d() / 8;
        opencv_core.Mat mat = new opencv_core.Mat(h, w, opencv_core.CV_8UC(d), pix.data(), 4 * pix.wpl());
        opencv_core.Mat mat2 = new opencv_core.Mat(h, w, opencv_core.CV_8UC(d));
        opencv_core.mixChannels(mat, 1L, mat2, 1L, (d <= 1 || !ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) ? new int[]{0, 0, 1, 1, 2, 2, 3, 3} : new int[]{0, 3, 1, 2, 2, 1, 3, 0}, r19.length / 2);
        if (pix2 != null) {
            lept.pixDestroy(pix2);
        }
        return mat2;
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asMatrix(File file) throws IOException {
        opencv_core.Mat imread = opencv_imgcodecs.imread(file.getAbsolutePath(), 6);
        if (imread == null || imread.empty()) {
            lept.PIX pixRead = lept.pixRead(file.getAbsolutePath());
            if (pixRead == null) {
                throw new IOException("Could not read image from file: " + file);
            }
            imread = convert(pixRead);
            lept.pixDestroy(pixRead);
        }
        return asMatrix(imread);
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asMatrix(InputStream inputStream) throws IOException {
        byte[] byteArray = IOUtils.toByteArray(inputStream);
        opencv_core.Mat imdecode = opencv_imgcodecs.imdecode(new opencv_core.Mat(byteArray), 6);
        if (imdecode == null || imdecode.empty()) {
            lept.PIX pixReadMem = lept.pixReadMem(byteArray, byteArray.length);
            if (pixReadMem == null) {
                throw new IOException("Could not decode image from input stream");
            }
            imdecode = convert(pixReadMem);
            lept.pixDestroy(pixReadMem);
        }
        return asMatrix(imdecode);
    }

    /* JADX WARN: Type inference failed for: r1v27, types: [int[], int[][]] */
    public INDArray asMatrix(opencv_core.Mat mat) throws IOException {
        if (this.imageTransform != null && this.converter != null) {
            mat = this.converter.convert(this.imageTransform.transform(new ImageWritable(this.converter.convert(mat))).getFrame());
        }
        if (this.channels > 0 && mat.channels() != this.channels) {
            int i = -1;
            switch (mat.channels()) {
                case 1:
                    switch (this.channels) {
                        case 3:
                            i = 8;
                            break;
                        case 4:
                            i = 9;
                            break;
                    }
                case 3:
                    switch (this.channels) {
                        case 1:
                            i = 6;
                            break;
                        case 4:
                            i = 2;
                            break;
                    }
                case 4:
                    switch (this.channels) {
                        case 1:
                            i = 11;
                            break;
                        case 3:
                            i = 3;
                            break;
                    }
            }
            if (i < 0) {
                throw new IOException("Cannot convert from " + mat.channels() + " to " + this.channels + " channels.");
            }
            opencv_core.Mat mat2 = new opencv_core.Mat();
            opencv_imgproc.cvtColor(mat, mat2, i);
            mat = mat2;
        }
        if (this.centerCropIfNeeded) {
            mat = centerCropIfNeeded(mat);
        }
        opencv_core.Mat scalingIfNeed = scalingIfNeed(mat);
        int rows = scalingIfNeed.rows();
        int cols = scalingIfNeed.cols();
        int channels = scalingIfNeed.channels();
        UByteIndexer createIndexer = scalingIfNeed.createIndexer();
        INDArray create = Nd4j.create(new int[]{channels, rows, cols});
        FloatPointer pointer = create.data().pointer();
        int[] stride = create.stride();
        boolean z = false;
        if (pointer instanceof FloatPointer) {
            FloatIndexer create2 = FloatIndexer.create(pointer, new long[]{channels, rows, cols}, new long[]{stride[0], stride[1], stride[2]});
            if (createIndexer instanceof UByteIndexer) {
                UByteIndexer uByteIndexer = createIndexer;
                for (int i2 = 0; i2 < channels; i2++) {
                    for (int i3 = 0; i3 < rows; i3++) {
                        for (int i4 = 0; i4 < cols; i4++) {
                            create2.put(i2, i3, i4, uByteIndexer.get(i3, i4, i2));
                        }
                    }
                }
                z = true;
            } else if (createIndexer instanceof UShortIndexer) {
                UShortIndexer uShortIndexer = (UShortIndexer) createIndexer;
                for (int i5 = 0; i5 < channels; i5++) {
                    for (int i6 = 0; i6 < rows; i6++) {
                        for (int i7 = 0; i7 < cols; i7++) {
                            create2.put(i5, i6, i7, uShortIndexer.get(i6, i7, i5));
                        }
                    }
                }
                z = true;
            } else if (createIndexer instanceof IntIndexer) {
                IntIndexer intIndexer = (IntIndexer) createIndexer;
                for (int i8 = 0; i8 < channels; i8++) {
                    for (int i9 = 0; i9 < rows; i9++) {
                        for (int i10 = 0; i10 < cols; i10++) {
                            create2.put(i8, i9, i10, intIndexer.get(i9, i10, i8));
                        }
                    }
                }
                z = true;
            } else if (createIndexer instanceof FloatIndexer) {
                FloatIndexer floatIndexer = (FloatIndexer) createIndexer;
                for (int i11 = 0; i11 < channels; i11++) {
                    for (int i12 = 0; i12 < rows; i12++) {
                        for (int i13 = 0; i13 < cols; i13++) {
                            create2.put(i11, i12, i13, floatIndexer.get(i12, i13, i11));
                        }
                    }
                }
                z = true;
            }
        } else if (pointer instanceof DoublePointer) {
            DoubleIndexer create3 = DoubleIndexer.create((DoublePointer) pointer, new long[]{channels, rows, cols}, new long[]{stride[0], stride[1], stride[2]});
            if (createIndexer instanceof UByteIndexer) {
                UByteIndexer uByteIndexer2 = createIndexer;
                for (int i14 = 0; i14 < channels; i14++) {
                    for (int i15 = 0; i15 < rows; i15++) {
                        for (int i16 = 0; i16 < cols; i16++) {
                            create3.put(i14, i15, i16, uByteIndexer2.get(i15, i16, i14));
                        }
                    }
                }
                z = true;
            } else if (createIndexer instanceof UShortIndexer) {
                UShortIndexer uShortIndexer2 = (UShortIndexer) createIndexer;
                for (int i17 = 0; i17 < channels; i17++) {
                    for (int i18 = 0; i18 < rows; i18++) {
                        for (int i19 = 0; i19 < cols; i19++) {
                            create3.put(i17, i18, i19, uShortIndexer2.get(i18, i19, i17));
                        }
                    }
                }
                z = true;
            } else if (createIndexer instanceof IntIndexer) {
                IntIndexer intIndexer2 = (IntIndexer) createIndexer;
                for (int i20 = 0; i20 < channels; i20++) {
                    for (int i21 = 0; i21 < rows; i21++) {
                        for (int i22 = 0; i22 < cols; i22++) {
                            create3.put(i20, i21, i22, intIndexer2.get(i21, i22, i20));
                        }
                    }
                }
                z = true;
            } else if (createIndexer instanceof FloatIndexer) {
                FloatIndexer floatIndexer2 = (FloatIndexer) createIndexer;
                for (int i23 = 0; i23 < channels; i23++) {
                    for (int i24 = 0; i24 < rows; i24++) {
                        for (int i25 = 0; i25 < cols; i25++) {
                            create3.put(i23, i24, i25, floatIndexer2.get(i24, i25, i23));
                        }
                    }
                }
                z = true;
            }
        }
        if (!z) {
            for (int i26 = 0; i26 < channels; i26++) {
                for (int i27 = 0; i27 < rows; i27++) {
                    for (int i28 = 0; i28 < cols; i28++) {
                        if (channels > 1) {
                            create.putScalar(i26, i27, i28, createIndexer.getDouble(new long[]{i27, i28, i26}));
                        } else {
                            create.putScalar(i27, i28, createIndexer.getDouble(new long[]{i27, i28}));
                        }
                    }
                }
            }
        }
        scalingIfNeed.data();
        if (this.normalizeIfNeeded) {
            create = normalizeIfNeeded(create);
        }
        return create.reshape(ArrayUtil.combine((int[][]) new int[]{new int[]{1}, create.shape()}));
    }

    protected INDArray normalizeIfNeeded(INDArray iNDArray) {
        return iNDArray.div(Double.valueOf(this.normalizeValue));
    }

    protected opencv_core.Mat centerCropIfNeeded(opencv_core.Mat mat) {
        int i = 0;
        int i2 = 0;
        int rows = mat.rows();
        int cols = mat.cols();
        int abs = Math.abs(cols - rows) / 2;
        if (cols > rows) {
            i = abs;
            cols -= abs;
        } else if (rows > cols) {
            i2 = abs;
            rows -= abs;
        }
        return mat.apply(new opencv_core.Rect(i, i2, cols, rows));
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat mat) {
        return scalingIfNeed(mat, this.height, this.width);
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat mat, int i, int i2) {
        opencv_core.Mat mat2 = mat;
        if (i > 0 && i2 > 0 && (mat.rows() != i || mat.cols() != i2)) {
            opencv_core.Mat mat3 = new opencv_core.Mat();
            mat2 = mat3;
            opencv_imgproc.resize(mat, mat3, new opencv_core.Size(i2, i));
        }
        return mat2;
    }

    static {
        $assertionsDisabled = !NativeImageLoader.class.desiredAssertionStatus();
        ALLOWED_FORMATS = new String[]{"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm", "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "PNG", "TIF", "TIFF", "EXR", "WEBP"};
    }
}
