package com.feedzai.openml.python;

import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.AbstractValueSchema;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.model.ClassificationMLModel;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.feedzai.openml.python.jep.instance.JepInstance;
import com.feedzai.openml.util.data.ClassificationDatasetSchemaUtil;
import com.feedzai.openml.util.data.encoding.EncodingHelper;
import com.google.common.collect.ImmutableList;
import java.io.Serializable;
import java.nio.file.Path;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.IntStream;
import jep.JepException;
import jep.NDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/feedzai/openml/python/ClassificationPythonModel.class */
public class ClassificationPythonModel implements ClassificationMLModel {
    private static final Logger logger = LoggerFactory.getLogger(ClassificationPythonModel.class);
    public static final String DEFAULT_CLASSIFY_FUNCTION_NAME = "classify";
    public static final String DEFAULT_GETCLASSDISTRIBUTION_FUNCTION_NAME = "getClassDistribution";
    private final Function<Serializable, Integer> classToIndexConverter;
    private final JepInstance jepInstance;
    private final int[] predictiveFieldIndexes;
    private final DatasetSchema schema;
    private final String id;
    private final String classifyFunctionName;
    private final String getClassDistributionFunctionName;

    public ClassificationPythonModel(JepInstance jepInstance, DatasetSchema datasetSchema, String str, String str2, String str3) {
        this.jepInstance = jepInstance;
        this.schema = datasetSchema;
        this.predictiveFieldIndexes = IntStream.range(0, datasetSchema.getFieldSchemas().size()).filter(i -> {
            return i != datasetSchema.getTargetIndex();
        }).toArray();
        this.id = str;
        this.classifyFunctionName = str2;
        this.getClassDistributionFunctionName = str3;
        this.classToIndexConverter = getClassToIndexConverter(datasetSchema);
    }

    public ClassificationPythonModel(JepInstance jepInstance, DatasetSchema datasetSchema, String str) {
        this(jepInstance, datasetSchema, str, DEFAULT_CLASSIFY_FUNCTION_NAME, DEFAULT_GETCLASSDISTRIBUTION_FUNCTION_NAME);
    }

    public boolean save(Path path, String str) {
        return false;
    }

    public DatasetSchema getSchema() {
        return this.schema;
    }

    public double[] getClassDistribution(Instance instance) {
        Object data = ((NDArray) invokeFunction(instance, this.getClassDistributionFunctionName, "numpy.array(%s)")).getData();
        if (!(data instanceof float[])) {
            return (double[]) data;
        }
        float[] fArr = (float[]) data;
        return IntStream.range(0, fArr.length).mapToDouble(i -> {
            return fArr[i];
        }).toArray();
    }

    public int classify(Instance instance) {
        String str = (String) invokeFunction(instance, this.classifyFunctionName, "str(%s[0])");
        try {
            return this.classToIndexConverter.apply(str).intValue();
        } catch (NullPointerException e) {
            AbstractValueSchema valueSchema = this.schema.getTargetFieldSchema().getValueSchema();
            logger.warn((String) ClassificationDatasetSchemaUtil.withCategoricalValueSchema(valueSchema, categoricalValueSchema -> {
                return String.format("Unexpected class provided by model: %s. Expected values: %s", str, categoricalValueSchema.getNominalValues());
            }).orElseThrow(() -> {
                return new RuntimeException("The target variable is not a categorical value: " + valueSchema);
            }), e);
            throw e;
        }
    }

    public void validate(JepInstance jepInstance, String str) throws ModelLoadingException {
        try {
            jepInstance.submitEvaluation(jep2 -> {
                for (String str2 : ImmutableList.of(this.classifyFunctionName, this.getClassDistributionFunctionName)) {
                    if (!((Boolean) jep2.getValue(String.format("callable(getattr(%s, \"%s\", None))", str, str2))).booleanValue()) {
                        throw new JepException(String.format("Model does not implement %s function.", str2));
                    }
                }
                return null;
            }).get();
        } catch (InterruptedException | ExecutionException e) {
            logger.error(e.getMessage(), e);
            throw new ModelLoadingException(e.getMessage(), e);
        }
    }

    public void close() {
        this.jepInstance.stop();
    }

    private Function<Serializable, Integer> getClassToIndexConverter(DatasetSchema datasetSchema) {
        CategoricalValueSchema valueSchema = datasetSchema.getTargetFieldSchema().getValueSchema();
        if (valueSchema instanceof CategoricalValueSchema) {
            return EncodingHelper.classToIndexConverter(valueSchema);
        }
        logger.warn("Provided schema's target field is not categorical: {}", datasetSchema);
        throw new IllegalArgumentException("Classification models require Categorical target fields. Got " + valueSchema);
    }

    private <T> T invokeFunction(Instance instance, String str, String str2) {
        int length = this.predictiveFieldIndexes.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = instance.getValue(this.predictiveFieldIndexes[i]);
        }
        NDArray nDArray = new NDArray(dArr, 1, length);
        try {
            return this.jepInstance.submitEvaluation(jep2 -> {
                jep2.eval(String.format("classification_function = %s.%s", this.id, str));
                jep2.set("encodedInstance", nDArray);
                return jep2.getValue(String.format(str2, "classification_function(encodedInstance)"));
            }).get();
        } catch (Exception e) {
            logger.warn("Error during instance evaluation.");
            throw new RuntimeException("Error during instance evaluation.", e);
        }
    }
}
