package com.feedzai.openml.h2o;

import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.mocks.MockDataset;
import com.feedzai.openml.provider.exception.ModelTrainingException;
import com.google.common.primitives.Floats;
import hex.VarImp;
import java.util.Iterator;
import java.util.Random;
import org.assertj.core.api.SoftAssertions;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:com/feedzai/openml/h2o/H2OFeatureImportanceCreatorTest.class */
public class H2OFeatureImportanceCreatorTest implements H2ODatasetMixin {
    private Random random;
    private H2OFeatureImportanceCreator featureImportanceCreator;

    @Before
    public void setup() {
        this.random = new Random(23L);
        this.featureImportanceCreator = new H2OFeatureImportanceCreator();
    }

    @Test
    public final void calculateFeatureImportance() throws ModelTrainingException {
        MockDataset mockDataset = TRAIN_DATASET;
        VarImp calculateFeatureImportance = this.featureImportanceCreator.calculateFeatureImportance(mockDataset, this.random, H2OAlgorithmTestParams.getGbm());
        SoftAssertions softAssertions = new SoftAssertions();
        softAssertions.assertThat(calculateFeatureImportance).as("The feature importance result is not null", new Object[0]).isNotNull();
        Iterator it = mockDataset.getSchema().getFieldSchemas().iterator();
        while (it.hasNext()) {
            softAssertions.assertThat(calculateFeatureImportance._names).as("Field %s is mentioned in the feature importance result.", new Object[0]).contains(new String[]{((FieldSchema) it.next()).getFieldName()});
        }
        softAssertions.assertThat(Floats.asList(calculateFeatureImportance._varimp)).as("The feature importance values", new Object[0]).allMatch(f -> {
            return f.floatValue() >= 0.0f;
        });
        softAssertions.assertAll();
    }
}
