package io.trino.plugin.ml;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slices;
import io.trino.plugin.ml.type.ClassifierType;
import io.trino.plugin.ml.type.ModelType;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:io/trino/plugin/ml/StringClassifierAdapter.class */
public class StringClassifierAdapter implements Classifier<String> {
    private final Classifier<Integer> classifier;
    private final Map<Integer, String> labelEnumeration;

    public StringClassifierAdapter(Classifier<Integer> classifier) {
        this(classifier, new HashMap());
    }

    public StringClassifierAdapter(Classifier<Integer> classifier, Map<Integer, String> map) {
        this.classifier = (Classifier) Objects.requireNonNull(classifier, "classifier is null");
        this.labelEnumeration = (Map) Objects.requireNonNull(map, "labelEnumeration is null");
    }

    @Override // io.trino.plugin.ml.Model
    public ModelType getType() {
        return ClassifierType.VARCHAR_CLASSIFIER;
    }

    @Override // io.trino.plugin.ml.Model
    public byte[] getSerializedData() {
        byte[] bytes = ModelUtils.serialize(this.classifier).getBytes();
        DynamicSliceOutput dynamicSliceOutput = new DynamicSliceOutput(bytes.length + (64 * this.labelEnumeration.size()));
        dynamicSliceOutput.appendInt(bytes.length);
        dynamicSliceOutput.appendBytes(bytes);
        dynamicSliceOutput.appendInt(this.labelEnumeration.size());
        for (Map.Entry<Integer, String> entry : this.labelEnumeration.entrySet()) {
            dynamicSliceOutput.appendInt(entry.getKey().intValue());
            byte[] bytes2 = entry.getValue().getBytes(StandardCharsets.UTF_8);
            dynamicSliceOutput.appendInt(bytes2.length);
            dynamicSliceOutput.appendBytes(bytes2);
        }
        return dynamicSliceOutput.slice().getBytes();
    }

    public static StringClassifierAdapter deserialize(byte[] bArr) {
        BasicSliceInput input = Slices.wrappedBuffer(bArr).getInput();
        Model deserialize = ModelUtils.deserialize(input.readSlice(input.readInt()));
        int readInt = input.readInt();
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i = 0; i < readInt; i++) {
            builder.put(Integer.valueOf(input.readInt()), input.readSlice(input.readInt()).toStringUtf8());
        }
        return new StringClassifierAdapter((Classifier) deserialize, builder.buildOrThrow());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.trino.plugin.ml.Classifier
    public String classify(FeatureVector featureVector) {
        int intValue = this.classifier.classify(featureVector).intValue();
        Preconditions.checkState(this.labelEnumeration.containsKey(Integer.valueOf(intValue)), "classifier predicted an unknown class %s", intValue);
        return this.labelEnumeration.get(Integer.valueOf(intValue));
    }

    @Override // io.trino.plugin.ml.Model
    public void train(Dataset dataset) {
        this.labelEnumeration.putAll(dataset.getLabelEnumeration());
        this.classifier.train(dataset);
    }
}
