package ai.libs.mlplan.core;

import ai.libs.hasco.core.HASCOFactory;
import ai.libs.hasco.model.Component;
import ai.libs.hasco.serialization.ComponentLoader;
import ai.libs.hasco.variants.forwarddecomposition.HASCOViaFDAndBestFirstFactory;
import ai.libs.hasco.variants.forwarddecomposition.HASCOViaFDFactory;
import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.basic.ILoggingCustomizable;
import ai.libs.jaicore.basic.TimeOut;
import ai.libs.jaicore.basic.algorithm.reduction.AlgorithmicProblemReduction;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.LearningCurveExtrapolationEvaluator;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.factory.ClassifierEvaluatorConstructionFailedException;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.factory.IClassifierEvaluatorFactory;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.factory.MonteCarloCrossValidationEvaluatorFactory;
import ai.libs.jaicore.ml.weka.dataset.splitter.IDatasetSplitter;
import ai.libs.jaicore.planning.hierarchical.algorithms.forwarddecomposition.graphgenerators.tfd.TFDNode;
import ai.libs.mlpipeline_evaluation.PerformanceDBAdapter;
import ai.libs.mlplan.multiclass.MLPlanClassifierConfig;
import ai.libs.mlplan.multiclass.wekamlplan.IClassifierFactory;
import ai.libs.mlplan.multiclass.wekamlplan.weka.PreferenceBasedNodeEvaluator;
import jaicore.search.algorithms.standard.bestfirst.StandardBestFirstFactory;
import jaicore.search.algorithms.standard.bestfirst.nodeevaluation.AlternativeNodeEvaluator;
import jaicore.search.algorithms.standard.bestfirst.nodeevaluation.INodeEvaluator;
import jaicore.search.core.interfaces.IOptimalPathInORGraphSearchFactory;
import jaicore.search.problemtransformers.GraphSearchProblemInputToGraphSearchWithSubpathEvaluationInputTransformerViaRDFS;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import org.aeonbits.owner.ConfigFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/mlplan/core/AbstractMLPlanBuilder.class */
public abstract class AbstractMLPlanBuilder implements IMLPlanBuilder, ILoggingCustomizable {
    private static final String RES_ALGORITHM_CONFIG = "mlplan/mlplan.properties";
    private static final String FS_ALGORITHM_CONFIG = "conf/mlplan.properties";
    private static final File DEF_ALGORITHM_CONFIG = FileUtil.getExistingFileWithHighestPriority(RES_ALGORITHM_CONFIG, new String[]{FS_ALGORITHM_CONFIG});
    private MLPlanClassifierConfig algorithmConfig;
    private File searchSpaceFile;
    private String requestedHASCOInterface;
    private IClassifierFactory classifierFactory;
    private PipelineValidityCheckingNodeEvaluator pipelineValidityCheckingNodeEvaluator;
    private IDatasetSplitter searchSelectionDatasetSplitter;
    private String performanceMeasureName;
    private boolean useCache;
    private Instances dataset;
    private Logger logger = LoggerFactory.getLogger(AbstractMLPlanBuilder.class);
    private String loggerName = AbstractMLPlanBuilder.class.getName();
    private boolean factoryPreparedWithData = false;
    private HASCOViaFDFactory hascoFactory = new HASCOViaFDFactory();
    private Predicate<TFDNode> priorizingPredicate = null;
    private INodeEvaluator<TFDNode, Double> preferredNodeEvaluator = null;
    private IClassifierEvaluatorFactory factoryForPipelineEvaluationInSearchPhase = null;
    private IClassifierEvaluatorFactory factoryForPipelineEvaluationInSelectionPhase = null;
    private Collection<Component> components = new LinkedList();
    private PerformanceDBAdapter dbAdapter = null;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractMLPlanBuilder() {
        withAlgorithmConfigFile(DEF_ALGORITHM_CONFIG);
        withRandomCompletionBasedBestFirstSearch();
    }

    public static MLPlanSKLearnBuilder forSKLearn() throws IOException {
        return new MLPlanSKLearnBuilder();
    }

    public static MLPlanWekaBuilder forWeka() throws IOException {
        return new MLPlanWekaBuilder();
    }

    public static MLPlanMekaBuilder forMeka() throws IOException {
        return new MLPlanMekaBuilder();
    }

    public AbstractMLPlanBuilder withPreferredNodeEvaluator(INodeEvaluator<TFDNode, Double> iNodeEvaluator) {
        if (this.factoryPreparedWithData) {
            throw new IllegalStateException("The method prepareNodeEvaluatorInFactoryWithData has already been called. No changes to the preferred node evaluator possible anymore");
        }
        if (this.preferredNodeEvaluator == null) {
            this.preferredNodeEvaluator = iNodeEvaluator;
        } else {
            this.preferredNodeEvaluator = new AlternativeNodeEvaluator(iNodeEvaluator, this.preferredNodeEvaluator);
        }
        update();
        return this;
    }

    public AbstractMLPlanBuilder withSearchFactory(IOptimalPathInORGraphSearchFactory iOptimalPathInORGraphSearchFactory, AlgorithmicProblemReduction algorithmicProblemReduction) {
        this.hascoFactory.setSearchFactory(iOptimalPathInORGraphSearchFactory);
        this.hascoFactory.setSearchProblemTransformer(algorithmicProblemReduction);
        return this;
    }

    public AbstractMLPlanBuilder withRandomCompletionBasedBestFirstSearch() {
        this.hascoFactory.setSearchFactory(new StandardBestFirstFactory());
        update();
        return this;
    }

    public Collection<Component> getComponents() throws IOException {
        return new ComponentLoader(this.searchSpaceFile).getComponents();
    }

    public AbstractMLPlanBuilder withAlgorithmConfigFile(File file) {
        return withAlgorithmConfig((MLPlanClassifierConfig) ConfigFactory.create(MLPlanClassifierConfig.class, new Map[0]).loadPropertiesFromFile(file));
    }

    public AbstractMLPlanBuilder withAlgorithmConfig(MLPlanClassifierConfig mLPlanClassifierConfig) {
        this.algorithmConfig = mLPlanClassifierConfig;
        this.hascoFactory.withAlgorithmConfig(this.algorithmConfig);
        update();
        return this;
    }

    public AbstractMLPlanBuilder withPreferredComponentsFile(File file) throws IOException {
        List readFileAsList;
        getAlgorithmConfig().setProperty(MLPlanClassifierConfig.PREFERRED_COMPONENTS, file.getAbsolutePath());
        if (file.exists()) {
            readFileAsList = FileUtil.readFileAsList(file);
        } else {
            this.logger.warn("The configured file for preferred components \"{}\" does not exist. Not using any particular ordering.", file.getAbsolutePath());
            readFileAsList = new ArrayList();
        }
        return withPreferredNodeEvaluator(new PreferenceBasedNodeEvaluator(this.components, readFileAsList));
    }

    public void setPerformanceMeasureName(String str) {
        this.performanceMeasureName = str;
    }

    public AbstractMLPlanBuilder withDataset(Instances instances) {
        this.dataset = instances;
        return this;
    }

    public AbstractMLPlanBuilder withSearchSpaceConfigFile(File file) throws IOException {
        FileUtil.requireFileExists(file);
        this.searchSpaceFile = file;
        this.components.clear();
        this.components.addAll(new ComponentLoader(this.searchSpaceFile).getComponents());
        return this;
    }

    public AbstractMLPlanBuilder withClassifierFactory(IClassifierFactory iClassifierFactory) {
        this.classifierFactory = iClassifierFactory;
        return this;
    }

    public AbstractMLPlanBuilder withDatasetSplitterForSearchSelectionSplit(IDatasetSplitter iDatasetSplitter) {
        this.searchSelectionDatasetSplitter = iDatasetSplitter;
        return this;
    }

    public AbstractMLPlanBuilder withRequestedInterface(String str) {
        this.requestedHASCOInterface = str;
        return this;
    }

    public AbstractMLPlanBuilder withTimeOut(TimeOut timeOut) {
        this.algorithmConfig.setProperty("timeout", timeOut.milliseconds() + "");
        update();
        return this;
    }

    public TimeOut getTimeOut() {
        return new TimeOut(this.algorithmConfig.timeout(), TimeUnit.MILLISECONDS);
    }

    public AbstractMLPlanBuilder withNodeEvaluationTimeOut(TimeOut timeOut) {
        this.algorithmConfig.setProperty("hasco.random_completions.timeout_node", timeOut.milliseconds() + "");
        update();
        return this;
    }

    public TimeOut getNodeEvaluationTimeOut() {
        return new TimeOut(this.algorithmConfig.timeoutForNodeEvaluation(), TimeUnit.MILLISECONDS);
    }

    public AbstractMLPlanBuilder withCandidateEvaluationTimeOut(TimeOut timeOut) {
        this.algorithmConfig.setProperty("hasco.random_completions.timeout_path", timeOut.milliseconds() + "");
        update();
        return this;
    }

    public TimeOut getCandidateEvaluationTimeOut() {
        return new TimeOut(this.algorithmConfig.timeoutForCandidateEvaluation(), TimeUnit.MILLISECONDS);
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public PipelineEvaluator getClassifierEvaluationInSearchPhase(Instances instances, int i, int i2) throws ClassifierEvaluatorConstructionFailedException {
        Objects.requireNonNull(this.factoryForPipelineEvaluationInSearchPhase, "No factory for pipeline evaluation in search phase has been set!");
        LearningCurveExtrapolationEvaluator iClassifierEvaluator = this.factoryForPipelineEvaluationInSearchPhase.getIClassifierEvaluator(instances, i);
        if (iClassifierEvaluator instanceof LearningCurveExtrapolationEvaluator) {
            iClassifierEvaluator.setFullDatasetSize(i2);
        }
        return new PipelineEvaluator(getClassifierFactory(), iClassifierEvaluator, getAlgorithmConfig().timeoutForCandidateEvaluation());
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public PipelineEvaluator getClassifierEvaluationInSelectionPhase(Instances instances, int i) throws ClassifierEvaluatorConstructionFailedException {
        if (this.factoryForPipelineEvaluationInSelectionPhase == null) {
            throw new IllegalStateException("No factory for pipeline evaluation in selection phase has been set!");
        }
        return new PipelineEvaluator(getClassifierFactory(), this.factoryForPipelineEvaluationInSelectionPhase.getIClassifierEvaluator(instances, i), Integer.MAX_VALUE);
    }

    public void withSearchPhaseEvaluatorFactory(IClassifierEvaluatorFactory iClassifierEvaluatorFactory) {
        this.factoryForPipelineEvaluationInSearchPhase = iClassifierEvaluatorFactory;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public IClassifierEvaluatorFactory getSearchEvaluatorFactory() {
        return this.factoryForPipelineEvaluationInSearchPhase;
    }

    public AbstractMLPlanBuilder withSelectionPhaseEvaluatorFactory(MonteCarloCrossValidationEvaluatorFactory monteCarloCrossValidationEvaluatorFactory) {
        this.factoryForPipelineEvaluationInSelectionPhase = monteCarloCrossValidationEvaluatorFactory;
        return this;
    }

    public AbstractMLPlanBuilder withNumCpus(int i) {
        this.algorithmConfig.setProperty("cpus", i + "");
        update();
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public IClassifierEvaluatorFactory getSelectionEvaluatorFactory() {
        return this.factoryForPipelineEvaluationInSelectionPhase;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public String getPerformanceMeasureName() {
        return this.performanceMeasureName;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public HASCOFactory getHASCOFactory() {
        return this.hascoFactory;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public IClassifierFactory getClassifierFactory() {
        return this.classifierFactory;
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        this.loggerName = str;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public String getRequestedInterface() {
        return this.requestedHASCOInterface;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public IDatasetSplitter getSearchSelectionDatasetSplitter() {
        return this.searchSelectionDatasetSplitter;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public File getSearchSpaceConfigFile() {
        return this.searchSpaceFile;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public MLPlanClassifierConfig getAlgorithmConfig() {
        return this.algorithmConfig;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public boolean getUseCache() {
        return this.useCache;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public PerformanceDBAdapter getDBAdapter() {
        return this.dbAdapter;
    }

    @Override // ai.libs.mlplan.core.IMLPlanBuilder
    public void prepareNodeEvaluatorInFactoryWithData(Instances instances) {
        AlternativeNodeEvaluator alternativeNodeEvaluator;
        if (this.hascoFactory instanceof HASCOViaFDAndBestFirstFactory) {
            if (this.factoryPreparedWithData) {
                throw new IllegalStateException("Factory has already been prepared with data. This can only be done once!");
            }
            this.factoryPreparedWithData = true;
            if (this.pipelineValidityCheckingNodeEvaluator == null && this.preferredNodeEvaluator == null) {
                return;
            }
            if (this.pipelineValidityCheckingNodeEvaluator != null) {
                this.pipelineValidityCheckingNodeEvaluator.setComponents(this.components);
                this.pipelineValidityCheckingNodeEvaluator.setData(instances);
                alternativeNodeEvaluator = this.preferredNodeEvaluator != null ? new AlternativeNodeEvaluator(this.pipelineValidityCheckingNodeEvaluator, this.preferredNodeEvaluator) : this.pipelineValidityCheckingNodeEvaluator;
            } else {
                alternativeNodeEvaluator = this.preferredNodeEvaluator;
            }
            this.preferredNodeEvaluator = alternativeNodeEvaluator;
            update();
        }
    }

    private void update() {
        this.hascoFactory.setSearchProblemTransformer(new GraphSearchProblemInputToGraphSearchWithSubpathEvaluationInputTransformerViaRDFS(this.preferredNodeEvaluator, this.priorizingPredicate, this.algorithmConfig.randomSeed(), this.algorithmConfig.numberOfRandomCompletions(), this.algorithmConfig.timeoutForCandidateEvaluation(), this.algorithmConfig.timeoutForNodeEvaluation()));
        this.hascoFactory.withAlgorithmConfig(getAlgorithmConfig());
    }

    public MLPlan build(Instances instances) {
        this.dataset = instances;
        return build();
    }

    public MLPlan build() {
        Objects.requireNonNull(this.dataset, "A dataset needs to be provided as input to ML-Plan");
        MLPlan mLPlan = new MLPlan(this, this.dataset);
        mLPlan.setTimeout(getTimeOut());
        return mLPlan;
    }
}
