package org.emergentorder.onnx.backends;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import io.kjaer.compiletime.Shape;
import io.kjaer.compiletime.ShapeOf;
import onnx.onnx.ModelProto;
import onnx.onnx.NodeProto;
import onnx.onnx.ValueInfoProto;
import org.emergentorder.compiletime.TensorShapeDenotation;
import org.emergentorder.compiletime.TensorShapeDenotationOf;
import org.emergentorder.onnx.OpToONNXBytesConverter;
import org.emergentorder.onnx.Tensors$Tensor$;
import scala.$less$colon$less$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Product;
import scala.Some;
import scala.Some$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnceOps;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.jdk.CollectionConverters$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Tuples$;
import scala.util.Using$;
import scala.util.Using$Releasable$AutoCloseableIsReleasable$;

/* compiled from: ORTOperatorBackend.scala */
/* loaded from: input_file:org/emergentorder/onnx/backends/ORTOperatorBackend.class */
public interface ORTOperatorBackend extends OpToONNXBytesConverter {
    static void $init$(ORTOperatorBackend oRTOperatorBackend) {
        oRTOperatorBackend.org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$env_$eq(OrtEnvironment.getEnvironment());
        oRTOperatorBackend.org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$coreCount_$eq(Runtime.getRuntime().availableProcessors());
    }

    OrtEnvironment env();

    void org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$env_$eq(OrtEnvironment ortEnvironment);

    int coreCount();

    void org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$coreCount_$eq(int i);

    default OrtSession getSession(byte[] bArr) {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        sessionOptions.setIntraOpNumThreads(coreCount());
        return env().createSession(bArr, sessionOptions);
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> Tuple2<Object, Tuple3<Tt, Td, S>> runModel(OrtSession ortSession, OnnxTensor[] onnxTensorArr, List<String> list, List<String> list2, String str, TensorShapeDenotationOf<Td> tensorShapeDenotationOf, ShapeOf<S> shapeOf) {
        OnnxTensor onnxTensor = ortSession.run(CollectionConverters$.MODULE$.MapHasAsJava(((IterableOnceOps) list.zip(Predef$.MODULE$.wrapRefArray(onnxTensorArr))).toMap($less$colon$less$.MODULE$.refl())).asJava()).get(0);
        int[] iArr = (int[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.longArrayOps(onnxTensor.getInfo().getShape()), j -> {
            return (int) j;
        }, ClassTag$.MODULE$.apply(Integer.TYPE));
        Shape value = shapeOf.value();
        TensorShapeDenotation value2 = tensorShapeDenotationOf.value();
        Predef$.MODULE$.require(Predef$.MODULE$.wrapIntArray(iArr).sameElements(value.toSeq()));
        return Tensors$Tensor$.MODULE$.apply(ORTTensorUtils$.MODULE$.getArrayFromOnnxTensor(onnxTensor), str, value2, value);
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> Tuple2<Object, Tuple3<Tt, Td, S>> callByteArrayOp(byte[] bArr, Product product, ShapeOf<S> shapeOf, String str, TensorShapeDenotationOf<Td> tensorShapeDenotationOf) {
        List list = Predef$.MODULE$.wrapRefArray((Object[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zipWithIndex$extension(Predef$.MODULE$.refArrayOps(Tuples$.MODULE$.toArray(product)))), tuple2 -> {
            return BoxesRunTime.boxToInteger(new StringBuilder(0).append(tuple2._1().toString()).append(ArrayOps$.MODULE$.size$extension(Predef$.MODULE$.refArrayOps((Object[]) ArrayOps$.MODULE$.distinct$extension(Predef$.MODULE$.refArrayOps(Tuples$.MODULE$.toArray(product))))) == Tuples$.MODULE$.size(product) ? "" : BoxesRunTime.boxToInteger(i$1(tuple2)).toString()).toString().hashCode()).toString();
        }, ClassTag$.MODULE$.apply(String.class))).toList();
        List list2 = (List) package$.MODULE$.List().apply(ScalaRunTime$.MODULE$.wrapRefArray(new String[]{list.toString()}));
        OnnxTensor[] onnxTensorArr = (OnnxTensor[]) ArrayOps$.MODULE$.flatten$extension(Predef$.MODULE$.refArrayOps((Object[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(Tuples$.MODULE$.toArray(product)), obj -> {
            if (!(obj instanceof Option)) {
                if (!(obj instanceof Tuple2)) {
                    throw new MatchError(obj);
                }
                Tuple2 tuple22 = (Tuple2) obj;
                return Some$.MODULE$.apply(ORTTensorUtils$.MODULE$.getOnnxTensor(Tensors$Tensor$.MODULE$.data(tuple22), Tensors$Tensor$.MODULE$.shape(tuple22), env()));
            }
            Some some = (Option) obj;
            if (some instanceof Some) {
                Tuple2 tuple23 = (Tuple2) some.value();
                return Some$.MODULE$.apply(ORTTensorUtils$.MODULE$.getOnnxTensor(Tensors$Tensor$.MODULE$.data(tuple23), Tensors$Tensor$.MODULE$.shape(tuple23), env()));
            }
            if (None$.MODULE$.equals(some)) {
                return None$.MODULE$;
            }
            throw new MatchError(some);
        }, ClassTag$.MODULE$.apply(Option.class))), Predef$.MODULE$.$conforms(), ClassTag$.MODULE$.apply(OnnxTensor.class));
        return (Tuple2) Using$.MODULE$.resource(getSession(bArr), ortSession -> {
            return runModel(ortSession, onnxTensorArr, list, list2, str, tensorShapeDenotationOf, shapeOf);
        }, Using$Releasable$AutoCloseableIsReleasable$.MODULE$);
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> Tuple2<Object, Tuple3<Tt, Td, S>> callOp(String str, String str2, Product product, Map<String, Object> map, String str3, TensorShapeDenotationOf<Td> tensorShapeDenotationOf, ShapeOf<S> shapeOf) {
        return callByteArrayOp(opToModelProto(str2, product, map).toByteArray(), product, shapeOf, str3, tensorShapeDenotationOf);
    }

    default ModelProto modelToPersist(ModelProto modelProto, String str) {
        return modelProto.clearGraph().withGraph(modelProto.getGraph().clearNode().withNode(package$.MODULE$.Seq().apply(ScalaRunTime$.MODULE$.wrapRefArray(new NodeProto[]{((NodeProto) modelProto.getGraph().node().apply(0)).clearOutput().withOutput(package$.MODULE$.Seq().apply(ScalaRunTime$.MODULE$.wrapRefArray(new String[]{str})))}))).clearOutput().withOutput(package$.MODULE$.Seq().apply(ScalaRunTime$.MODULE$.wrapRefArray(new ValueInfoProto[]{((ValueInfoProto) modelProto.getGraph().output().apply(0)).clearName().withName(str)}))));
    }

    default void close() {
        env().close();
    }

    private static int i$1(Tuple2 tuple2) {
        return BoxesRunTime.unboxToInt(tuple2._2());
    }
}
