package com.datasqrl.vector;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.LoadingCache;
import java.util.List;
import java.util.concurrent.ExecutionException;
import org.apache.flink.table.functions.ScalarFunction;
import org.tribuo.interop.onnx.extractors.OnnxRunner;
import org.tribuo.interop.onnx.extractors.Tokenizer;

/* loaded from: input_file:com/datasqrl/vector/OnnxEmbed.class */
public class OnnxEmbed extends ScalarFunction {
    public LoadingCache<String, CachedModel> models = CacheBuilder.newBuilder().maximumSize(100).build(new CacheLoaderImpl());

    /* loaded from: input_file:com/datasqrl/vector/OnnxEmbed$CachedModel.class */
    public static final class CachedModel {
        private final OnnxRunner runner;
        private final Tokenizer tokenizer;

        public double[] embedd(String str) throws Exception {
            List<String> list = this.tokenizer.tokenize(str);
            return this.runner.run(this.tokenizer.convertTokens(this.runner.env, list), list);
        }

        public CachedModel(OnnxRunner onnxRunner, Tokenizer tokenizer) {
            this.runner = onnxRunner;
            this.tokenizer = tokenizer;
        }

        public OnnxRunner getRunner() {
            return this.runner;
        }

        public Tokenizer getTokenizer() {
            return this.tokenizer;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof CachedModel)) {
                return false;
            }
            CachedModel cachedModel = (CachedModel) obj;
            OnnxRunner runner = getRunner();
            OnnxRunner runner2 = cachedModel.getRunner();
            if (runner == null) {
                if (runner2 != null) {
                    return false;
                }
            } else if (!runner.equals(runner2)) {
                return false;
            }
            Tokenizer tokenizer = getTokenizer();
            Tokenizer tokenizer2 = cachedModel.getTokenizer();
            return tokenizer == null ? tokenizer2 == null : tokenizer.equals(tokenizer2);
        }

        public int hashCode() {
            OnnxRunner runner = getRunner();
            int hashCode = (1 * 59) + (runner == null ? 43 : runner.hashCode());
            Tokenizer tokenizer = getTokenizer();
            return (hashCode * 59) + (tokenizer == null ? 43 : tokenizer.hashCode());
        }

        public String toString() {
            return "OnnxEmbed.CachedModel(runner=" + getRunner() + ", tokenizer=" + getTokenizer() + ")";
        }
    }

    public FlinkVectorType eval(String str, String str2) {
        if (str == null || str2 == null) {
            return null;
        }
        try {
            return VectorFunctions.convert(((CachedModel) this.models.get(str2)).embedd(str));
        } catch (RuntimeException e) {
            throw e;
        } catch (ExecutionException e2) {
            throw new RuntimeException(e2.getCause());
        } catch (Exception e3) {
            throw new RuntimeException(e3);
        }
    }
}
