package io.trino.plugin.ml;

import org.assertj.core.api.Assertions;
import org.assertj.core.api.ObjectAssert;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/plugin/ml/TestModelSerialization.class */
public class TestModelSerialization {
    @Test
    public void testSvmClassifier() {
        SvmClassifier svmClassifier = new SvmClassifier();
        svmClassifier.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(svmClassifier));
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialization failed", new Object[0])).isNotNull();
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialized model is not a svm", new Object[0])).isInstanceOf(SvmClassifier.class);
    }

    @Test
    public void testSvmRegressor() {
        SvmRegressor svmRegressor = new SvmRegressor();
        svmRegressor.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(svmRegressor));
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialization failed", new Object[0])).isNotNull();
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialized model is not a svm", new Object[0])).isInstanceOf(SvmRegressor.class);
    }

    @Test
    public void testRegressorFeatureTransformer() {
        RegressorFeatureTransformer regressorFeatureTransformer = new RegressorFeatureTransformer(new SvmRegressor(), new FeatureVectorUnitNormalizer());
        regressorFeatureTransformer.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(regressorFeatureTransformer));
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialization failed", new Object[0])).isNotNull();
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialized model is not a regressor feature transformer", new Object[0])).isInstanceOf(RegressorFeatureTransformer.class);
    }

    @Test
    public void testClassifierFeatureTransformer() {
        ClassifierFeatureTransformer classifierFeatureTransformer = new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer());
        classifierFeatureTransformer.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(classifierFeatureTransformer));
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialization failed", new Object[0])).isNotNull();
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialized model is not a classifier feature transformer", new Object[0])).isInstanceOf(ClassifierFeatureTransformer.class);
    }

    @Test
    public void testVarcharClassifierAdapter() {
        StringClassifierAdapter stringClassifierAdapter = new StringClassifierAdapter(new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer()));
        stringClassifierAdapter.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(stringClassifierAdapter));
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialization failed", new Object[0])).isNotNull();
        ((ObjectAssert) Assertions.assertThat(deserialize).describedAs("deserialized model is not a varchar classifier adapter", new Object[0])).isInstanceOf(StringClassifierAdapter.class);
    }

    @Test
    public void testSerializationIds() {
        Assertions.assertThat(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmClassifier.class)).intValue()).isEqualTo(1);
        Assertions.assertThat(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmRegressor.class)).intValue()).isEqualTo(2);
        Assertions.assertThat(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(FeatureVectorUnitNormalizer.class)).intValue()).isEqualTo(3);
        Assertions.assertThat(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(ClassifierFeatureTransformer.class)).intValue()).isEqualTo(4);
        Assertions.assertThat(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(RegressorFeatureTransformer.class)).intValue()).isEqualTo(5);
        Assertions.assertThat(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(FeatureUnitNormalizer.class)).intValue()).isEqualTo(6);
        Assertions.assertThat(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(StringClassifierAdapter.class)).intValue()).isEqualTo(7);
    }
}
