package com.feedzai.openml.h2o;

import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Iterables;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import java.io.Closeable;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Optional;
import java.util.SortedSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/feedzai/openml/h2o/SupervisedClassificationH2OModel.class */
public class SupervisedClassificationH2OModel extends AbstractClassificationH2OModel {
    private static final Logger logger = LoggerFactory.getLogger(SupervisedClassificationH2OModel.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    public SupervisedClassificationH2OModel(GenModel genModel, Path path, DatasetSchema datasetSchema, Closeable closeable) {
        super(genModel, path, datasetSchema, closeable);
        Preconditions.checkArgument(datasetSchema.getTargetFieldSchema().isPresent(), "Supervised models require a schema with target field.");
    }

    public double[] getClassDistribution(Instance instance) {
        return convertDistribution(isMultiClassification() ? predictInstance(instance).classProbabilities : predictInstance(instance).classProbabilities);
    }

    private double[] convertDistribution(double[] dArr) {
        SortedSet<String> targetValues = getTargetValues();
        double[] dArr2 = new double[targetValues.size()];
        String[] domainValues = this.modelWrapper.m.getDomainValues(this.modelWrapper.m.getResponseIdx());
        for (int i = 0; i < domainValues.length; i++) {
            String str = domainValues[i];
            str.getClass();
            int indexOf = Iterables.indexOf(targetValues, (v1) -> {
                return r1.equals(v1);
            });
            if (indexOf == -1) {
                String format = String.format("Unexpected value found: %s. Feature domain: %s", str, targetValues);
                logger.error(format);
                throw new IllegalStateException(format);
            }
            dArr2[indexOf] = dArr[i];
        }
        return dArr2;
    }

    private SortedSet<String> getTargetValues() {
        Optional map = this.schema.getTargetFieldSchema().map((v0) -> {
            return v0.getValueSchema();
        });
        Class<CategoricalValueSchema> cls = CategoricalValueSchema.class;
        CategoricalValueSchema.class.getClass();
        return (SortedSet) map.map((v1) -> {
            return r1.cast(v1);
        }).map((v0) -> {
            return v0.getNominalValues();
        }).orElse(ImmutableSortedSet.of());
    }

    public int classify(Instance instance) {
        return convertClassification(isMultiClassification() ? predictInstance(instance).labelIndex : predictInstance(instance).labelIndex);
    }

    private int convertClassification(int i) {
        SortedSet<String> targetValues = getTargetValues();
        String str = this.modelWrapper.getResponseDomainValues()[i];
        str.getClass();
        return Iterables.indexOf(targetValues, (v1) -> {
            return r1.equals(v1);
        });
    }

    private boolean isMultiClassification() {
        return this.modelWrapper.getModelCategory() == ModelCategory.Multinomial;
    }

    @Override // com.feedzai.openml.h2o.AbstractClassificationH2OModel
    public /* bridge */ /* synthetic */ void close() throws IOException {
        super.close();
    }

    @Override // com.feedzai.openml.h2o.AbstractClassificationH2OModel
    public /* bridge */ /* synthetic */ DatasetSchema getSchema() {
        return super.getSchema();
    }

    @Override // com.feedzai.openml.h2o.AbstractClassificationH2OModel
    public /* bridge */ /* synthetic */ boolean save(Path path, String str) {
        return super.save(path, str);
    }
}
