package org.emergentorder.onnx.backends;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import io.kjaer.compiletime.Shape;
import io.kjaer.compiletime.ShapeOf;
import org.emergentorder.compiletime.TensorShapeDenotation;
import org.emergentorder.compiletime.TensorShapeDenotationOf;
import org.emergentorder.onnx.OpToONNXBytesConverter;
import org.emergentorder.onnx.Tensors$Tensor$;
import scala.$less$colon$less$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Product;
import scala.Some;
import scala.Some$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnceOps;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.jdk.CollectionConverters$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Tuple$;
import scala.util.Using$;
import scala.util.Using$Releasable$AutoCloseableIsReleasable$;

/* compiled from: ORTOperatorBackend.scala */
/* loaded from: input_file:org/emergentorder/onnx/backends/ORTOperatorBackend.class */
public interface ORTOperatorBackend extends OpToONNXBytesConverter {
    default void $init$() {
        org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$env_$eq(OrtEnvironment.getEnvironment());
    }

    OrtEnvironment env();

    void org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$env_$eq(OrtEnvironment ortEnvironment);

    default OrtSession getSession(byte[] bArr) {
        return env().createSession(bArr);
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> Tuple2<Object, Tuple3<Tt, Td, S>> runModel(OrtSession ortSession, OnnxTensor[] onnxTensorArr, List<String> list, List<String> list2, Tt tt, TensorShapeDenotationOf<Td> tensorShapeDenotationOf, ShapeOf<S> shapeOf) {
        OnnxTensor onnxTensor = ortSession.run(CollectionConverters$.MODULE$.MapHasAsJava(((IterableOnceOps) list.zip(Predef$.MODULE$.wrapRefArray(onnxTensorArr))).toMap($less$colon$less$.MODULE$.refl())).asJava()).get(0);
        int[] iArr = (int[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.longArrayOps(onnxTensor.getInfo().getShape()), j -> {
            return (int) j;
        }, ClassTag$.MODULE$.apply(Integer.TYPE));
        Shape value = shapeOf.value();
        TensorShapeDenotation value2 = tensorShapeDenotationOf.value();
        Predef$.MODULE$.require(Predef$.MODULE$.wrapIntArray(iArr).sameElements(value.toSeq()));
        return Tensors$Tensor$.MODULE$.apply(ORTTensorUtils$.MODULE$.getArrayFromOnnxTensor(onnxTensor), tt, value2, value);
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> Tuple2<Object, Tuple3<Tt, Td, S>> callByteArrayOp(byte[] bArr, Product product, ShapeOf<S> shapeOf, Tt tt, TensorShapeDenotationOf<Td> tensorShapeDenotationOf) {
        List list = (List) package$.MODULE$.List().apply(ScalaRunTime$.MODULE$.wrapRefArray(new String[]{"0", "1", "2", "3", "4", "5", "6", "7", "8"}));
        List list2 = (List) package$.MODULE$.List().apply(ScalaRunTime$.MODULE$.wrapRefArray(new String[]{"outName"}));
        OnnxTensor[] onnxTensorArr = (OnnxTensor[]) ArrayOps$.MODULE$.flatten$extension(Predef$.MODULE$.refArrayOps((Object[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(Tuple$.MODULE$.toArray(product)), obj -> {
            if (!(obj instanceof Option)) {
                if (!(obj instanceof Tuple2)) {
                    throw new MatchError(obj);
                }
                Tuple2 tuple2 = (Tuple2) obj;
                return Some$.MODULE$.apply(ORTTensorUtils$.MODULE$.getOnnxTensor(Tensors$Tensor$.MODULE$.extension_data(tuple2), Tensors$Tensor$.MODULE$.extension_shape(tuple2), env()));
            }
            Some some = (Option) obj;
            if (some instanceof Some) {
                Tuple2 tuple22 = (Tuple2) some.value();
                return Some$.MODULE$.apply(ORTTensorUtils$.MODULE$.getOnnxTensor(Tensors$Tensor$.MODULE$.extension_data(tuple22), Tensors$Tensor$.MODULE$.extension_shape(tuple22), env()));
            }
            if (None$.MODULE$.equals(some)) {
                return None$.MODULE$;
            }
            throw new MatchError(some);
        }, ClassTag$.MODULE$.apply(Option.class))), Predef$.MODULE$.$conforms(), ClassTag$.MODULE$.apply(OnnxTensor.class));
        return (Tuple2) Using$.MODULE$.resource(getSession(bArr), ortSession -> {
            return runModel(ortSession, onnxTensorArr, list, list2, tt, tensorShapeDenotationOf, shapeOf);
        }, Using$Releasable$AutoCloseableIsReleasable$.MODULE$);
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> Tuple2<Object, Tuple3<Tt, Td, S>> callOp(String str, String str2, Product product, Map<String, Object> map, Tt tt, TensorShapeDenotationOf<Td> tensorShapeDenotationOf, ShapeOf<S> shapeOf) {
        return callByteArrayOp(opToONNXBytes(str, str2, product, "outName", map), product, shapeOf, tt, tensorShapeDenotationOf);
    }

    default void close() {
        env().close();
    }
}
