package com.feedzai.openml.datarobot;

import com.datarobot.prediction.Predictor;
import com.feedzai.openml.data.schema.AbstractValueSchema;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.java.utils.JavaFileUtils;
import com.feedzai.openml.provider.descriptor.fieldtype.ParamValidationError;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.feedzai.openml.provider.model.MachineLearningModelLoader;
import com.feedzai.openml.util.load.LoadModelUtils;
import com.feedzai.openml.util.load.LoadSchemaUtils;
import com.feedzai.openml.util.validate.ClassificationValidationUtils;
import com.feedzai.openml.util.validate.ValidationUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.net.URLClassLoader;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/feedzai/openml/datarobot/DataRobotModelCreator.class */
public class DataRobotModelCreator implements MachineLearningModelLoader<ClassificationBinaryDataRobotModel> {
    private static final String MODEL_PACKAGE_TEMPLATE = "com.datarobot.prediction.dr%s.DRModel";
    private static final Logger logger = LoggerFactory.getLogger(DataRobotModelCreator.class);
    private static final Set<String> BOOLEAN_VALUES = ImmutableSet.of("True", "False");
    private static final Predicate<String[]> IS_BOOLEAN_MODEL = strArr -> {
        return Objects.equals(BOOLEAN_VALUES, Arrays.stream(strArr).collect(Collectors.toSet()));
    };

    /* renamed from: loadModel, reason: merged with bridge method [inline-methods] */
    public ClassificationBinaryDataRobotModel m3loadModel(Path path, DatasetSchema datasetSchema) throws ModelLoadingException {
        if (!datasetSchema.getTargetIndex().isPresent()) {
            throw new ModelLoadingException("Cannot load a model with a schema that has no target variable.");
        }
        logger.info("Trying to load a model in path [{}]...", path);
        ClassificationValidationUtils.validateParamsModelToLoad(this, path, datasetSchema, ImmutableMap.of());
        Pair<Predictor, URLClassLoader> createPredictorInstance = createPredictorInstance(path);
        Predictor predictor = (Predictor) createPredictorInstance.getKey();
        int length = predictor.get_double_predictors().length + predictor.get_string_predictors().length;
        if (length != datasetSchema.getFieldSchemas().size() - 1) {
            String format = String.format("Wrong number of fields in the given schema. The model expected %d feature fields + 1 target field, but the schema had a total of %d fields only (encompassing both features and target fields).", Integer.valueOf(length), Integer.valueOf(datasetSchema.getFieldSchemas().size()));
            logger.error(format + String.format(" Schema expected by the model %s. Schema provided %s.", predictor2Str(predictor), datasetSchema));
            throw new ModelLoadingException(format);
        }
        String[] targetModelValues = getTargetModelValues(predictor);
        ClassificationBinaryDataRobotModel classificationBinaryDataRobotModel = new ClassificationBinaryDataRobotModel(predictor, checkTargetModelValuesWithSchema(datasetSchema, targetModelValues).first().equals(targetModelValues[0]), path, datasetSchema, (URLClassLoader) createPredictorInstance.getValue());
        ClassificationValidationUtils.validateClassificationModel(datasetSchema, classificationBinaryDataRobotModel);
        logger.info("Model in path [{}] loaded successfully.", path);
        return classificationBinaryDataRobotModel;
    }

    private String predictor2Str(Predictor predictor) {
        StringBuilder sb = new StringBuilder();
        for (String str : predictor.get_double_predictors()) {
            sb.append(str);
            sb.append(",");
        }
        for (String str2 : predictor.get_string_predictors()) {
            sb.append(str2);
            sb.append(",");
        }
        return sb.toString();
    }

    private Pair<Predictor, URLClassLoader> createPredictorInstance(Path path) throws ModelLoadingException {
        String path2 = LoadModelUtils.getModelFilePath(path).toAbsolutePath().toString();
        URLClassLoader urlClassLoader = JavaFileUtils.getUrlClassLoader(path2, ClassificationBinaryDataRobotModel.class.getClassLoader());
        return Pair.of((Predictor) JavaFileUtils.createNewInstanceFromClassLoader(path2, MODEL_PACKAGE_TEMPLATE, urlClassLoader), urlClassLoader);
    }

    @VisibleForTesting
    String[] getTargetModelValues(Predictor predictor) {
        try {
            return (String[]) FieldUtils.readField(predictor, "classLabels", true);
        } catch (Exception e) {
            logger.warn(String.format("Jar file of the DataRobot model may not be supported. A possible cause is that the model might be to old and a newer version is required because it lacks the \"%s\" field with the target values. Ideally, you should create a new project on DataRobot and train new models. As a workaround, the load will assume the target variable values to be 0 and 1, in that order.", "classLabels"), e);
            return new String[]{"0", "1"};
        }
    }

    SortedSet<String> checkTargetModelValuesWithSchema(DatasetSchema datasetSchema, String[] strArr) throws ModelLoadingException {
        Optional map = datasetSchema.getTargetFieldSchema().map((v0) -> {
            return v0.getValueSchema();
        });
        Class<CategoricalValueSchema> cls = CategoricalValueSchema.class;
        CategoricalValueSchema.class.getClass();
        SortedSet<String> sortedSet = (SortedSet) map.map((v1) -> {
            return r1.cast(v1);
        }).map((v0) -> {
            return v0.getNominalValues();
        }).orElseThrow(() -> {
            return new ModelLoadingException("Cannot load a model with a schema that has no target variable.");
        });
        if (IS_BOOLEAN_MODEL.test(strArr)) {
            if (Objects.equals(sortedSet.stream().map(StringUtils::lowerCase).collect(Collectors.toSet()), BOOLEAN_VALUES.stream().map(StringUtils::lowerCase).collect(Collectors.toSet()))) {
                return new TreeSet(BOOLEAN_VALUES);
            }
            String format = String.format("Incompatible target values. The model is binary and thus expects some form of: [%s], but the schema had: %s.", String.join(",", strArr), String.join(",", sortedSet));
            logger.error(format);
            throw new ModelLoadingException(format);
        }
        if (sortedSet.size() == strArr.length && sortedSet.containsAll(Arrays.asList(strArr))) {
            return sortedSet;
        }
        String format2 = String.format("Incompatible target values. model: [%s], schema: %s.", String.join(",", strArr), String.join(",", sortedSet));
        logger.error(format2);
        throw new ModelLoadingException(format2);
    }

    public DatasetSchema loadSchema(Path path) throws ModelLoadingException {
        return LoadSchemaUtils.datasetSchemaFromJson(path);
    }

    public List<ParamValidationError> validateTargetIsBinary(AbstractValueSchema abstractValueSchema) {
        ImmutableList.Builder builder = ImmutableList.builder();
        if ((abstractValueSchema instanceof CategoricalValueSchema) && ((CategoricalValueSchema) abstractValueSchema).getNominalValues().size() != 2) {
            builder.add(new ParamValidationError("At the moment only binary classification models are supported"));
        }
        return builder.build();
    }

    List<ParamValidationError> validateModelFileFormat(Path path) {
        ImmutableList.Builder builder = ImmutableList.builder();
        try {
            String path2 = LoadModelUtils.getModelFilePath(path).toAbsolutePath().toString();
            if (!JavaFileUtils.isJarFile(path2)) {
                builder.add(new ParamValidationError(String.format("Extension [%s] not recognized for a DataRobot model, the model should beexported in the [%s] extension.", path2, "jar")));
            }
        } catch (ModelLoadingException e) {
            builder.add(new ParamValidationError(String.format("Unable to find a model file in [%s].", path)));
        }
        return builder.build();
    }

    public List<ParamValidationError> validateForLoad(Path path, DatasetSchema datasetSchema, Map<String, String> map) {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(ValidationUtils.baseLoadValidations(datasetSchema, map));
        builder.addAll((List) datasetSchema.getTargetFieldSchema().map(fieldSchema -> {
            return validateForLoad(datasetSchema, fieldSchema);
        }).orElse(ImmutableList.of(new ParamValidationError("Cannot load a model with a schema that has no target variable."))));
        builder.addAll(validateModelFileFormat(path));
        builder.addAll(ValidationUtils.validateModelInDir(path));
        return builder.build();
    }

    private List<ParamValidationError> validateForLoad(DatasetSchema datasetSchema, FieldSchema fieldSchema) {
        ImmutableList.Builder builder = ImmutableList.builder();
        Optional validateCategoricalSchema = ValidationUtils.validateCategoricalSchema(datasetSchema);
        builder.getClass();
        validateCategoricalSchema.ifPresent((v1) -> {
            r1.add(v1);
        });
        builder.addAll(validateTargetIsBinary(fieldSchema.getValueSchema()));
        return builder.build();
    }
}
