package ai.libs.mlplan.core;

import ai.libs.hasco.builder.HASCOBuilder;
import ai.libs.hasco.builder.forwarddecomposition.HASCOViaFDAndBestFirstWithRandomCompletionsBuilder;
import ai.libs.jaicore.basic.MathExt;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.components.model.RefinementConfiguredSoftwareConfigurationProblem;
import ai.libs.jaicore.ml.core.evaluation.evaluator.factory.ISupervisedLearnerEvaluatorFactory;
import ai.libs.jaicore.ml.core.evaluation.evaluator.factory.LearnerEvaluatorConstructionFailedException;
import ai.libs.jaicore.planning.hierarchical.algorithms.forwarddecomposition.graphgenerators.tfd.TFDNode;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.nodeevaluation.AlternativeNodeEvaluator;
import ai.libs.mlplan.multiclass.IMLPlanClassifierConfig;
import ai.libs.mlplan.safeguard.IEvaluationSafeGuardFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IPathEvaluator;
import org.api4.java.ai.ml.core.IDataConfigurable;
import org.api4.java.ai.ml.core.dataset.splitter.IFoldSizeConfigurableRandomDatasetSplitter;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPredictionPerformanceMetricConfigurable;
import org.api4.java.ai.ml.core.evaluation.ISupervisedLearnerEvaluator;
import org.api4.java.ai.ml.core.evaluation.supervised.loss.IDeterministicPredictionPerformanceMeasure;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.common.attributedobjects.IObjectEvaluator;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.control.IRandomConfigurable;
import org.slf4j.Logger;

/* loaded from: input_file:ai/libs/mlplan/core/MLPlanUtil.class */
abstract class MLPlanUtil {
    private MLPlanUtil() {
    }

    public static Pair<ILabeledDataset<?>, ILabeledDataset<?>> getDataForSearchAndSelection(ILabeledDataset<?> iLabeledDataset, double d, Random random, IFoldSizeConfigurableRandomDatasetSplitter<ILabeledDataset<?>> iFoldSizeConfigurableRandomDatasetSplitter, Logger logger) throws InterruptedException, AlgorithmException {
        ILabeledDataset<?> iLabeledDataset2;
        ILabeledDataset<?> iLabeledDataset3;
        if (d > 0.0d) {
            try {
                if (iFoldSizeConfigurableRandomDatasetSplitter == null) {
                    throw new IllegalArgumentException("The builder does not specify a dataset splitter for the separation between search and selection phase data.");
                }
                logger.debug("Splitting given {} data points into search data ({}%) and selection data ({}%) with splitter {}.", new Object[]{Integer.valueOf(iLabeledDataset.size()), Double.valueOf(MathExt.round((1.0d - d) * 100.0d, 2)), Double.valueOf(MathExt.round(d * 100.0d, 2)), iFoldSizeConfigurableRandomDatasetSplitter.getClass().getName()});
                if (iFoldSizeConfigurableRandomDatasetSplitter instanceof ILoggingCustomizable) {
                    ((ILoggingCustomizable) iFoldSizeConfigurableRandomDatasetSplitter).setLoggerName(logger.getName() + ".searchselectsplitter");
                }
                List split = iFoldSizeConfigurableRandomDatasetSplitter.split(iLabeledDataset, random, new double[]{d});
                int round = (int) Math.round(iLabeledDataset.size() * (1.0d - d));
                int size = iLabeledDataset.size() - round;
                if (Math.abs(round - ((ILabeledDataset) split.get(1)).size()) > 1 || Math.abs(size - ((ILabeledDataset) split.get(0)).size()) > 1) {
                    throw new IllegalStateException("Invalid split produced by " + iFoldSizeConfigurableRandomDatasetSplitter.getClass().getName() + "! Split sizes are " + ((ILabeledDataset) split.get(1)).size() + "/" + ((ILabeledDataset) split.get(0)).size() + " but expected sizes were " + round + "/" + size);
                }
                iLabeledDataset2 = (ILabeledDataset) split.get(1);
                iLabeledDataset3 = iLabeledDataset;
                logger.debug("Search/Selection split completed. Using {} data points in search and {} in selection.", Integer.valueOf(iLabeledDataset2.size()), Integer.valueOf(iLabeledDataset3.size()));
            } catch (SplitFailedException e) {
                throw new AlgorithmException("Error in ML-Plan execution.", e);
            }
        } else {
            iLabeledDataset2 = iLabeledDataset;
            iLabeledDataset3 = null;
            logger.debug("Selection phase de-activated. Not splitting the data and giving everything to the search.");
        }
        if (iLabeledDataset2.isEmpty()) {
            throw new IllegalStateException("Cannot search on no data.");
        }
        if (iLabeledDataset3 == null || iLabeledDataset3.size() >= iLabeledDataset2.size()) {
            return new Pair<>(iLabeledDataset2, iLabeledDataset3);
        }
        throw new IllegalStateException("The search data (" + iLabeledDataset2.size() + " data points) are bigger than the selection data (" + iLabeledDataset3.size() + " data points)!");
    }

    public static Pair<PipelineEvaluator, PipelineEvaluator> getPipelineEvaluators(ISupervisedLearnerEvaluatorFactory<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> iSupervisedLearnerEvaluatorFactory, IDeterministicPredictionPerformanceMeasure<?, ?> iDeterministicPredictionPerformanceMeasure, ISupervisedLearnerEvaluatorFactory<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> iSupervisedLearnerEvaluatorFactory2, IDeterministicPredictionPerformanceMeasure<?, ?> iDeterministicPredictionPerformanceMeasure2, Random random, ILabeledDataset<?> iLabeledDataset, ILabeledDataset<?> iLabeledDataset2, IEvaluationSafeGuardFactory iEvaluationSafeGuardFactory, ILearnerFactory<? extends ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>>> iLearnerFactory, Timeout timeout) throws InterruptedException, AlgorithmException, LearnerEvaluatorConstructionFailedException {
        if (iSupervisedLearnerEvaluatorFactory instanceof IPredictionPerformanceMetricConfigurable) {
            ((IPredictionPerformanceMetricConfigurable) iSupervisedLearnerEvaluatorFactory).setMeasure(iDeterministicPredictionPerformanceMeasure);
        }
        if (iSupervisedLearnerEvaluatorFactory instanceof IRandomConfigurable) {
            ((IRandomConfigurable) iSupervisedLearnerEvaluatorFactory).setRandom(random);
        }
        if (iSupervisedLearnerEvaluatorFactory instanceof IDataConfigurable) {
            ((IDataConfigurable) iSupervisedLearnerEvaluatorFactory).setData(iLabeledDataset);
        }
        if (iSupervisedLearnerEvaluatorFactory2 instanceof IPredictionPerformanceMetricConfigurable) {
            ((IPredictionPerformanceMetricConfigurable) iSupervisedLearnerEvaluatorFactory2).setMeasure(iDeterministicPredictionPerformanceMeasure2);
        }
        if (iSupervisedLearnerEvaluatorFactory2 instanceof IRandomConfigurable) {
            ((IRandomConfigurable) iSupervisedLearnerEvaluatorFactory2).setRandom(random);
        }
        if ((iSupervisedLearnerEvaluatorFactory2 instanceof IDataConfigurable) && iLabeledDataset2 != null) {
            ((IDataConfigurable) iSupervisedLearnerEvaluatorFactory2).setData(iLabeledDataset2);
        }
        ISupervisedLearnerEvaluator<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> learnerEvaluator = iSupervisedLearnerEvaluatorFactory.getLearnerEvaluator();
        PipelineEvaluator pipelineEvaluator = new PipelineEvaluator(iLearnerFactory, learnerEvaluator, timeout);
        if (iEvaluationSafeGuardFactory != null) {
            iEvaluationSafeGuardFactory.withEvaluator(learnerEvaluator);
            try {
                pipelineEvaluator.setSafeGuard(iEvaluationSafeGuardFactory.build());
            } catch (InterruptedException e) {
                throw e;
            } catch (Exception e2) {
                throw new AlgorithmException("Could not build safe guard.", e2);
            }
        }
        return new Pair<>(pipelineEvaluator, iLabeledDataset2 != null ? new PipelineEvaluator(iLearnerFactory, iSupervisedLearnerEvaluatorFactory2.getLearnerEvaluator(), timeout) : null);
    }

    public static HASCOViaFDAndBestFirstWithRandomCompletionsBuilder getHASCOBuilder(IMLPlanClassifierConfig iMLPlanClassifierConfig, ILabeledDataset<?> iLabeledDataset, File file, String str, Predicate<TFDNode> predicate, List<IPathEvaluator<TFDNode, String, Double>> list, PipelineValidityCheckingNodeEvaluator pipelineValidityCheckingNodeEvaluator, String str2, String str3) {
        try {
            RefinementConfiguredSoftwareConfigurationProblem refinementConfiguredSoftwareConfigurationProblem = new RefinementConfiguredSoftwareConfigurationProblem(file, str, (IObjectEvaluator) null);
            HASCOViaFDAndBestFirstWithRandomCompletionsBuilder withRandomCompletions = HASCOBuilder.get(refinementConfiguredSoftwareConfigurationProblem).withBestFirst().withRandomCompletions();
            ArrayList arrayList = new ArrayList();
            if (pipelineValidityCheckingNodeEvaluator != null) {
                pipelineValidityCheckingNodeEvaluator.setComponents(refinementConfiguredSoftwareConfigurationProblem.getComponents());
                pipelineValidityCheckingNodeEvaluator.setData(iLabeledDataset);
                arrayList.add(pipelineValidityCheckingNodeEvaluator);
            }
            if (iMLPlanClassifierConfig.preferredComponents() != null && !iMLPlanClassifierConfig.preferredComponents().isEmpty()) {
                Objects.requireNonNull(str2, "First HASCO method must not be null!");
                Objects.requireNonNull(str3, "Second HASCO method must not be null!");
                arrayList.add(new PreferenceBasedNodeEvaluator(refinementConfiguredSoftwareConfigurationProblem.getComponents(), iMLPlanClassifierConfig.preferredComponents(), str2, str3));
            }
            arrayList.addAll(list);
            if (!arrayList.isEmpty()) {
                IPathEvaluator iPathEvaluator = (IPathEvaluator) arrayList.remove(0);
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    iPathEvaluator = new AlternativeNodeEvaluator(iPathEvaluator, (IPathEvaluator) it.next());
                }
                withRandomCompletions.withPreferredNodeEvaluator(iPathEvaluator);
            }
            withRandomCompletions.withNumSamples(iMLPlanClassifierConfig.numberOfRandomCompletions());
            withRandomCompletions.withSeed(iMLPlanClassifierConfig.seed());
            withRandomCompletions.withTimeoutForNode(new Timeout(iMLPlanClassifierConfig.timeoutForNodeEvaluation(), TimeUnit.MILLISECONDS));
            withRandomCompletions.withTimeoutForSingleEvaluation(new Timeout(iMLPlanClassifierConfig.timeoutForCandidateEvaluation(), TimeUnit.MILLISECONDS));
            withRandomCompletions.withPriorizingPredicate(predicate);
            return withRandomCompletions;
        } catch (IOException e) {
            throw new IllegalArgumentException("Invalid configuration file " + file, e);
        }
    }
}
