package org.tribuo.interop.onnx.extractors;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.nio.FloatBuffer;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:org/tribuo/interop/onnx/extractors/OnnxRunner.class */
public class OnnxRunner {
    public final OrtEnvironment env = OrtEnvironment.getEnvironment();
    public final OrtSession session;
    public final int dimension;
    public final int inputSize;

    public OnnxRunner(Path path) throws Exception {
        this.session = this.env.createSession(path.toString(), new OrtSession.SessionOptions());
        Map outputInfo = this.session.getOutputInfo();
        this.inputSize = this.session.getInputInfo().size();
        this.dimension = (int) ((NodeInfo) outputInfo.values().iterator().next()).getInfo().getShape()[2];
    }

    /* JADX WARN: Multi-variable type inference failed */
    public OnnxTensor createTensor(int i, long j) throws OrtException {
        long[] jArr = new long[i];
        Arrays.fill(jArr, j);
        return OnnxTensor.createTensor(this.env, new long[]{jArr});
    }

    public double[] run(OnnxTensor onnxTensor, List<String> list) throws Exception {
        OnnxTensor createTensor = createTensor(list.size() + 2, 1L);
        OnnxTensor createTensor2 = createTensor(list.size() + 2, 0L);
        HashMap hashMap = new HashMap(3);
        hashMap.put("input_ids", onnxTensor);
        hashMap.put("attention_mask", createTensor);
        if (this.inputSize > 2) {
            hashMap.put("token_type_ids", createTensor2);
        }
        OrtSession.Result run = this.session.run(hashMap);
        double[] extractCLSVector = extractCLSVector(run);
        double[] extractMeanTokenVector = extractMeanTokenVector(run, list.size(), true);
        double[] dArr = new double[this.dimension];
        for (int i = 0; i < this.dimension; i++) {
            dArr[i] = (extractCLSVector[i] + extractMeanTokenVector[i]) / 2.0d;
        }
        return extractMeanTokenVector;
    }

    private double[] extractCLSVector(OrtSession.Result result) {
        return extractFeatures(result.get(0).getFloatBuffer(), this.dimension);
    }

    private double[] extractMeanTokenVector(OrtSession.Result result, int i, boolean z) {
        FloatBuffer floatBuffer = result.get(0).getFloatBuffer();
        double[] dArr = new double[this.dimension];
        floatBuffer.position(this.dimension);
        for (int i2 = 0; i2 < i; i2++) {
            addFeatures(floatBuffer, this.dimension, dArr);
        }
        if (z) {
            for (int i3 = 0; i3 < this.dimension; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / i;
            }
        }
        return dArr;
    }

    private static void addFeatures(FloatBuffer floatBuffer, int i, double[] dArr) {
        float[] fArr = new float[i];
        floatBuffer.get(fArr);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] + fArr[i2];
        }
    }

    private static double[] extractFeatures(FloatBuffer floatBuffer, int i) {
        double[] dArr = new double[i];
        float[] fArr = new float[i];
        floatBuffer.get(fArr);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            dArr[i2] = fArr[i2];
        }
        return dArr;
    }
}
