package org.emergentorder.onnx.backends;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import cats.effect.IO;
import cats.effect.IO$;
import cats.effect.package$;
import cats.effect.unsafe.implicits$;
import onnx.onnx.ModelProto;
import onnx.onnx.NodeProto;
import onnx.onnx.ValueInfoProto;
import org.emergentorder.compiletime.TensorShapeDenotation;
import org.emergentorder.compiletime.TensorShapeDenotationOf;
import org.emergentorder.io.kjaer.compiletime.Shape;
import org.emergentorder.io.kjaer.compiletime.ShapeOf;
import org.emergentorder.onnx.OpToONNXBytesConverter;
import org.emergentorder.onnx.Tensors$Tensor$;
import org.emergentorder.onnx.package;
import scala.$less$colon$less$;
import scala.MatchError;
import scala.Predef$;
import scala.Product;
import scala.Tuple1;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple3$;
import scala.collection.ArrayOps$;
import scala.collection.Iterable;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.Statics;
import scala.runtime.Tuples$;

/* compiled from: ORTModelBackend.scala */
/* loaded from: input_file:org/emergentorder/onnx/backends/ORTModelBackend.class */
public class ORTModelBackend extends package.Model implements OpToONNXBytesConverter, ORTOperatorBackend {
    private OrtEnvironment env;
    private int coreCount;
    private final OrtSession session;
    private final Tuple3 allNodeNamesAndDims;

    public ORTModelBackend(byte[] bArr) {
        ORTOperatorBackend.$init$(this);
        this.session = getSession(bArr);
        this.allNodeNamesAndDims = getInputAndOutputNodeNamesAndDims(session());
        Statics.releaseFence();
    }

    public /* bridge */ /* synthetic */ NodeProto opToNode(String str, String str2, String str3, Map map, String str4) {
        return OpToONNXBytesConverter.opToNode$(this, str, str2, str3, map, str4);
    }

    public /* bridge */ /* synthetic */ ValueInfoProto createInputValueInfoProto(int i, int[] iArr, String str) {
        return OpToONNXBytesConverter.createInputValueInfoProto$(this, i, iArr, str);
    }

    public /* bridge */ /* synthetic */ ModelProto opToModelProto(String str, Tuple2[] tuple2Arr, Map map) {
        return OpToONNXBytesConverter.opToModelProto$(this, str, tuple2Arr, map);
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public OrtEnvironment env() {
        return this.env;
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public int coreCount() {
        return this.coreCount;
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public void org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$env_$eq(OrtEnvironment ortEnvironment) {
        this.env = ortEnvironment;
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public void org$emergentorder$onnx$backends$ORTOperatorBackend$_setter_$coreCount_$eq(int i) {
        this.coreCount = i;
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public /* bridge */ /* synthetic */ OrtSession getSession(byte[] bArr) {
        return getSession(bArr);
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public /* bridge */ /* synthetic */ IO runModel(OrtSession ortSession, OnnxTensor[] onnxTensorArr, List list, List list2, String str, TensorShapeDenotationOf tensorShapeDenotationOf, ShapeOf shapeOf) {
        return runModel(ortSession, onnxTensorArr, list, list2, str, tensorShapeDenotationOf, shapeOf);
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public /* bridge */ /* synthetic */ IO callByteArrayOp(Product product, List list, String str, Map map, ShapeOf shapeOf, String str2, TensorShapeDenotationOf tensorShapeDenotationOf) {
        return callByteArrayOp(product, list, str, map, shapeOf, str2, tensorShapeDenotationOf);
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public /* bridge */ /* synthetic */ IO callOp(String str, String str2, Product product, Map map, String str3, TensorShapeDenotationOf tensorShapeDenotationOf, ShapeOf shapeOf) {
        return callOp(str, str2, product, map, str3, tensorShapeDenotationOf, shapeOf);
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend
    public /* bridge */ /* synthetic */ ModelProto modelToPersist(ModelProto modelProto, String str) {
        return modelToPersist(modelProto, str);
    }

    public Tuple3<List<String>, long[][], List<String>> getInputAndOutputNodeNamesAndDims(OrtSession ortSession) {
        return Tuple3$.MODULE$.apply(CollectionConverters$.MODULE$.SetHasAsScala(session().getInputNames()).asScala().toList(), ((Iterable) CollectionConverters$.MODULE$.CollectionHasAsScala(session().getInputInfo().values()).asScala().map(nodeInfo -> {
            return nodeInfo.getInfo().getShape();
        })).toArray(ClassTag$.MODULE$.apply(Long.TYPE).wrap()), CollectionConverters$.MODULE$.SetHasAsScala(session().getOutputNames()).asScala().toList());
    }

    public OrtSession session() {
        return this.session;
    }

    public Tuple3<List<String>, long[][], List<String>> allNodeNamesAndDims() {
        return this.allNodeNamesAndDims;
    }

    public <T, Tt extends String, Td extends TensorShapeDenotation, S extends Shape> IO<Tuple2<Object, Tuple3<Tt, Td, S>>> fullModel(Product product, String str, TensorShapeDenotationOf<Td> tensorShapeDenotationOf, ShapeOf<S> shapeOf) {
        return (IO) ((IO) package$.MODULE$.Resource().make(inputTensors$1(product, Tuples$.MODULE$.size(product)), onnxTensorArr -> {
            return IO$.MODULE$.apply(() -> {
                $anonfun$2$$anonfun$1(onnxTensorArr);
                return BoxedUnit.UNIT;
            });
        }, IO$.MODULE$.asyncForIO()).use(onnxTensorArr2 -> {
            return runModel(session(), onnxTensorArr2, (List) allNodeNamesAndDims()._1(), (List) allNodeNamesAndDims()._3(), str, tensorShapeDenotationOf, shapeOf);
        }, IO$.MODULE$.asyncForIO())).memoize().unsafeRunSync(implicits$.MODULE$.global());
    }

    @Override // org.emergentorder.onnx.backends.ORTOperatorBackend, java.lang.AutoCloseable
    public void close() {
    }

    private final /* synthetic */ IO inputTensors$1$$anonfun$1(Product product, int i) {
        Tuple1 take = Tuples$.MODULE$.take(Tuples$.MODULE$.drop(product, i), 1);
        if (!(take instanceof Tuple1)) {
            throw new MatchError(take);
        }
        Object apply = Tuples$.MODULE$.apply(take, 0);
        if (!(apply instanceof IO)) {
            throw new MatchError(apply);
        }
        IO io = (IO) apply;
        return Tensors$Tensor$.MODULE$.data(io).map(obj -> {
            return Tensors$Tensor$.MODULE$.shape(io).map(iArr -> {
                return ORTTensorUtils$.MODULE$.getOnnxTensor(obj, iArr, env());
            });
        }).flatten($less$colon$less$.MODULE$.refl());
    }

    private final IO inputTensors$1(Product product, int i) {
        return ((IO) cats.implicits$.MODULE$.toTraverseOps(RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), i).map(obj -> {
            return inputTensors$1$$anonfun$1(product, BoxesRunTime.unboxToInt(obj));
        }).toList(), cats.implicits$.MODULE$.catsStdInstancesForList()).sequence($less$colon$less$.MODULE$.refl(), IO$.MODULE$.asyncForIO())).map(list -> {
            return (OnnxTensor[]) list.toArray(ClassTag$.MODULE$.apply(OnnxTensor.class));
        });
    }

    private static final void $anonfun$2$$anonfun$1(OnnxTensor[] onnxTensorArr) {
        ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(onnxTensorArr), onnxTensor -> {
            onnxTensor.close();
        }, ClassTag$.MODULE$.Unit());
    }
}
