package org.tribuo.interop.onnx;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.impl.ArrayExample;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.MultiLabelFactory;

/* loaded from: input_file:org/tribuo/interop/onnx/MultiLabelTransformerTest.class */
public class MultiLabelTransformerTest {
    private static final MultiLabelFactory factory = new MultiLabelFactory();
    private static final MultiLabelTransformer transformer = new MultiLabelTransformer();

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void multilabelTest() {
        try {
            OrtEnvironment environment = OrtEnvironment.getEnvironment();
            Throwable th = null;
            try {
                HashMap hashMap = new HashMap();
                hashMap.put(new MultiLabel("A"), 0);
                hashMap.put(new MultiLabel("B"), 1);
                hashMap.put(new MultiLabel("C"), 2);
                hashMap.put(new MultiLabel("D"), 3);
                ImmutableOutputInfo constructInfoForExternalModel = factory.constructInfoForExternalModel(hashMap);
                OnnxTensor createTensor = OnnxTensor.createTensor(environment, new float[]{new float[]{0.1f, 0.51f, 0.8f, 0.0f}});
                MultiLabel transformToOutput = transformer.transformToOutput(Collections.singletonList(createTensor), constructInfoForExternalModel);
                Assertions.assertFalse(transformToOutput.contains("A"));
                Assertions.assertTrue(transformToOutput.contains("B"));
                Assertions.assertTrue(transformToOutput.contains("C"));
                Assertions.assertFalse(transformToOutput.contains("D"));
                Assertions.assertEquals(0.5099999904632568d, transformToOutput.getLabelScore(new Label("B")).getAsDouble());
                Assertions.assertEquals(0.800000011920929d, transformToOutput.getLabelScore(new Label("C")).getAsDouble());
                Prediction transformToPrediction = transformer.transformToPrediction(Collections.singletonList(createTensor), constructInfoForExternalModel, 1, new ArrayExample(transformToOutput));
                MultiLabel output = transformToPrediction.getOutput();
                Assertions.assertFalse(output.contains("A"));
                Assertions.assertTrue(output.contains("B"));
                Assertions.assertTrue(output.contains("C"));
                Assertions.assertFalse(output.contains("D"));
                Assertions.assertEquals(0.10000000149011612d, ((MultiLabel) transformToPrediction.getOutputScores().get("A")).getLabelScore(new Label("A")).getAsDouble());
                Assertions.assertEquals(0.5099999904632568d, ((MultiLabel) transformToPrediction.getOutputScores().get("B")).getLabelScore(new Label("B")).getAsDouble());
                Assertions.assertEquals(0.800000011920929d, ((MultiLabel) transformToPrediction.getOutputScores().get("C")).getLabelScore(new Label("C")).getAsDouble());
                Assertions.assertEquals(0.0d, ((MultiLabel) transformToPrediction.getOutputScores().get("D")).getLabelScore(new Label("D")).getAsDouble());
                if (environment != null) {
                    if (0 != 0) {
                        try {
                            environment.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        environment.close();
                    }
                }
            } finally {
            }
        } catch (OrtException e) {
            Assertions.fail(e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void multiLabelBatchTest() {
        try {
            OrtEnvironment environment = OrtEnvironment.getEnvironment();
            Throwable th = null;
            try {
                try {
                    HashMap hashMap = new HashMap();
                    hashMap.put(new MultiLabel("A"), 0);
                    hashMap.put(new MultiLabel("B"), 1);
                    hashMap.put(new MultiLabel("C"), 2);
                    hashMap.put(new MultiLabel("D"), 3);
                    ImmutableOutputInfo constructInfoForExternalModel = factory.constructInfoForExternalModel(hashMap);
                    OnnxTensor createTensor = OnnxTensor.createTensor(environment, new float[]{new float[]{0.1f, 0.51f, 0.8f, 0.0f}, new float[]{0.9f, 0.1f, 0.2f, 1.0f}, new float[]{0.0f, 0.0f, 0.0f, 0.0f}});
                    List transformToBatchOutput = transformer.transformToBatchOutput(Collections.singletonList(createTensor), constructInfoForExternalModel);
                    Assertions.assertEquals(3, transformToBatchOutput.size());
                    MultiLabel multiLabel = (MultiLabel) transformToBatchOutput.get(0);
                    Assertions.assertFalse(multiLabel.contains("A"));
                    Assertions.assertTrue(multiLabel.contains("B"));
                    Assertions.assertTrue(multiLabel.contains("C"));
                    Assertions.assertFalse(multiLabel.contains("D"));
                    Assertions.assertEquals(0.5099999904632568d, multiLabel.getLabelScore(new Label("B")).getAsDouble());
                    Assertions.assertEquals(0.800000011920929d, multiLabel.getLabelScore(new Label("C")).getAsDouble());
                    MultiLabel multiLabel2 = (MultiLabel) transformToBatchOutput.get(1);
                    Assertions.assertTrue(multiLabel2.contains("A"));
                    Assertions.assertFalse(multiLabel2.contains("B"));
                    Assertions.assertFalse(multiLabel2.contains("C"));
                    Assertions.assertTrue(multiLabel2.contains("D"));
                    Assertions.assertEquals(0.8999999761581421d, multiLabel2.getLabelScore(new Label("A")).getAsDouble());
                    Assertions.assertEquals(1.0d, multiLabel2.getLabelScore(new Label("D")).getAsDouble());
                    MultiLabel multiLabel3 = (MultiLabel) transformToBatchOutput.get(2);
                    Assertions.assertFalse(multiLabel3.contains("A"));
                    Assertions.assertFalse(multiLabel3.contains("B"));
                    Assertions.assertFalse(multiLabel3.contains("C"));
                    Assertions.assertFalse(multiLabel3.contains("D"));
                    Assertions.assertEquals(0, multiLabel3.getLabelSet().size());
                    ArrayList arrayList = new ArrayList();
                    arrayList.add(new ArrayExample(multiLabel));
                    arrayList.add(new ArrayExample(multiLabel2));
                    arrayList.add(new ArrayExample(multiLabel3));
                    List transformToBatchPrediction = transformer.transformToBatchPrediction(Collections.singletonList(createTensor), constructInfoForExternalModel, new int[]{1, 1, 1}, arrayList);
                    Prediction prediction = (Prediction) transformToBatchPrediction.get(0);
                    MultiLabel output = prediction.getOutput();
                    Assertions.assertFalse(output.contains("A"));
                    Assertions.assertTrue(output.contains("B"));
                    Assertions.assertTrue(output.contains("C"));
                    Assertions.assertFalse(output.contains("D"));
                    Assertions.assertEquals(0.10000000149011612d, ((MultiLabel) prediction.getOutputScores().get("A")).getLabelScore(new Label("A")).getAsDouble());
                    Assertions.assertEquals(0.5099999904632568d, ((MultiLabel) prediction.getOutputScores().get("B")).getLabelScore(new Label("B")).getAsDouble());
                    Assertions.assertEquals(0.800000011920929d, ((MultiLabel) prediction.getOutputScores().get("C")).getLabelScore(new Label("C")).getAsDouble());
                    Assertions.assertEquals(0.0d, ((MultiLabel) prediction.getOutputScores().get("D")).getLabelScore(new Label("D")).getAsDouble());
                    Prediction prediction2 = (Prediction) transformToBatchPrediction.get(1);
                    MultiLabel output2 = prediction2.getOutput();
                    Assertions.assertTrue(output2.contains("A"));
                    Assertions.assertFalse(output2.contains("B"));
                    Assertions.assertFalse(output2.contains("C"));
                    Assertions.assertTrue(output2.contains("D"));
                    Assertions.assertEquals(0.8999999761581421d, ((MultiLabel) prediction2.getOutputScores().get("A")).getLabelScore(new Label("A")).getAsDouble());
                    Assertions.assertEquals(0.10000000149011612d, ((MultiLabel) prediction2.getOutputScores().get("B")).getLabelScore(new Label("B")).getAsDouble());
                    Assertions.assertEquals(0.20000000298023224d, ((MultiLabel) prediction2.getOutputScores().get("C")).getLabelScore(new Label("C")).getAsDouble());
                    Assertions.assertEquals(1.0d, ((MultiLabel) prediction2.getOutputScores().get("D")).getLabelScore(new Label("D")).getAsDouble());
                    Prediction prediction3 = (Prediction) transformToBatchPrediction.get(2);
                    MultiLabel output3 = prediction3.getOutput();
                    Assertions.assertFalse(output3.contains("A"));
                    Assertions.assertFalse(output3.contains("B"));
                    Assertions.assertFalse(output3.contains("C"));
                    Assertions.assertFalse(output3.contains("D"));
                    Assertions.assertEquals(0.0d, ((MultiLabel) prediction3.getOutputScores().get("A")).getLabelScore(new Label("A")).getAsDouble());
                    Assertions.assertEquals(0.0d, ((MultiLabel) prediction3.getOutputScores().get("B")).getLabelScore(new Label("B")).getAsDouble());
                    Assertions.assertEquals(0.0d, ((MultiLabel) prediction3.getOutputScores().get("C")).getLabelScore(new Label("C")).getAsDouble());
                    Assertions.assertEquals(0.0d, ((MultiLabel) prediction3.getOutputScores().get("D")).getLabelScore(new Label("D")).getAsDouble());
                    if (environment != null) {
                        if (0 != 0) {
                            try {
                                environment.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            environment.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (OrtException e) {
            Assertions.fail(e);
        }
    }
}
