package io.trino.plugin.ml;

import io.trino.plugin.ml.type.ModelType;
import io.trino.plugin.ml.type.RegressorType;
import java.util.List;
import java.util.Objects;

/* loaded from: input_file:io/trino/plugin/ml/RegressorFeatureTransformer.class */
public class RegressorFeatureTransformer implements Regressor {
    private final Regressor regressor;
    private final FeatureTransformation transformation;

    public RegressorFeatureTransformer(Regressor regressor, FeatureTransformation featureTransformation) {
        this.regressor = (Regressor) Objects.requireNonNull(regressor, "regressor is null");
        this.transformation = (FeatureTransformation) Objects.requireNonNull(featureTransformation, "transformation is null");
    }

    @Override // io.trino.plugin.ml.Model
    public ModelType getType() {
        return RegressorType.REGRESSOR;
    }

    @Override // io.trino.plugin.ml.Model
    public byte[] getSerializedData() {
        return ModelUtils.serializeModels(this.regressor, this.transformation);
    }

    public static RegressorFeatureTransformer deserialize(byte[] bArr) {
        List<Model> deserializeModels = ModelUtils.deserializeModels(bArr);
        return new RegressorFeatureTransformer((Regressor) deserializeModels.get(0), (FeatureTransformation) deserializeModels.get(1));
    }

    @Override // io.trino.plugin.ml.Regressor
    public double regress(FeatureVector featureVector) {
        return this.regressor.regress(this.transformation.transform(featureVector));
    }

    @Override // io.trino.plugin.ml.Model
    public void train(Dataset dataset) {
        this.transformation.train(dataset);
        this.regressor.train(this.transformation.transform(dataset));
    }
}
