package org.emergentorder.onnx.backends;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import cats.effect.IO;
import cats.effect.IO$;
import cats.effect.package$;
import cats.implicits$;
import java.util.Map;
import onnx.onnx.ModelProto;
import onnx.onnx.NodeProto;
import onnx.onnx.ValueInfoProto;
import org.emergentorder.compiletime.TensorShapeDenotation;
import org.emergentorder.compiletime.TensorShapeDenotationOf;
import org.emergentorder.io.kjaer.compiletime.Shape;
import org.emergentorder.io.kjaer.compiletime.ShapeOf;
import org.emergentorder.onnx.OpToONNXBytesConverter;
import org.emergentorder.onnx.Tensors$Tensor$;
import scala.$less$colon$less$;
import scala.Array$;
import scala.Array$UnapplySeqWrapper$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Product;
import scala.Some;
import scala.Some$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnceOps;
import scala.collection.immutable.List;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Tuples$;

/* compiled from: ORTOperatorBackend.scala */
/* loaded from: input_file:org/emergentorder/onnx/backends/ORTOperatorBackend.class */
public interface ORTOperatorBackend extends OpToONNXBytesConverter, AutoCloseable {
    static void $init$(ORTOperatorBackend oRTOperatorBackend) {
        oRTOperatorBackend.org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$env_$eq(OrtEnvironment.getEnvironment());
        oRTOperatorBackend.org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$coreCount_$eq(Runtime.getRuntime().availableProcessors());
    }

    OrtEnvironment env();

    void org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$env_$eq(OrtEnvironment ortEnvironment);

    int coreCount();

    void org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$coreCount_$eq(int i);

    default OrtSession getSession(byte[] bArr) {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        sessionOptions.setIntraOpNumThreads(coreCount());
        return env().createSession(bArr, sessionOptions);
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> IO<Tuple2<Object, Tuple3<Tt, Td, S>>> runModel(OrtSession ortSession, OnnxTensor[] onnxTensorArr, List<String> list, List<String> list2, String str, TensorShapeDenotationOf<Td> tensorShapeDenotationOf, ShapeOf<S> shapeOf) {
        Map asJava = CollectionConverters$.MODULE$.MapHasAsJava(((IterableOnceOps) list.zip(Predef$.MODULE$.wrapRefArray(onnxTensorArr))).toMap($less$colon$less$.MODULE$.refl())).asJava();
        Shape value = shapeOf.value();
        String str2 = str;
        TensorShapeDenotation value2 = tensorShapeDenotationOf.value();
        return ((IO) package$.MODULE$.Resource().make(IO$.MODULE$.blocking(() -> {
            return $anonfun$1(r2, r3);
        }), result -> {
            return IO$.MODULE$.apply(() -> {
                result.close();
                return BoxedUnit.UNIT;
            });
        }, IO$.MODULE$.asyncForIO()).use(result2 -> {
            OnnxTensor onnxTensor = result2.get(0);
            Predef$.MODULE$.require(Predef$.MODULE$.wrapIntArray((int[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.longArrayOps(onnxTensor.getInfo().getShape()), j -> {
                return (int) j;
            }, ClassTag$.MODULE$.apply(Integer.TYPE))).sameElements(value.toSeq()));
            return IO$.MODULE$.blocking(() -> {
                return $anonfun$3$$anonfun$1(r1);
            });
        }, IO$.MODULE$.asyncForIO())).flatMap(obj -> {
            return Tensors$Tensor$.MODULE$.apply(obj, str2, value2, value);
        });
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> IO<Tuple2<Object, Tuple3<Tt, Td, S>>> callByteArrayOp(Product product, List<String> list, String str, scala.collection.immutable.Map<String, Object> map, ShapeOf<S> shapeOf, String str2, TensorShapeDenotationOf<Td> tensorShapeDenotationOf) {
        List list2 = (List) scala.package$.MODULE$.List().apply(ScalaRunTime$.MODULE$.wrapRefArray(new String[]{BoxesRunTime.boxToInteger(Tuples$.MODULE$.size(product)).toString()}));
        return ((IO) implicits$.MODULE$.toTraverseOps(Predef$.MODULE$.wrapRefArray((Object[]) ArrayOps$.MODULE$.flatMap$extension(Predef$.MODULE$.refArrayOps(Tuples$.MODULE$.toArray(product)), obj -> {
            if (!(obj instanceof Option)) {
                if (obj instanceof IO) {
                    return Some$.MODULE$.apply(((IO) obj).map(tuple2 -> {
                        return ORTTensorUtils$.MODULE$.getOnnxTensor(tuple2._1(), (int[]) ((Shape) ((Tuple3) tuple2._2())._3()).toSeq().toArray(ClassTag$.MODULE$.apply(Integer.TYPE)), env());
                    }));
                }
                throw new MatchError(obj);
            }
            Some some = (Option) obj;
            if (some instanceof Some) {
                return Some$.MODULE$.apply(((IO) some.value()).map(tuple22 -> {
                    return ORTTensorUtils$.MODULE$.getOnnxTensor(tuple22._1(), (int[]) ((Shape) ((Tuple3) tuple22._2())._3()).toSeq().toArray(ClassTag$.MODULE$.apply(Integer.TYPE)), env());
                }));
            }
            if (None$.MODULE$.equals(some)) {
                return None$.MODULE$;
            }
            throw new MatchError(some);
        }, ClassTag$.MODULE$.apply(IO.class))).toList(), implicits$.MODULE$.catsStdInstancesForList()).sequence($less$colon$less$.MODULE$.refl(), IO$.MODULE$.asyncForIO())).map(list3 -> {
            return (OnnxTensor[]) list3.toArray(ClassTag$.MODULE$.apply(OnnxTensor.class));
        }).memoize().flatMap(io -> {
            return io.map(onnxTensorArr -> {
                Object intArrayOps = Predef$.MODULE$.intArrayOps((int[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(onnxTensorArr), onnxTensor -> {
                    int i = onnxTensor.getInfo().onnxType.value;
                    switch (i) {
                        case 2:
                            return 3;
                        case 4:
                            return 5;
                        case 8:
                            return 7;
                        case 10:
                            return 1;
                        case 13:
                            return 9;
                        default:
                            return i;
                    }
                }, ClassTag$.MODULE$.apply(Integer.TYPE)));
                return res$1(list, shapeOf, str2, tensorShapeDenotationOf, list2, opToModelProto(str, ArrayOps$.MODULE$.zip$extension(intArrayOps, Predef$.MODULE$.wrapRefArray((Object[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(onnxTensorArr), onnxTensor2 -> {
                    int[] iArr = (int[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.longArrayOps(onnxTensor2.getInfo().getShape()), j -> {
                        return (int) j;
                    }, ClassTag$.MODULE$.apply(Integer.TYPE));
                    if (iArr != null) {
                        Object unapplySeq = Array$.MODULE$.unapplySeq(iArr);
                        if (Array$UnapplySeqWrapper$.MODULE$.lengthCompare$extension(unapplySeq, 1) == 0 && 1 == BoxesRunTime.unboxToInt(Array$UnapplySeqWrapper$.MODULE$.apply$extension(unapplySeq, 0))) {
                            return str.equals("Dropout") ? new int[0] : new int[]{1};
                        }
                    }
                    if (iArr instanceof int[]) {
                        return iArr;
                    }
                    throw new MatchError(iArr);
                }, ClassTag$.MODULE$.apply(Integer.TYPE).wrap()))), map).toByteArray(), io);
            });
        }).flatten($less$colon$less$.MODULE$.refl());
    }

    default <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> IO<Tuple2<Object, Tuple3<Tt, Td, S>>> callOp(String str, String str2, Product product, scala.collection.immutable.Map<String, Object> map, String str3, TensorShapeDenotationOf<Td> tensorShapeDenotationOf, ShapeOf<S> shapeOf) {
        return (IO) callByteArrayOp(product, RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), Tuples$.MODULE$.size(product)).toList().map(obj -> {
            return $anonfun$9(BoxesRunTime.unboxToInt(obj));
        }), str2, map, shapeOf, str3, tensorShapeDenotationOf).memoize().unsafeRunSync(cats.effect.unsafe.implicits$.MODULE$.global());
    }

    default ModelProto modelToPersist(ModelProto modelProto, String str) {
        return modelProto.clearGraph().withGraph(modelProto.getGraph().clearNode().withNode(scala.package$.MODULE$.Seq().apply(ScalaRunTime$.MODULE$.wrapRefArray(new NodeProto[]{((NodeProto) modelProto.getGraph().node().apply(0)).clearOutput().withOutput(scala.package$.MODULE$.Seq().apply(ScalaRunTime$.MODULE$.wrapRefArray(new String[]{str})))}))).clearOutput().withOutput(scala.package$.MODULE$.Seq().apply(ScalaRunTime$.MODULE$.wrapRefArray(new ValueInfoProto[]{((ValueInfoProto) modelProto.getGraph().output().apply(0)).clearName().withName(str)}))));
    }

    @Override // java.lang.AutoCloseable
    default void close() {
    }

    private static OrtSession.Result $anonfun$1(OrtSession ortSession, Map map) {
        return ortSession.run(map);
    }

    private static Object $anonfun$3$$anonfun$1(OnnxTensor onnxTensor) {
        return ORTTensorUtils$.MODULE$.getArrayFromOnnxTensor(onnxTensor);
    }

    private static void res$1$$anonfun$1$$anonfun$1(OnnxTensor[] onnxTensorArr) {
        ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(onnxTensorArr), onnxTensor -> {
            onnxTensor.close();
        }, ClassTag$.MODULE$.Unit());
    }

    private default OrtSession res$1$$anonfun$2$$anonfun$1(byte[] bArr) {
        return getSession(bArr);
    }

    private default IO res$1(List list, ShapeOf shapeOf, String str, TensorShapeDenotationOf tensorShapeDenotationOf, List list2, byte[] bArr, IO io) {
        return (IO) package$.MODULE$.Resource().make(io, onnxTensorArr -> {
            return IO$.MODULE$.apply(() -> {
                res$1$$anonfun$1$$anonfun$1(onnxTensorArr);
                return BoxedUnit.UNIT;
            });
        }, IO$.MODULE$.asyncForIO()).use(onnxTensorArr2 -> {
            return (IO) package$.MODULE$.Resource().make(IO$.MODULE$.blocking(() -> {
                return r2.res$1$$anonfun$2$$anonfun$1(r3);
            }), ortSession -> {
                return IO$.MODULE$.apply(() -> {
                    ortSession.close();
                    return BoxedUnit.UNIT;
                });
            }, IO$.MODULE$.asyncForIO()).use(ortSession2 -> {
                return runModel(ortSession2, onnxTensorArr2, list, list2, str, tensorShapeDenotationOf, shapeOf);
            }, IO$.MODULE$.asyncForIO());
        }, IO$.MODULE$.asyncForIO());
    }

    /* JADX INFO: Access modifiers changed from: private */
    static /* synthetic */ String $anonfun$9(int i) {
        return BoxesRunTime.boxToInteger(i).toString();
    }
}
