package com.feedzai.openml.h2o;

import com.feedzai.openml.data.Dataset;
import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.mocks.MockDataset;
import com.feedzai.openml.mocks.MockInstance;
import com.feedzai.openml.provider.exception.ModelTrainingException;
import com.feedzai.openml.util.algorithm.MLAlgorithmEnum;
import com.feedzai.openml.util.provider.AbstractProviderModelTrainTest;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/feedzai/openml/h2o/H2OModelProviderTrainTest.class */
public class H2OModelProviderTrainTest extends AbstractProviderModelTrainTest<AbstractClassificationH2OModel, H2OModelCreator, H2OModelProvider> implements H2ODatasetMixin {
    private static final Logger logger = LoggerFactory.getLogger(H2OModelProviderTrainTest.class);
    private Dataset dataset;

    @Before
    public void createDataset() {
        this.dataset = createDataset(SCHEMA);
    }

    private Dataset createDataset(DatasetSchema datasetSchema) {
        return createDataset(datasetSchema, ((Integer) Optional.ofNullable(H2OAlgorithmTestParams.getIsolationForest().get("sample_size")).map(Integer::parseInt).orElse(256)).intValue() + 100);
    }

    private Dataset createDataset(DatasetSchema datasetSchema, int i) {
        Random random = new Random(234L);
        logger.info("Using dataset size of {}", Integer.valueOf(i));
        return new MockDataset(datasetSchema, (List) IntStream.range(0, i).mapToObj(i2 -> {
            return new MockInstance(datasetSchema, random);
        }).collect(Collectors.toList()));
    }

    /* renamed from: getFirstModel, reason: merged with bridge method [inline-methods] */
    public AbstractClassificationH2OModel m11getFirstModel() throws ModelTrainingException {
        return getMachineLearningModelLoader(H2OAlgorithm.DEEP_LEARNING).fit(TRAIN_DATASET, new Random(0L), ImmutableMap.of());
    }

    /* renamed from: getSecondModel, reason: merged with bridge method [inline-methods] */
    public AbstractClassificationH2OModel m10getSecondModel() throws ModelTrainingException {
        return getMachineLearningModelLoader(H2OAlgorithm.NAIVE_BAYES_CLASSIFIER).fit(TRAIN_DATASET, new Random(1L), ImmutableMap.of());
    }

    public Set<Integer> getClassifyValuesOfFirstModel() {
        return (Set) IntStream.range(0, TARGET_VALUES.size()).boxed().collect(Collectors.toSet());
    }

    public Set<Integer> getClassifyValuesOfSecondModel() {
        return (Set) IntStream.range(0, TARGET_VALUES.size()).boxed().collect(Collectors.toSet());
    }

    /* renamed from: getFirstMachineLearningModelLoader, reason: merged with bridge method [inline-methods] */
    public H2OModelCreator m9getFirstMachineLearningModelLoader() {
        return getMachineLearningModelLoader(H2OAlgorithm.DEEP_LEARNING);
    }

    /* renamed from: getMachineLearningProvider, reason: merged with bridge method [inline-methods] */
    public H2OModelProvider m8getMachineLearningProvider() {
        return new H2OModelProvider();
    }

    public Instance getDummyInstance() {
        return new MockInstance(createDatasetSchema(TARGET_VALUES), new Random(0L));
    }

    public Instance getDummyInstanceDifferentResult() {
        return new MockInstance(createDatasetSchema(TARGET_VALUES), new Random(1L));
    }

    public DatasetSchema createDatasetSchema(Set<String> set) {
        return SCHEMA;
    }

    public MLAlgorithmEnum getValidAlgorithm() {
        return H2OAlgorithm.DEEP_LEARNING;
    }

    public Set<String> getFirstModelTargetNominalValues() {
        return TARGET_VALUES;
    }

    protected Dataset getTrainDataset() {
        return this.dataset;
    }

    protected Map<MLAlgorithmEnum, Map<String, String>> getTrainAlgorithms() {
        ImmutableMap.Builder builder = new ImmutableMap.Builder();
        builder.put(H2OAlgorithm.DEEP_LEARNING, H2OAlgorithmTestParams.getDeepLearning());
        builder.put(H2OAlgorithm.DISTRIBUTED_RANDOM_FOREST, H2OAlgorithmTestParams.getDrf());
        builder.put(H2OAlgorithm.GRADIENT_BOOSTING_MACHINE, H2OAlgorithmTestParams.getGbm());
        builder.put(H2OAlgorithm.NAIVE_BAYES_CLASSIFIER, H2OAlgorithmTestParams.getBayes());
        builder.put(H2OAlgorithm.XG_BOOST, H2OAlgorithmTestParams.getXgboost());
        builder.put(H2OAlgorithm.GENERALIZED_LINEAR_MODEL, H2OAlgorithmTestParams.getGlm());
        builder.put(H2OAlgorithm.ISOLATION_FOREST, H2OAlgorithmTestParams.getIsolationForest());
        return builder.build();
    }

    @Test
    public final void testIsolationForestWithDatasetWithoutTargetVariable() throws ModelTrainingException {
        H2OModelCreator machineLearningModelLoader = getMachineLearningModelLoader(H2OAlgorithm.ISOLATION_FOREST);
        Map<String, String> isolationForest = H2OAlgorithmTestParams.getIsolationForest();
        Random random = new Random(234L);
        Dataset createDataset = createDataset(SCHEMA_NO_TARGET_VARIABLE);
        AbstractClassificationH2OModel fit = machineLearningModelLoader.fit(createDataset, random, isolationForest);
        MockInstance mockInstance = new MockInstance(createDataset.getSchema(), random);
        double[] classDistribution = fit.getClassDistribution(mockInstance);
        AssertionsForClassTypes.assertThat(classDistribution).as("Scoring instance '%s' succeeds", new Object[]{mockInstance}).hasSize(2).matches(dArr -> {
            return DoubleStream.of(dArr).sum() == 1.0d;
        });
        int classify = fit.classify(mockInstance);
        AssertionsForClassTypes.assertThat(classDistribution[classify]).as("The classify method returns the index of the greatest score in the class distribution", new Object[0]).isGreaterThanOrEqualTo(classDistribution[1 - classify]);
    }

    @Test
    public final void testIsolationForestWithNotEnoughInstances() {
        H2OModelCreator machineLearningModelLoader = getMachineLearningModelLoader(H2OAlgorithm.ISOLATION_FOREST);
        Map<String, String> isolationForest = H2OAlgorithmTestParams.getIsolationForest();
        int intValue = ((Integer) Optional.ofNullable(isolationForest.get("sample_size")).map(Integer::parseInt).orElse(256)).intValue();
        Random random = new Random(234L);
        Dataset createDataset = createDataset(SCHEMA_NO_TARGET_VARIABLE, intValue / 2);
        AssertionsForClassTypes.assertThatCode(() -> {
            machineLearningModelLoader.fit(createDataset, random, isolationForest);
        }).as("The training of a model with no out of bag instances", new Object[0]).doesNotThrowAnyException();
    }

    @Test
    public final void testExceptionIsThrownWhenDatasetIsEmpty() {
        Dataset createDataset = createDataset(SCHEMA, 0);
        Map<String, String> isolationForest = H2OAlgorithmTestParams.getIsolationForest();
        H2OModelCreator machineLearningModelLoader = getMachineLearningModelLoader(H2OAlgorithm.ISOLATION_FOREST);
        Random random = new Random(234L);
        AssertionsForClassTypes.assertThatThrownBy(() -> {
            machineLearningModelLoader.fit(createDataset, random, isolationForest);
        }).isInstanceOf(ModelTrainingException.class).hasMessageContaining("In order to generate the model the dataset cannot be empty").hasMessageContaining("/tmp/");
    }
}
