package org.tribuo.interop.onnx;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Paths;
import java.util.HashMap;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.MutableDataset;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.datasource.LibSVMDataSource;
import org.tribuo.test.Helpers;

/* loaded from: input_file:org/tribuo/interop/onnx/TestOnnxRuntime.class */
public class TestOnnxRuntime {
    @Test
    public void testCNNMNIST() throws IOException, OrtException, URISyntaxException {
        LabelFactory labelFactory = new LabelFactory();
        OrtEnvironment environment = OrtEnvironment.getEnvironment();
        Throwable th = null;
        try {
            try {
                OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
                LibSVMDataSource libSVMDataSource = new LibSVMDataSource(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/mnist_test_head.libsvm"), labelFactory, false, 784);
                MutableDataset mutableDataset = new MutableDataset(libSVMDataSource);
                HashMap hashMap = new HashMap();
                for (int i = 0; i < 784; i++) {
                    hashMap.put(String.format("%03d", Integer.valueOf(i)), Integer.valueOf(i));
                }
                HashMap hashMap2 = new HashMap();
                for (Label label : mutableDataset.getOutputInfo().getDomain()) {
                    hashMap2.put(label, Integer.valueOf(Integer.parseInt(label.getLabel())));
                }
                LabelEvaluation evaluate = labelFactory.getEvaluator().evaluate(ONNXExternalModel.createOnnxModel(labelFactory, hashMap, hashMap2, new ImageTransformer(1, 28, 28), new LabelTransformer(), sessionOptions, Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/cnn_mnist.onnx").toURI()), "input_image"), libSVMDataSource);
                Assertions.assertEquals(1.0d, evaluate.accuracy(), 1.0E-6d);
                Assertions.assertEquals(0.0d, evaluate.balancedErrorRate(), 1.0E-6d);
                if (environment != null) {
                    if (0 == 0) {
                        environment.close();
                        return;
                    }
                    try {
                        environment.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (environment != null) {
                if (th != null) {
                    try {
                        environment.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    environment.close();
                }
            }
            throw th4;
        }
    }

    @Test
    public void testMNIST() throws IOException, OrtException, URISyntaxException {
        LabelFactory labelFactory = new LabelFactory();
        OrtEnvironment environment = OrtEnvironment.getEnvironment();
        Throwable th = null;
        try {
            OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
            LibSVMDataSource libSVMDataSource = new LibSVMDataSource(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/mnist_test_head.libsvm"), labelFactory, false, 784);
            MutableDataset mutableDataset = new MutableDataset(libSVMDataSource);
            HashMap hashMap = new HashMap();
            for (int i = 0; i < 784; i++) {
                hashMap.put(String.format("%03d", Integer.valueOf(i)), Integer.valueOf(783 - i));
            }
            HashMap hashMap2 = new HashMap();
            for (Label label : mutableDataset.getOutputInfo().getDomain()) {
                hashMap2.put(label, Integer.valueOf(Integer.parseInt(label.getLabel())));
            }
            ONNXExternalModel createOnnxModel = ONNXExternalModel.createOnnxModel(labelFactory, hashMap, hashMap2, new DenseTransformer(), new LabelTransformer(), sessionOptions, Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/lr_mnist.onnx").toURI()), "float_input");
            createOnnxModel.setBatchSize(1);
            LabelEvaluation evaluate = labelFactory.getEvaluator().evaluate(createOnnxModel, libSVMDataSource);
            Assertions.assertEquals(0.967741d, evaluate.accuracy(), 1.0E-6d);
            Assertions.assertEquals(0.024285d, evaluate.balancedErrorRate(), 1.0E-6d);
            Helpers.testModelSerialization(createOnnxModel, Label.class);
            if (environment != null) {
                if (0 == 0) {
                    environment.close();
                    return;
                }
                try {
                    environment.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (environment != null) {
                if (0 != 0) {
                    try {
                        environment.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    environment.close();
                }
            }
            throw th3;
        }
    }

    @Test
    public void testTransposedMNIST() throws IOException, OrtException, URISyntaxException {
        LabelFactory labelFactory = new LabelFactory();
        OrtEnvironment environment = OrtEnvironment.getEnvironment();
        Throwable th = null;
        try {
            OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
            LibSVMDataSource libSVMDataSource = new LibSVMDataSource(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/transposed_mnist_test_head.libsvm"), labelFactory, true, 784);
            MutableDataset mutableDataset = new MutableDataset(libSVMDataSource);
            HashMap hashMap = new HashMap();
            for (int i = 0; i < 784; i++) {
                hashMap.put(String.format("%03d", Integer.valueOf(i)), Integer.valueOf(i));
            }
            HashMap hashMap2 = new HashMap();
            for (Label label : mutableDataset.getOutputInfo().getDomain()) {
                hashMap2.put(label, Integer.valueOf(Integer.parseInt(label.getLabel())));
            }
            ONNXExternalModel createOnnxModel = ONNXExternalModel.createOnnxModel(labelFactory, hashMap, hashMap2, new DenseTransformer(), new LabelTransformer(), sessionOptions, Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/lr_mnist.onnx").toURI()), "float_input");
            createOnnxModel.setBatchSize(1);
            LabelEvaluation evaluate = labelFactory.getEvaluator().evaluate(createOnnxModel, libSVMDataSource);
            Assertions.assertEquals(0.967741d, evaluate.accuracy(), 1.0E-6d);
            Assertions.assertEquals(0.024285d, evaluate.balancedErrorRate(), 1.0E-6d);
            if (environment != null) {
                if (0 == 0) {
                    environment.close();
                    return;
                }
                try {
                    environment.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (environment != null) {
                if (0 != 0) {
                    try {
                        environment.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    environment.close();
                }
            }
            throw th3;
        }
    }
}
