package dev.langchain4j.model.scoring.onnx;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.util.PairList;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Set;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:dev/langchain4j/model/scoring/onnx/OnnxScoringBertCrossEncoder.class */
public class OnnxScoringBertCrossEncoder {
    private final OrtEnvironment environment;
    private final OrtSession session;
    private final Set<String> expectedInputs;
    private final HuggingFaceTokenizer tokenizer;
    private final boolean normalize;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dev/langchain4j/model/scoring/onnx/OnnxScoringBertCrossEncoder$ScoringAndTokenCount.class */
    public static class ScoringAndTokenCount {
        List<Double> scores;
        int tokenCount;

        ScoringAndTokenCount(List<Double> list, int i) {
            this.scores = list;
            this.tokenCount = i;
        }
    }

    public OnnxScoringBertCrossEncoder(String str, OrtSession.SessionOptions sessionOptions, String str2, final int i, boolean z) {
        try {
            this.environment = OrtEnvironment.getEnvironment();
            this.session = this.environment.createSession(str, sessionOptions);
            this.expectedInputs = this.session.getInputNames();
            HashMap<String, String> hashMap = new HashMap<String, String>() { // from class: dev.langchain4j.model.scoring.onnx.OnnxScoringBertCrossEncoder.1
                {
                    put("padding", "true");
                    put("truncation", "LONGEST_FIRST");
                    put("modelMaxLength", String.valueOf(i - 2));
                }
            };
            this.normalize = z;
            this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(str2, new String[0]), hashMap);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ScoringAndTokenCount scoreAll(String str, List<String> list) {
        int i = 0;
        int size = this.tokenizer.tokenize(str).size() - 2;
        PairList<String, String> pairList = new PairList<>();
        for (String str2 : list) {
            pairList.add(str, str2);
            i += (size + this.tokenizer.tokenize(str2).size()) - 2;
        }
        try {
            OrtSession.Result encode = encode(pairList);
            try {
                List<Double> score = toScore(encode);
                if (encode != null) {
                    encode.close();
                }
                return new ScoringAndTokenCount(score, i);
            } finally {
            }
        } catch (OrtException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private OrtSession.Result encode(PairList<String, String> pairList) throws OrtException {
        Encoding[] batchEncode = this.tokenizer.batchEncode(pairList);
        long[] jArr = new long[batchEncode.length];
        long[] jArr2 = new long[batchEncode.length];
        long[] jArr3 = new long[batchEncode.length];
        for (int i = 0; i < batchEncode.length; i++) {
            jArr[i] = batchEncode[i].getIds();
            jArr2[i] = batchEncode[i].getAttentionMask();
            jArr3[i] = batchEncode[i].getTypeIds();
        }
        OnnxTensor createTensor = OnnxTensor.createTensor(this.environment, jArr);
        try {
            OnnxTensor createTensor2 = OnnxTensor.createTensor(this.environment, jArr2);
            try {
                OnnxTensor createTensor3 = OnnxTensor.createTensor(this.environment, jArr3);
                try {
                    HashMap hashMap = new HashMap();
                    hashMap.put("input_ids", createTensor);
                    hashMap.put("attention_mask", createTensor2);
                    if (this.expectedInputs.contains("token_type_ids")) {
                        hashMap.put("token_type_ids", createTensor3);
                    }
                    OrtSession.Result run = this.session.run(hashMap);
                    if (createTensor3 != null) {
                        createTensor3.close();
                    }
                    if (createTensor2 != null) {
                        createTensor2.close();
                    }
                    if (createTensor != null) {
                        createTensor.close();
                    }
                    return run;
                } catch (Throwable th) {
                    if (createTensor3 != null) {
                        try {
                            createTensor3.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Throwable th3) {
                if (createTensor2 != null) {
                    try {
                        createTensor2.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (Throwable th5) {
            if (createTensor != null) {
                try {
                    createTensor.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
            }
            throw th5;
        }
    }

    private List<Double> toScore(OrtSession.Result result) throws OrtException {
        float[][] fArr = (float[][]) result.get(0).getValue();
        ArrayList arrayList = new ArrayList();
        for (float[] fArr2 : fArr) {
            if (this.normalize) {
                arrayList.add(Double.valueOf(sigmoid(fArr2[0])));
            } else {
                arrayList.add(Double.valueOf(fArr2[0]));
            }
        }
        return arrayList;
    }

    private double sigmoid(float f) {
        return 1.0d / (1.0d + Math.exp(-f));
    }
}
