package io.trino.plugin.ml;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.trino.array.ObjectBigArray;
import io.trino.array.SliceBigArray;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.GroupedAccumulatorState;
import java.util.ArrayList;
import java.util.List;
import libsvm.svm_parameter;

/* loaded from: input_file:io/trino/plugin/ml/LearnStateFactory.class */
public class LearnStateFactory implements AccumulatorStateFactory<LearnState> {
    private static final long ARRAY_LIST_SIZE = SizeOf.instanceSize(ArrayList.class);
    private static final long SVM_PARAMETERS_SIZE = SizeOf.instanceSize(svm_parameter.class);

    /* loaded from: input_file:io/trino/plugin/ml/LearnStateFactory$GroupedLearnState.class */
    public static class GroupedLearnState implements GroupedAccumulatorState, LearnState {
        private final ObjectBigArray<List<Double>> labelsArray = new ObjectBigArray<>();
        private final ObjectBigArray<List<FeatureVector>> featureVectorsArray = new ObjectBigArray<>();
        private final SliceBigArray parametersArray = new SliceBigArray();
        private final BiMap<String, Integer> labelEnumeration = HashBiMap.create();
        private long groupId;
        private int nextLabel;
        private long size;

        public void setGroupId(long j) {
            this.groupId = j;
        }

        public void ensureCapacity(long j) {
            this.labelsArray.ensureCapacity(j);
            this.featureVectorsArray.ensureCapacity(j);
            this.parametersArray.ensureCapacity(j);
        }

        public long getEstimatedSize() {
            return this.size + this.labelsArray.sizeOf() + this.featureVectorsArray.sizeOf();
        }

        @Override // io.trino.plugin.ml.LearnState
        public BiMap<String, Integer> getLabelEnumeration() {
            return this.labelEnumeration;
        }

        @Override // io.trino.plugin.ml.LearnState
        public int enumerateLabel(String str) {
            if (!this.labelEnumeration.containsKey(str)) {
                this.labelEnumeration.put(str, Integer.valueOf(this.nextLabel));
                this.nextLabel++;
            }
            return ((Integer) this.labelEnumeration.get(str)).intValue();
        }

        @Override // io.trino.plugin.ml.LearnState
        public List<Double> getLabels() {
            List<Double> list = (List) this.labelsArray.get(this.groupId);
            if (list == null) {
                list = new ArrayList();
                this.size += LearnStateFactory.ARRAY_LIST_SIZE;
                this.size += LearnStateFactory.SVM_PARAMETERS_SIZE;
                this.labelsArray.set(this.groupId, list);
            }
            return list;
        }

        @Override // io.trino.plugin.ml.LearnState
        public List<FeatureVector> getFeatureVectors() {
            List<FeatureVector> list = (List) this.featureVectorsArray.get(this.groupId);
            if (list == null) {
                list = new ArrayList();
                this.size += LearnStateFactory.ARRAY_LIST_SIZE;
                this.featureVectorsArray.set(this.groupId, list);
            }
            return list;
        }

        @Override // io.trino.plugin.ml.LearnState
        public Slice getParameters() {
            return this.parametersArray.get(this.groupId);
        }

        @Override // io.trino.plugin.ml.LearnState
        public void setParameters(Slice slice) {
            this.parametersArray.set(this.groupId, slice);
        }

        @Override // io.trino.plugin.ml.LearnState
        public void addMemoryUsage(long j) {
            this.size += j;
        }
    }

    /* loaded from: input_file:io/trino/plugin/ml/LearnStateFactory$SingleLearnState.class */
    public static class SingleLearnState implements LearnState {
        private final List<Double> labels = new ArrayList();
        private final List<FeatureVector> featureVectors = new ArrayList();
        private final BiMap<String, Integer> labelEnumeration = HashBiMap.create();
        private int nextLabel;
        private Slice parameters;
        private long size;

        public long getEstimatedSize() {
            return this.size + (2 * LearnStateFactory.ARRAY_LIST_SIZE);
        }

        @Override // io.trino.plugin.ml.LearnState
        public BiMap<String, Integer> getLabelEnumeration() {
            return this.labelEnumeration;
        }

        @Override // io.trino.plugin.ml.LearnState
        public int enumerateLabel(String str) {
            if (!this.labelEnumeration.containsKey(str)) {
                this.labelEnumeration.put(str, Integer.valueOf(this.nextLabel));
                this.nextLabel++;
            }
            return ((Integer) this.labelEnumeration.get(str)).intValue();
        }

        @Override // io.trino.plugin.ml.LearnState
        public List<Double> getLabels() {
            return this.labels;
        }

        @Override // io.trino.plugin.ml.LearnState
        public List<FeatureVector> getFeatureVectors() {
            return this.featureVectors;
        }

        @Override // io.trino.plugin.ml.LearnState
        public Slice getParameters() {
            return this.parameters;
        }

        @Override // io.trino.plugin.ml.LearnState
        public void setParameters(Slice slice) {
            this.parameters = slice;
        }

        @Override // io.trino.plugin.ml.LearnState
        public void addMemoryUsage(long j) {
            this.size += j;
        }
    }

    /* renamed from: createSingleState, reason: merged with bridge method [inline-methods] */
    public LearnState m6createSingleState() {
        return new SingleLearnState();
    }

    /* renamed from: createGroupedState, reason: merged with bridge method [inline-methods] */
    public LearnState m5createGroupedState() {
        return new GroupedLearnState();
    }
}
