package ai.libs.mlpipeline_evaluation;

import ai.libs.hasco.model.ComponentInstance;
import ai.libs.hasco.serialization.CompositionSerializer;
import ai.libs.jaicore.ml.openml.OpenMLHelper;
import ai.libs.mlplan.multiclass.wekamlplan.weka.WEKAPipelineFactory;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.converters.ConverterUtils;

/* loaded from: input_file:ai/libs/mlpipeline_evaluation/PipelineEvaluationCache.class */
public class PipelineEvaluationCache {
    private static final String LOG_CANT_CONNECT_TO_CACHE = "Cannot connect to cache. Switching to offline mode.";
    private static final String INTERMEDIATE_RESULTS_TABLENAME = "pgotfml_hgraf.intermediate_results";
    private final PipelineEvaluationCacheConfigBuilder config;
    private Logger logger = LoggerFactory.getLogger(PipelineEvaluationCache.class);
    private boolean useCache = true;

    public PipelineEvaluationCache(PipelineEvaluationCacheConfigBuilder pipelineEvaluationCacheConfigBuilder) throws Exception {
        this.config = pipelineEvaluationCacheConfigBuilder;
        switch (this.config.getDatasetOrigin()) {
            case LOCAL:
            case CLUSTER_LOCATION_NEW:
                this.config.withDataset(new ConverterUtils.DataSource(this.config.getDatasetId()).getDataSet());
                return;
            case OPENML_DATASET_ID:
                OpenMLHelper.setApiKey("4350e421cdc16404033ef1812ea38c01");
                this.config.withDataset(OpenMLHelper.getInstancesById(Integer.parseInt(this.config.getDatasetId())));
                return;
            default:
                throw new InvalidDatasetOriginException("Invalid dataset origin.");
        }
    }

    public double getResultOrExecuteEvaluation(ComponentInstance componentInstance) throws Exception {
        String str = null;
        if (this.useCache && this.config.getDatasetOrigin() != DatasetOrigin.LOCAL) {
            this.logger.debug("DB Lookup");
            str = CompositionSerializer.serializeComponentInstance(componentInstance).toString();
            this.logger.debug("Pipeline: {}", str);
            Double doDBLookUp = doDBLookUp(str);
            if (doDBLookUp != null) {
                this.logger.debug("Return DB result");
                return doDBLookUp.doubleValue();
            }
        }
        this.logger.debug("Execute new evaluation");
        double evaluate = evaluate(componentInstance);
        this.logger.debug("Score: {}", Double.valueOf(evaluate));
        if (this.useCache && this.config.getDatasetOrigin() != DatasetOrigin.LOCAL) {
            this.logger.debug("Write new evaluation back into DB");
            uploadResultToDB(str, evaluate);
        }
        return evaluate;
    }

    private Double doDBLookUp(String str) {
        String str2;
        List asList;
        if (doNotValidate()) {
            str2 = "SELECT error_rate FROM pgotfml_hgraf.intermediate_results WHERE pipeline=? AND dataset_id=? AND dataset_origin=? AND test_evaluation_technique=? AND test_split_technique=? AND test_seed=? AND val_evaluation_technique IS NULL AND val_split_technique IS NULL AND val_seed IS NULL";
            asList = Arrays.asList(str, this.config.getDatasetId(), DatasetOrigin.mapOriginToColumnIdentifier(this.config.getDatasetOrigin()), this.config.getTestEvaluationTechnique(), this.config.getTestSplitTechnique(), String.valueOf(this.config.getTestSeed()));
        } else {
            str2 = "SELECT error_rate FROM pgotfml_hgraf.intermediate_results WHERE pipeline=? AND dataset_id=? AND dataset_origin=? AND test_evaluation_technique=? AND test_split_technique=? AND test_seed=? AND val_evaluation_technique=? AND val_split_technique=? AND val_seed=?";
            asList = Arrays.asList(str, this.config.getDatasetId(), DatasetOrigin.mapOriginToColumnIdentifier(this.config.getDatasetOrigin()), this.config.getTestEvaluationTechnique(), this.config.getTestSplitTechnique(), String.valueOf(this.config.getTestSeed()), this.config.getValEvaluationTechnique(), this.config.getValSplitTechnique(), String.valueOf(this.config.getValSeed()));
        }
        try {
            ResultSet resultsOfQuery = this.config.getAdapter().getResultsOfQuery(str2, asList);
            if (resultsOfQuery.next()) {
                return Double.valueOf(resultsOfQuery.getDouble("error_rate"));
            }
            return null;
        } catch (SQLException e) {
            this.logger.warn(LOG_CANT_CONNECT_TO_CACHE, e);
            this.useCache = false;
            return null;
        }
    }

    private double evaluate(ComponentInstance componentInstance) throws Exception {
        Classifier m32getComponentInstantiation = new WEKAPipelineFactory().m32getComponentInstantiation(componentInstance);
        return doNotValidate() ? ConsistentMLPipelineEvaluator.evaluateClassifier(this.config.getTestSplitTechnique(), this.config.getTestEvaluationTechnique(), this.config.getTestSeed(), this.config.getData(), m32getComponentInstantiation) : ConsistentMLPipelineEvaluator.evaluateClassifier(this.config.getTestSplitTechnique(), this.config.getTestEvaluationTechnique(), this.config.getTestSeed(), this.config.getValSplitTechnique(), this.config.getValEvaluationTechnique(), this.config.getValSeed(), this.config.getData(), m32getComponentInstantiation);
    }

    private void uploadResultToDB(String str, double d) {
        HashMap hashMap = new HashMap();
        hashMap.put("pipeline", str);
        hashMap.put("dataset_id", this.config.getDatasetId());
        hashMap.put("dataset_origin", DatasetOrigin.mapOriginToColumnIdentifier(this.config.getDatasetOrigin()));
        hashMap.put("test_evaluation_technique", this.config.getTestEvaluationTechnique());
        hashMap.put("test_split_technique", this.config.getTestSplitTechnique());
        hashMap.put("test_seed", Integer.valueOf(this.config.getTestSeed()));
        hashMap.put("error_rate", Double.valueOf(d));
        if (!doNotValidate()) {
            hashMap.put("val_split_technique", this.config.getValSplitTechnique());
            hashMap.put("val_evaluation_technique", this.config.getValEvaluationTechnique());
            hashMap.put("val_seed", Integer.valueOf(this.config.getValSeed()));
        }
        try {
            this.config.getAdapter().insert(INTERMEDIATE_RESULTS_TABLENAME, hashMap);
        } catch (SQLException e) {
            this.logger.warn(LOG_CANT_CONNECT_TO_CACHE, e);
            this.useCache = false;
        }
    }

    private boolean doNotValidate() {
        return this.config.getValSplitTechnique() == null || this.config.getValSplitTechnique().trim().equals("");
    }

    public void configureValidation(String str, String str2, int i) {
        this.config.withValEvaluationTechnique(str2);
        this.config.withValSplitTechnique(str);
        this.config.withValSeed(i);
    }

    public boolean usesCache() {
        return this.useCache;
    }

    public void setUseCache(boolean z) {
        this.useCache = z;
    }
}
