package io.trino.plugin.ml;

import io.airlift.slice.Slice;
import io.trino.plugin.ml.type.RegressorType;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.SqlMap;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;

@AggregationFunction(value = "learn_libsvm_regressor", decomposable = false)
/* loaded from: input_file:io/trino/plugin/ml/LearnLibSvmRegressorAggregation.class */
public final class LearnLibSvmRegressorAggregation {
    private LearnLibSvmRegressorAggregation() {
    }

    @InputFunction
    public static void input(@AggregationState LearnState learnState, @SqlType("bigint") long j, @SqlType("map(bigint,double)") SqlMap sqlMap, @SqlType("varchar") Slice slice) {
        input(learnState, j, sqlMap, slice);
    }

    @InputFunction
    public static void input(@AggregationState LearnState learnState, @SqlType("double") double d, @SqlType("map(bigint,double)") SqlMap sqlMap, @SqlType("varchar") Slice slice) {
        learnState.getLabels().add(Double.valueOf(d));
        FeatureVector features = ModelUtils.toFeatures(sqlMap);
        learnState.addMemoryUsage(features.getEstimatedSize());
        learnState.getFeatureVectors().add(features);
        learnState.setParameters(slice);
    }

    @OutputFunction(RegressorType.NAME)
    public static void output(@AggregationState LearnState learnState, BlockBuilder blockBuilder) {
        Dataset dataset = new Dataset(learnState.getLabels(), learnState.getFeatureVectors(), learnState.getLabelEnumeration().inverse());
        RegressorFeatureTransformer regressorFeatureTransformer = new RegressorFeatureTransformer(new SvmRegressor(LibSvmUtils.parseParameters(learnState.getParameters().toStringUtf8())), new FeatureUnitNormalizer());
        regressorFeatureTransformer.train(dataset);
        RegressorType.REGRESSOR.writeSlice(blockBuilder, ModelUtils.serialize(regressorFeatureTransformer));
    }
}
