package org.canova.image.loader;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgcodecs;
import org.bytedeco.javacpp.opencv_imgproc;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/canova/image/loader/NativeImageLoader.class */
public class NativeImageLoader extends BaseImageLoader {
    public NativeImageLoader() {
    }

    public NativeImageLoader(int i, int i2) {
        this.height = i;
        this.width = i2;
    }

    public NativeImageLoader(int i, int i2, int i3) {
        this.height = i;
        this.width = i2;
        this.channels = i3;
    }

    public NativeImageLoader(int i, int i2, int i3, boolean z) {
        this(i, i2, i3);
        this.centerCropIfNeeded = z;
    }

    @Override // org.canova.image.loader.BaseImageLoader
    public INDArray asRowVector(File file) throws IOException {
        return asMatrix(file).ravel();
    }

    @Override // org.canova.image.loader.BaseImageLoader
    public INDArray asRowVector(InputStream inputStream) throws IOException {
        return asMatrix(inputStream).ravel();
    }

    public INDArray asRowVector(opencv_core.Mat mat) throws IOException {
        return asMatrix(mat).ravel();
    }

    @Override // org.canova.image.loader.BaseImageLoader
    public INDArray asMatrix(File file) throws IOException {
        opencv_core.Mat imread = opencv_imgcodecs.imread(file.getAbsolutePath(), this.channels == 1 ? 0 : this.channels == 3 ? 1 : -1);
        if (imread == null) {
            throw new IOException("Could not read image from file: " + file);
        }
        return asMatrix(imread);
    }

    @Override // org.canova.image.loader.BaseImageLoader
    public INDArray asMatrix(InputStream inputStream) throws IOException {
        opencv_core.Mat imdecode = opencv_imgcodecs.imdecode(new opencv_core.Mat(IOUtils.toByteArray(inputStream)), this.channels == 1 ? 0 : this.channels == 3 ? 1 : -1);
        if (imdecode == null) {
            throw new IOException("Could not decode image from input stream");
        }
        return asMatrix(imdecode);
    }

    public INDArray asMatrix(opencv_core.Mat mat) throws IOException {
        if (this.channels > 0 && mat.channels() != this.channels) {
            int i = -1;
            switch (mat.channels()) {
                case 1:
                    switch (this.channels) {
                        case 3:
                            i = 8;
                            break;
                        case 4:
                            i = 9;
                            break;
                    }
                case 3:
                    switch (this.channels) {
                        case 1:
                            i = 6;
                            break;
                        case 4:
                            i = 2;
                            break;
                    }
                case 4:
                    switch (this.channels) {
                        case 1:
                            i = 11;
                            break;
                        case 3:
                            i = 3;
                            break;
                    }
            }
            if (i < 0) {
                throw new IOException("Cannot convert from " + mat.channels() + " to " + this.channels + " channels.");
            }
            opencv_core.Mat mat2 = new opencv_core.Mat();
            opencv_imgproc.cvtColor(mat, mat2, i);
            mat = mat2;
        }
        if (this.centerCropIfNeeded) {
            mat = centerCropIfNeeded(mat);
        }
        opencv_core.Mat scalingIfNeed = scalingIfNeed(mat);
        int rows = scalingIfNeed.rows();
        int cols = scalingIfNeed.cols();
        int channels = scalingIfNeed.channels();
        UByteIndexer createIndexer = scalingIfNeed.createIndexer();
        INDArray create = channels > 1 ? Nd4j.create(new int[]{channels, rows, cols}) : Nd4j.create(rows, cols);
        for (int i2 = 0; i2 < channels; i2++) {
            for (int i3 = 0; i3 < rows; i3++) {
                for (int i4 = 0; i4 < cols; i4++) {
                    if (channels > 1) {
                        create.putScalar(i2, i3, i4, createIndexer.get(i3, i4, i2));
                    } else {
                        create.putScalar(i3, i4, createIndexer.get(i3, i4));
                    }
                }
            }
        }
        return create;
    }

    protected opencv_core.Mat centerCropIfNeeded(opencv_core.Mat mat) {
        int i = 0;
        int i2 = 0;
        int rows = mat.rows();
        int cols = mat.cols();
        int abs = Math.abs(cols - rows) / 2;
        if (cols > rows) {
            i = abs;
            cols -= abs;
        } else if (rows > cols) {
            i2 = abs;
            rows -= abs;
        }
        return mat.apply(new opencv_core.Rect(i, i2, cols, rows));
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat mat) {
        return scalingIfNeed(mat, this.height, this.width);
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat mat, int i, int i2) {
        opencv_core.Mat mat2 = mat;
        if (i > 0 && i2 > 0 && (mat.rows() != i || mat.cols() != i2)) {
            opencv_core.Mat mat3 = new opencv_core.Mat();
            mat2 = mat3;
            opencv_imgproc.resize(mat, mat3, new opencv_core.Size(i2, i));
        }
        return mat2;
    }
}
