package com.feedzai.openml.h2o;

import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.data.schema.NumericValueSchema;
import com.feedzai.openml.model.ClassificationMLModel;
import com.feedzai.openml.util.data.encoding.EncodingHelper;
import com.google.common.base.Preconditions;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import java.io.Closeable;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/feedzai/openml/h2o/AbstractClassificationH2OModel.class */
abstract class AbstractClassificationH2OModel implements ClassificationMLModel {
    private static final Logger logger = LoggerFactory.getLogger(AbstractClassificationH2OModel.class);
    protected final EasyPredictModelWrapper modelWrapper;
    private final Path modelPath;
    protected final DatasetSchema schema;
    private final Closeable closeable;
    private final Object predictLock = new Object();

    /* JADX INFO: Access modifiers changed from: package-private */
    public AbstractClassificationH2OModel(GenModel genModel, Path path, DatasetSchema datasetSchema, Closeable closeable) {
        this.modelWrapper = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true));
        this.modelPath = (Path) Preconditions.checkNotNull(path, "path of the model cannot be null");
        this.schema = (DatasetSchema) Preconditions.checkNotNull(datasetSchema, "dataset schema cannot be null");
        this.closeable = (Closeable) Preconditions.checkNotNull(closeable, "the closeable cannot be null");
    }

    public boolean save(Path path, String str) {
        try {
            FileUtils.copyDirectory(this.modelPath.toFile(), path.toFile());
            return true;
        } catch (IOException e) {
            String format = String.format("Error saving model %s to %s", str, path);
            logger.error(format, e);
            throw new RuntimeException(format, e);
        }
    }

    public DatasetSchema getSchema() {
        return this.schema;
    }

    public void close() throws IOException {
        this.closeable.close();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <P extends AbstractPrediction> P predictInstance(Instance instance) {
        P p;
        RowData convertInstanceToRowData = convertInstanceToRowData(instance);
        try {
            synchronized (this.predictLock) {
                p = (P) this.modelWrapper.predict(convertInstanceToRowData, this.modelWrapper.getModelCategory());
            }
            return p;
        } catch (PredictException e) {
            throw new RuntimeException(String.format("The model failed to classify the event[%s]!", convertInstanceToRowData), e);
        }
    }

    private RowData convertInstanceToRowData(Instance instance) {
        RowData rowData = new RowData();
        List fieldSchemas = this.schema.getFieldSchemas();
        for (int i = 0; i < fieldSchemas.size(); i++) {
            FieldSchema fieldSchema = (FieldSchema) fieldSchemas.get(i);
            if (fieldSchema.getValueSchema() instanceof CategoricalValueSchema) {
                rowData.put(fieldSchema.getFieldName(), EncodingHelper.decodeDoubleToCategory(instance.getValue(i), fieldSchema.getValueSchema()));
            } else if (fieldSchema.getValueSchema() instanceof NumericValueSchema) {
                rowData.put(fieldSchema.getFieldName(), Double.valueOf(instance.getValue(i)));
            } else {
                rowData.put(fieldSchema.getFieldName(), instance.getStringValue(i));
            }
        }
        return rowData;
    }
}
