package io.trino.plugin.ml;

import com.google.common.base.Preconditions;
import com.google.common.cache.CacheBuilder;
import com.google.common.hash.HashCode;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.cache.CacheUtils;
import io.trino.cache.NonEvictableCache;
import io.trino.cache.SafeCaches;
import io.trino.plugin.ml.type.ClassifierType;
import io.trino.plugin.ml.type.RegressorType;
import io.trino.spi.block.Block;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;

/* loaded from: input_file:io/trino/plugin/ml/MLFunctions.class */
public final class MLFunctions {
    private static final NonEvictableCache<HashCode, Model> MODEL_CACHE = SafeCaches.buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(5));
    private static final String MAP_BIGINT_DOUBLE = "map(bigint,double)";

    private MLFunctions() {
    }

    @ScalarFunction("classify")
    @SqlType("varchar")
    public static Slice varcharClassify(@SqlType("map(bigint,double)") Block block, @SqlType("Classifier(varchar)") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Model orLoadModel = getOrLoadModel(slice);
        Preconditions.checkArgument(orLoadModel.getType().equals(ClassifierType.VARCHAR_CLASSIFIER), "model is not a Classifier(varchar)");
        return Slices.utf8Slice((String) ((Classifier) orLoadModel).classify(features));
    }

    @ScalarFunction
    @SqlType("bigint")
    public static long classify(@SqlType("map(bigint,double)") Block block, @SqlType("Classifier(bigint)") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Preconditions.checkArgument(getOrLoadModel(slice).getType().equals(ClassifierType.BIGINT_CLASSIFIER), "model is not a Classifier(bigint)");
        return ((Integer) ((Classifier) r0).classify(features)).intValue();
    }

    @ScalarFunction
    @SqlType("double")
    public static double regress(@SqlType("map(bigint,double)") Block block, @SqlType("Regressor") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Model orLoadModel = getOrLoadModel(slice);
        Preconditions.checkArgument(orLoadModel.getType().equals(RegressorType.REGRESSOR), "model is not a regressor");
        return ((Regressor) orLoadModel).regress(features);
    }

    private static Model getOrLoadModel(Slice slice) {
        return (Model) CacheUtils.uncheckedCacheGet(MODEL_CACHE, ModelUtils.modelHash(slice), () -> {
            return ModelUtils.deserialize(slice);
        });
    }
}
