/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.engine;

import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.onnxruntime.engine.OrtNDManager;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;

final class OrtUtils {
    private OrtUtils() {
    }

    public static OnnxTensor toTensor(OrtEnvironment env, NDArray array) throws OrtException {
        ByteBuffer bb = array.toByteBuffer();
        DataType dataType = array.getDataType();
        Buffer buf = dataType.asDataType(bb);
        return OrtUtils.toTensor(env, buf, array.getShape(), dataType);
    }

    public static OnnxTensor toTensor(OrtEnvironment env, Buffer data, Shape shape, DataType dataType) throws OrtException {
        long[] sh = shape.getShape();
        switch (dataType) {
            case FLOAT32: {
                return OnnxTensor.createTensor((OrtEnvironment)env, (FloatBuffer)((FloatBuffer)data), (long[])sh);
            }
            case FLOAT64: {
                return OnnxTensor.createTensor((OrtEnvironment)env, (DoubleBuffer)((DoubleBuffer)data), (long[])sh);
            }
            case INT32: {
                return OnnxTensor.createTensor((OrtEnvironment)env, (IntBuffer)((IntBuffer)data), (long[])sh);
            }
            case INT64: {
                return OnnxTensor.createTensor((OrtEnvironment)env, (LongBuffer)((LongBuffer)data), (long[])sh);
            }
            case INT8: 
            case UINT8: {
                return OnnxTensor.createTensor((OrtEnvironment)env, (ByteBuffer)((ByteBuffer)data), (long[])sh, (OnnxJavaType)OnnxJavaType.INT8);
            }
            case BOOLEAN: {
                return OnnxTensor.createTensor((OrtEnvironment)env, (ByteBuffer)((ByteBuffer)data), (long[])sh, (OnnxJavaType)OnnxJavaType.BOOL);
            }
        }
        throw new EngineException("Data type not supported: " + dataType);
    }

    public static OnnxTensor toTensor(OrtEnvironment env, String[] inputs, Shape shape) throws OrtException {
        long[] sh = shape.getShape();
        return OnnxTensor.createTensor((OrtEnvironment)env, (String[])inputs, (long[])sh);
    }

    public static NDArray toNDArray(NDManager manager, OnnxTensor tensor) {
        if (manager instanceof OrtNDManager) {
            return ((OrtNDManager)manager).create(tensor);
        }
        ByteBuffer bb = tensor.getByteBuffer();
        bb.order(ByteOrder.nativeOrder());
        DataType dataType = OrtUtils.toDataType(tensor.getInfo().type);
        Shape shape = new Shape(tensor.getInfo().getShape());
        Buffer buf = dataType.asDataType(bb);
        tensor.close();
        return manager.create(buf, shape, dataType);
    }

    public static DataType toDataType(OnnxJavaType javaType) {
        switch (javaType) {
            case FLOAT: {
                return DataType.FLOAT32;
            }
            case DOUBLE: {
                return DataType.FLOAT64;
            }
            case INT8: {
                return DataType.INT8;
            }
            case INT32: {
                return DataType.INT32;
            }
            case INT64: {
                return DataType.INT64;
            }
            case BOOL: {
                return DataType.BOOLEAN;
            }
            case UNKNOWN: {
                return DataType.UNKNOWN;
            }
            case STRING: {
                return DataType.STRING;
            }
        }
        throw new UnsupportedOperationException("type is not supported: " + javaType);
    }
}

