package io.trino.plugin.ml;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hashing;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlMap;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;

/* loaded from: input_file:io/trino/plugin/ml/ModelUtils.class */
public final class ModelUtils {
    private static final int VERSION_OFFSET = 0;
    private static final int HASH_OFFSET = 4;
    private static final int ALGORITHM_OFFSET = 36;
    private static final int HYPERPARAMETER_LENGTH_OFFSET = 40;
    private static final int HYPERPARAMETERS_OFFSET = 44;
    private static final int CURRENT_FORMAT_VERSION = 1;

    @VisibleForTesting
    static final BiMap<Class<? extends Model>, Integer> MODEL_SERIALIZATION_IDS;

    private ModelUtils() {
    }

    public static Slice serialize(Model model) {
        Objects.requireNonNull(model, "model is null");
        Integer num = (Integer) MODEL_SERIALIZATION_IDS.get(model.getClass());
        Objects.requireNonNull(num, "id is null");
        byte[] bArr = new byte[VERSION_OFFSET];
        int length = HYPERPARAMETERS_OFFSET + bArr.length;
        int i = length + 8;
        byte[] serializedData = model.getSerializedData();
        Slice allocate = Slices.allocate(i + serializedData.length);
        allocate.setInt(VERSION_OFFSET, CURRENT_FORMAT_VERSION);
        allocate.setInt(ALGORITHM_OFFSET, num.intValue());
        allocate.setInt(HYPERPARAMETER_LENGTH_OFFSET, bArr.length);
        allocate.setBytes(HYPERPARAMETERS_OFFSET, bArr);
        allocate.setLong(length, serializedData.length);
        allocate.setBytes(i, serializedData);
        byte[] asBytes = Hashing.sha256().hashBytes(allocate.getBytes(ALGORITHM_OFFSET, allocate.length() - ALGORITHM_OFFSET)).asBytes();
        Preconditions.checkState(asBytes.length == 32, "sha256 hash code expected to be 32 bytes");
        allocate.setBytes(HASH_OFFSET, asBytes);
        return allocate;
    }

    public static HashCode modelHash(Slice slice) {
        return HashCode.fromBytes(slice.getBytes(HASH_OFFSET, 32));
    }

    public static Model deserialize(byte[] bArr) {
        return deserialize(Slices.wrappedBuffer(bArr));
    }

    public static Model deserialize(Slice slice) {
        int i = slice.getInt(VERSION_OFFSET);
        Preconditions.checkArgument(i == CURRENT_FORMAT_VERSION, "Unsupported version: %s", i);
        Preconditions.checkArgument(Hashing.sha256().hashBytes(slice.getBytes(ALGORITHM_OFFSET, slice.length() - ALGORITHM_OFFSET)).equals(HashCode.fromBytes(slice.getBytes(HASH_OFFSET, 32))), "model hash does not match data");
        int i2 = slice.getInt(ALGORITHM_OFFSET);
        Class cls = (Class) MODEL_SERIALIZATION_IDS.inverse().get(Integer.valueOf(i2));
        Objects.requireNonNull(cls, String.format("Unsupported algorith %d", Integer.valueOf(i2)));
        int i3 = HYPERPARAMETERS_OFFSET + slice.getInt(HYPERPARAMETER_LENGTH_OFFSET);
        try {
            return (Model) cls.getMethod("deserialize", byte[].class).invoke(null, slice.getBytes(i3 + 8, (int) slice.getLong(i3)));
        } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    public static byte[] serializeModels(Model... modelArr) {
        ArrayList<byte[]> arrayList = new ArrayList();
        int length = HASH_OFFSET + (HASH_OFFSET * modelArr.length);
        int length2 = modelArr.length;
        for (int i = VERSION_OFFSET; i < length2; i += CURRENT_FORMAT_VERSION) {
            byte[] bytes = serialize(modelArr[i]).getBytes();
            length += bytes.length;
            arrayList.add(bytes);
        }
        Slice allocate = Slices.allocate(length);
        allocate.setInt(VERSION_OFFSET, modelArr.length);
        for (int i2 = VERSION_OFFSET; i2 < modelArr.length; i2 += CURRENT_FORMAT_VERSION) {
            allocate.setInt(HASH_OFFSET * (i2 + CURRENT_FORMAT_VERSION), ((byte[]) arrayList.get(i2)).length);
        }
        int length3 = HASH_OFFSET + (HASH_OFFSET * modelArr.length);
        for (byte[] bArr : arrayList) {
            allocate.setBytes(length3, bArr);
            length3 += bArr.length;
        }
        return allocate.getBytes();
    }

    public static List<Model> deserializeModels(byte[] bArr) {
        Slice wrappedBuffer = Slices.wrappedBuffer(bArr);
        int i = wrappedBuffer.getInt(VERSION_OFFSET);
        int i2 = HASH_OFFSET + (HASH_OFFSET * i);
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i3 = VERSION_OFFSET; i3 < i; i3 += CURRENT_FORMAT_VERSION) {
            int i4 = wrappedBuffer.getInt(HASH_OFFSET * (i3 + CURRENT_FORMAT_VERSION));
            builder.add(deserialize(wrappedBuffer.getBytes(i2, i4)));
            i2 += i4;
        }
        return builder.build();
    }

    public static FeatureVector toFeatures(SqlMap sqlMap) {
        HashMap hashMap = new HashMap();
        if (sqlMap != null) {
            int rawOffset = sqlMap.getRawOffset();
            Block rawKeyBlock = sqlMap.getRawKeyBlock();
            Block rawValueBlock = sqlMap.getRawValueBlock();
            for (int i = VERSION_OFFSET; i < sqlMap.getSize(); i += CURRENT_FORMAT_VERSION) {
                hashMap.put(Integer.valueOf((int) BigintType.BIGINT.getLong(rawKeyBlock, rawOffset + i)), Double.valueOf(DoubleType.DOUBLE.getDouble(rawValueBlock, rawOffset + i)));
            }
        }
        return new FeatureVector(hashMap);
    }

    static {
        ImmutableBiMap.Builder builder = ImmutableBiMap.builder();
        builder.put(SvmClassifier.class, Integer.valueOf(CURRENT_FORMAT_VERSION));
        builder.put(SvmRegressor.class, 2);
        builder.put(FeatureVectorUnitNormalizer.class, 3);
        builder.put(ClassifierFeatureTransformer.class, Integer.valueOf(HASH_OFFSET));
        builder.put(RegressorFeatureTransformer.class, 5);
        builder.put(FeatureUnitNormalizer.class, 6);
        builder.put(StringClassifierAdapter.class, 7);
        MODEL_SERIALIZATION_IDS = builder.build();
    }
}
