package ai.konduit.serving.pipeline.util;

import ai.konduit.serving.pipeline.api.data.NDArray;
import ai.konduit.serving.pipeline.api.data.NDArrayType;

/* loaded from: input_file:ai/konduit/serving/pipeline/util/NDArrayUtils.class */
public class NDArrayUtils {
    private NDArrayUtils() {
    }

    public static NDArray FloatNDArrayToDouble(NDArray nDArray) {
        if (nDArray.type() != NDArrayType.FLOAT && nDArray.type() != NDArrayType.FLOAT16 && nDArray.type() != NDArrayType.BFLOAT16) {
            return nDArray;
        }
        float[][] fArr = (float[][]) nDArray.getAs(float[][].class);
        double[][] dArr = new double[(int) nDArray.shape()[0]][(int) nDArray.shape()[1]];
        for (int i = 0; i < fArr.length; i++) {
            for (int i2 = 0; i2 < fArr[i].length; i2++) {
                dArr[i][i2] = Double.valueOf(fArr[i][i2]).doubleValue();
            }
        }
        return NDArray.create(dArr);
    }

    public static double[] squeeze(NDArray nDArray) {
        if (nDArray.shape().length == 1) {
            return (double[]) nDArray.getAs(double[].class);
        }
        if (nDArray.shape().length == 2 && nDArray.shape()[0] == 1) {
            return ((double[][]) nDArray.getAs(double[][].class))[0];
        }
        throw new UnsupportedOperationException("Failed squeezing NDArray");
    }

    public static double[] getMaxValueAndIndex(double[] dArr) {
        double d = dArr[0];
        int i = 0;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        return new double[]{d, i};
    }

    public static float[][][][] nchwToNhwc(float[][][][] fArr) {
        int length = fArr.length;
        int length2 = fArr[0].length;
        int length3 = fArr[0][0].length;
        int length4 = fArr[0][0][0].length;
        float[][][][] fArr2 = new float[length][length3][length4][length2];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length3; i2++) {
                for (int i3 = 0; i3 < length4; i3++) {
                    for (int i4 = 0; i4 < length2; i4++) {
                        fArr2[i][i2][i3][i4] = fArr[i][i4][i2][i3];
                    }
                }
            }
        }
        return fArr2;
    }

    public static float[][][][] nhwcToNchw(float[][][][] fArr) {
        int length = fArr.length;
        int length2 = fArr[0].length;
        int length3 = fArr[0][0].length;
        int length4 = fArr[0][0][0].length;
        float[][][][] fArr2 = new float[length][length4][length2][length3];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                for (int i3 = 0; i3 < length3; i3++) {
                    for (int i4 = 0; i4 < length4; i4++) {
                        fArr2[i][i4][i2][i3] = fArr[i][i2][i3][i4];
                    }
                }
            }
        }
        return fArr2;
    }
}
