package io.trino.plugin.ml;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/plugin/ml/TestFeatureTransformations.class */
public class TestFeatureTransformations {
    @Test
    public void testUnitNormalizer() {
        FeatureUnitNormalizer featureUnitNormalizer = new FeatureUnitNormalizer();
        Dataset dataset = TestUtils.getDataset();
        boolean z = false;
        Iterator it = dataset.getDatapoints().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((FeatureVector) it.next()).getFeatures().values().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                } else if (((Double) it2.next()).doubleValue() > 1.0d) {
                    z = true;
                    break;
                }
            }
        }
        Assertions.assertThat(z).isTrue();
        featureUnitNormalizer.train(dataset);
        Iterator it3 = featureUnitNormalizer.transform(dataset).getDatapoints().iterator();
        while (it3.hasNext()) {
            Iterator it4 = ((FeatureVector) it3.next()).getFeatures().values().iterator();
            while (it4.hasNext()) {
                Assertions.assertThat(((Double) it4.next()).doubleValue() <= 1.0d).isTrue();
            }
        }
    }

    @Test
    public void testUnitNormalizerSimple() {
        FeatureUnitNormalizer featureUnitNormalizer = new FeatureUnitNormalizer();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 3; i++) {
            arrayList.add(Double.valueOf(0.0d));
            arrayList2.add(new FeatureVector(0, i));
        }
        Dataset dataset = new Dataset(arrayList, arrayList2, ImmutableMap.of());
        featureUnitNormalizer.train(dataset);
        HashSet hashSet = new HashSet();
        Iterator it = featureUnitNormalizer.transform(dataset).getDatapoints().iterator();
        while (it.hasNext()) {
            hashSet.addAll(((FeatureVector) it.next()).getFeatures().values());
        }
        Assertions.assertThat(hashSet).isEqualTo(ImmutableSet.of(Double.valueOf(0.0d), Double.valueOf(0.5d), Double.valueOf(1.0d)));
    }
}
