package com.feedzai.openml.h2o;

import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.data.schema.NumericValueSchema;
import com.feedzai.openml.data.schema.StringValueSchema;
import com.feedzai.openml.mocks.MockInstance;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.nio.file.Paths;
import java.util.Arrays;
import org.assertj.core.api.Assertions;
import org.junit.Test;

/* loaded from: input_file:com/feedzai/openml/h2o/ImportAllFeaturesTypesTest.class */
public class ImportAllFeaturesTypesTest {
    @Test
    public void test() throws ModelLoadingException {
        AbstractClassificationH2OModel loadModel = new H2OModelCreator(H2OAlgorithm.DISTRIBUTED_RANDOM_FOREST.getAlgorithmDescriptor()).loadModel(Paths.get(getClass().getResource("/drf-all-types-features").getPath(), new String[0]), new DatasetSchema(4, ImmutableList.of(new FieldSchema("Amount", 0, new NumericValueSchema(false)), new FieldSchema("MerchantID", 1, new StringValueSchema(true)), new FieldSchema("MCC", 2, new CategoricalValueSchema(true, ImmutableSet.of("1711", "3010", "5411", "5812", "6011", "7995", new String[]{"8398"}))), new FieldSchema("AccountCreatedAt", 3, new NumericValueSchema(true)), new FieldSchema("FraudLabel", 4, new CategoricalValueSchema(true, ImmutableSet.of("GENUINE", "FRAUD"))))));
        MockInstance mockInstance = new MockInstance(ImmutableList.of(Double.valueOf(14923.103d), "Amazon", Double.valueOf(0.0d), Double.valueOf(1.559312575E12d), Double.valueOf(0.0d)));
        Assertions.assertThat(loadModel.classify(mockInstance)).as("the score", new Object[0]).isBetween(0, 1);
        Assertions.assertThat(Arrays.stream(loadModel.getClassDistribution(mockInstance)).sum()).as("sum of the class distribution", new Object[0]).isCloseTo(1.0d, Assertions.within(Double.valueOf(0.01d)));
    }
}
