package com.feedzai.openml.h2o;

import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.data.schema.NumericValueSchema;
import com.feedzai.openml.mocks.MockInstance;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.feedzai.openml.util.algorithm.MLAlgorithmEnum;
import com.feedzai.openml.util.provider.AbstractProviderCategoricalTargetTest;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.assertj.core.api.Assertions;
import org.assertj.core.util.Lists;
import org.junit.Test;

/* loaded from: input_file:com/feedzai/openml/h2o/H2OModelProviderLoadTest.class */
public class H2OModelProviderLoadTest extends AbstractProviderCategoricalTargetTest<AbstractClassificationH2OModel, H2OModelCreator, H2OModelProvider> {
    private static final String MOJO_MODEL_FILE = "deeplearning";
    public static final String POJO_MODEL_FILE = "drf";
    private static final Set<String> TARGET_VALUES = ImmutableSet.of("true", "false");

    public Instance getDummyInstance() {
        return new MockInstance(new double[]{3423432.0d, 7.0d, 0.0d});
    }

    public Instance getDummyInstanceDifferentResult() {
        return new MockInstance(new double[]{3223434.0d, 6.0d, 0.0d});
    }

    /* renamed from: getFirstModel, reason: merged with bridge method [inline-methods] */
    public AbstractClassificationH2OModel m6getFirstModel() throws ModelLoadingException {
        return loadModel(H2OAlgorithm.DEEP_LEARNING, MOJO_MODEL_FILE, TARGET_VALUES);
    }

    /* renamed from: getSecondModel, reason: merged with bridge method [inline-methods] */
    public AbstractClassificationH2OModel m5getSecondModel() throws ModelLoadingException {
        return loadModel(H2OAlgorithm.DISTRIBUTED_RANDOM_FOREST, POJO_MODEL_FILE, TARGET_VALUES);
    }

    public Set<Integer> getClassifyValuesOfFirstModel() {
        return (Set) IntStream.range(0, TARGET_VALUES.size()).boxed().collect(Collectors.toSet());
    }

    public Set<Integer> getClassifyValuesOfSecondModel() {
        return getClassifyValuesOfFirstModel();
    }

    /* renamed from: getFirstMachineLearningModelLoader, reason: merged with bridge method [inline-methods] */
    public H2OModelCreator m4getFirstMachineLearningModelLoader() {
        return getMachineLearningModelLoader(H2OAlgorithm.DEEP_LEARNING);
    }

    /* renamed from: getMachineLearningProvider, reason: merged with bridge method [inline-methods] */
    public H2OModelProvider m3getMachineLearningProvider() {
        return new H2OModelProvider();
    }

    public MLAlgorithmEnum getValidAlgorithm() {
        return H2OAlgorithm.DEEP_LEARNING;
    }

    public String getValidModelDirName() {
        return MOJO_MODEL_FILE;
    }

    public Set<String> getFirstModelTargetNominalValues() {
        return TARGET_VALUES;
    }

    public DatasetSchema createDatasetSchema(Set<String> set) {
        return new DatasetSchema(2, ImmutableList.builder().add(new FieldSchema("date", 0, new NumericValueSchema(false))).add(new FieldSchema("amount", 1, new NumericValueSchema(false))).add(new FieldSchema("fraud", 2, new CategoricalValueSchema(false, set))).build());
    }

    @Test
    public final void testLoadModelWithPartialSchema() {
        DatasetSchema removeNonTargetVariable = removeNonTargetVariable(createDatasetSchema(TARGET_VALUES));
        H2OModelCreator machineLearningModelLoader = getMachineLearningModelLoader(H2OAlgorithm.DISTRIBUTED_RANDOM_FOREST);
        String path = getClass().getResource("/drf").getPath();
        Assertions.assertThatThrownBy(() -> {
            machineLearningModelLoader.loadModel(Paths.get(path, new String[0]), removeNonTargetVariable);
        }).isInstanceOf(ModelLoadingException.class);
    }

    private DatasetSchema removeNonTargetVariable(DatasetSchema datasetSchema) {
        ArrayList newArrayList = Lists.newArrayList(datasetSchema.getFieldSchemas());
        if (newArrayList.size() < 2) {
            throw new IllegalStateException("This schema does not have enough fields for this test to work.");
        }
        int size = newArrayList.size() - 1;
        int intValue = ((Integer) datasetSchema.getTargetIndex().map(num -> {
            return Integer.valueOf(num.intValue() == size ? size - 1 : size);
        }).orElse(0)).intValue();
        int i = 0;
        ArrayList arrayList = new ArrayList(newArrayList.size() - 1);
        for (int i2 = 0; i2 < newArrayList.size(); i2++) {
            if (i2 != intValue) {
                FieldSchema fieldSchema = (FieldSchema) newArrayList.get(i2);
                int i3 = i;
                i++;
                arrayList.add(new FieldSchema(fieldSchema.getFieldName(), i3, fieldSchema.getValueSchema()));
            }
        }
        return (DatasetSchema) datasetSchema.getTargetIndex().map(num2 -> {
            return Integer.valueOf(num2.intValue() == size ? size - 1 : size);
        }).map(num3 -> {
            return new DatasetSchema(num3.intValue(), arrayList);
        }).orElseGet(() -> {
            return new DatasetSchema(arrayList);
        });
    }
}
