package ai.libs.mlplan.metamining.dyadranking;

import ai.libs.hasco.core.HASCOUtil;
import ai.libs.jaicore.basic.algorithm.AlgorithmInitializedEvent;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.components.api.IComponent;
import ai.libs.jaicore.components.api.IComponentInstance;
import ai.libs.jaicore.components.api.IComponentRepository;
import ai.libs.jaicore.components.model.ComponentInstance;
import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
import ai.libs.jaicore.ml.core.evaluation.evaluator.FixedSplitClassifierEvaluator;
import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.util.DyadMinMaxScaler;
import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
import ai.libs.jaicore.planning.hierarchical.algorithms.forwarddecomposition.graphgenerators.tfd.TFDNode;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.EvaluatedSearchSolutionCandidateFoundEvent;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.FValueEvent;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.nodeevaluation.RandomizedDepthFirstNodeEvaluator;
import ai.libs.jaicore.search.algorithms.standard.gbf.SolutionEventBus;
import ai.libs.jaicore.search.algorithms.standard.random.RandomSearch;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.model.other.SearchGraphPath;
import ai.libs.jaicore.search.probleminputs.GraphSearchWithSubpathEvaluationsInput;
import ai.libs.mlplan.core.ILearnerFactory;
import ai.libs.mlplan.metamining.pipelinecharacterizing.ComponentInstanceVectorFeatureGenerator;
import ai.libs.mlplan.metamining.pipelinecharacterizing.IPipelineCharacterizer;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.lang.Comparable;
import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.aeonbits.owner.ConfigFactory;
import org.apache.commons.collections.BidiMap;
import org.apache.commons.collections.bidimap.DualHashBidiMap;
import org.api4.java.ai.graphsearch.problem.IPathSearchInput;
import org.api4.java.ai.graphsearch.problem.implicit.graphgenerator.IPathGoalTester;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IPotentiallyGraphDependentPathEvaluator;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IPotentiallySolutionReportingPathEvaluator;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.PathEvaluationException;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.algorithm.IAlgorithm;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.common.attributedobjects.IObjectEvaluator;
import org.api4.java.common.math.IVector;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.api4.java.datastructure.graph.implicit.IGraphGenerator;
import org.openml.webapplication.fantail.dc.LandmarkerCharacterizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/mlplan/metamining/dyadranking/DyadRankingBasedNodeEvaluator.class */
public class DyadRankingBasedNodeEvaluator<T, A, V extends Comparable<V>> implements IPotentiallyGraphDependentPathEvaluator<T, A, V>, IPotentiallySolutionReportingPathEvaluator<T, A, V> {
    private static final Logger logger = LoggerFactory.getLogger(DyadRankingBasedNodeEvaluator.class);
    private BidiMap pathToPipelines;
    private RandomSearch<T, A> randomPathCompleter;
    private IObjectEvaluator<ComponentInstance, V> pipelineEvaluator;
    private Collection<IComponent> components;
    private final int randomlyCompletedPaths;
    private Instances evaluationDataset;
    private double[] datasetMetaFeatures;
    private final int evaluatedPaths;
    private final Random random;
    private PLNetDyadRanker dyadRanker;
    private IPipelineCharacterizer characterizer;
    private final int landmarkerSampleSize;
    private final int[] landmarkers;
    private Instances[][] landmarkerSets;
    private ILearnerFactory<IClassifier> classifierFactory;
    private boolean useLandmarkers;
    private Instant firstEvaluation;
    private SolutionEventBus<T> eventBus;
    private IGraphGenerator<T, A> graphGenerator;
    private IPathGoalTester<T, A> goalTester;
    private DyadMinMaxScaler scaler;

    public void setClassifierFactory(ILearnerFactory<IClassifier> iLearnerFactory) {
        this.classifierFactory = iLearnerFactory;
    }

    public DyadRankingBasedNodeEvaluator(IComponentRepository iComponentRepository) {
        this(iComponentRepository, ConfigFactory.create(DyadRankingBasedNodeEvaluatorConfig.class, new Map[0]));
    }

    public DyadRankingBasedNodeEvaluator(IComponentRepository iComponentRepository, DyadRankingBasedNodeEvaluatorConfig dyadRankingBasedNodeEvaluatorConfig) {
        this.pathToPipelines = new DualHashBidiMap();
        this.dyadRanker = new PLNetDyadRanker();
        this.firstEvaluation = null;
        this.scaler = null;
        this.eventBus = new SolutionEventBus<>();
        this.components = iComponentRepository;
        this.random = new Random(dyadRankingBasedNodeEvaluatorConfig.getSeed());
        this.evaluatedPaths = dyadRankingBasedNodeEvaluatorConfig.getNumberOfEvaluations();
        this.randomlyCompletedPaths = dyadRankingBasedNodeEvaluatorConfig.getNumberOfRandomSamples();
        logger.debug("Initialized DyadRankingBasedNodeEvaluator with evalNum: {} and completionNum: {}", Integer.valueOf(this.randomlyCompletedPaths), Integer.valueOf(this.evaluatedPaths));
        this.characterizer = new ComponentInstanceVectorFeatureGenerator(iComponentRepository);
        this.landmarkers = dyadRankingBasedNodeEvaluatorConfig.getLandmarkers();
        this.landmarkerSampleSize = dyadRankingBasedNodeEvaluatorConfig.getLandmarkerSampleSize();
        this.useLandmarkers = dyadRankingBasedNodeEvaluatorConfig.useLandmarkers();
        String scalerPath = dyadRankingBasedNodeEvaluatorConfig.scalerPath();
        try {
            this.dyadRanker.loadModelFromFile(Paths.get(dyadRankingBasedNodeEvaluatorConfig.getPlNetPath(), new String[0]).toString());
        } catch (IOException e) {
            logger.error("Could not load model for plnet in {}", Paths.get(dyadRankingBasedNodeEvaluatorConfig.getPlNetPath(), new String[0]));
        }
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(Paths.get(scalerPath, new String[0]).toFile()));
            try {
                this.scaler = (DyadMinMaxScaler) objectInputStream.readObject();
                objectInputStream.close();
            } catch (Throwable th) {
                try {
                    objectInputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        } catch (IOException e2) {
            logger.error("Could not load sclader for plnet in {}", Paths.get(dyadRankingBasedNodeEvaluatorConfig.scalerPath(), new String[0]));
        } catch (ClassNotFoundException e3) {
            logger.error("Could not read scaler.", e3);
        }
    }

    public V evaluate(ILabeledPath<T, A> iLabeledPath) throws InterruptedException, PathEvaluationException {
        if (this.firstEvaluation == null) {
            this.firstEvaluation = Instant.now();
        }
        if (((IPathSearchInput) this.randomPathCompleter.getInput()).getGoalTester().isGoal(iLabeledPath)) {
            return null;
        }
        Instant now = Instant.now();
        if (!this.randomPathCompleter.knowsNode(iLabeledPath.getHead())) {
            synchronized (this.randomPathCompleter) {
                this.randomPathCompleter.appendPathToNode(iLabeledPath);
            }
        }
        try {
            try {
                List<ComponentInstance> dyadRankedPaths = getDyadRankedPaths(getNRandomPaths(iLabeledPath));
                if (dyadRankedPaths.isEmpty()) {
                    return Double.valueOf(9000.0d);
                }
                try {
                    List<Pair<ComponentInstance, V>> evaluateTopKPaths = evaluateTopKPaths(dyadRankedPaths.subList(0, Math.min(this.evaluatedPaths, dyadRankedPaths.size())));
                    logger.info("Evaluation took {}ms", Long.valueOf(Duration.between(now, Instant.now()).toMillis()));
                    V bestSolution = getBestSolution(evaluateTopKPaths);
                    logger.info("Best solution is {}, {}", bestSolution, evaluateTopKPaths.stream().map((v0) -> {
                        return v0.getY();
                    }).collect(Collectors.toList()));
                    if (bestSolution == null) {
                        return Double.valueOf(9000.0d);
                    }
                    this.eventBus.post(new FValueEvent((IAlgorithm) null, bestSolution, r0.toMillis()));
                    return bestSolution;
                } catch (InterruptedException | TimeoutException e) {
                    logger.error("Interrupted while predicitng next best solution");
                    Thread.currentThread().interrupt();
                    Thread.interrupted();
                    throw new InterruptedException();
                } catch (ExecutionException e2) {
                    logger.error("Couldn't evaluate solution candidates- Returning null as FValue!.");
                    return null;
                }
            } catch (PredictionException e3) {
                throw new PathEvaluationException("Could not rank nodes", e3);
            }
        } catch (InterruptedException | TimeoutException e4) {
            logger.error("Interrupted in path completion!");
            Thread.currentThread().interrupt();
            Thread.interrupted();
            throw new InterruptedException();
        }
    }

    private List<List<T>> getNRandomPaths(ILabeledPath<T, A> iLabeledPath) throws InterruptedException, TimeoutException {
        ArrayList arrayList;
        ArrayList arrayList2 = new ArrayList();
        int i = 0;
        while (true) {
            if (i >= this.randomlyCompletedPaths) {
                break;
            }
            synchronized (this.randomPathCompleter) {
                if (this.randomPathCompleter.isCanceled()) {
                    logger.info("Completer has been canceled (perhaps due a cancel on the evaluator). Canceling RDFS");
                } else {
                    arrayList = new ArrayList(iLabeledPath.getNodes());
                    try {
                        SearchGraphPath nextSolutionUnderSubPath = this.randomPathCompleter.nextSolutionUnderSubPath(iLabeledPath);
                        if (nextSolutionUnderSubPath == null) {
                            logger.info("No completion was found for path {}.", iLabeledPath.getNodes());
                        } else {
                            ArrayList arrayList3 = new ArrayList(nextSolutionUnderSubPath.getNodes());
                            arrayList3.remove(0);
                            arrayList.addAll(arrayList3);
                        }
                    } catch (AlgorithmExecutionCanceledException e) {
                        logger.info("Completer has been canceled. Returning control.");
                    }
                }
            }
            break;
            arrayList2.add(arrayList);
            i++;
        }
        logger.info("Returning {} paths", Integer.valueOf(arrayList2.size()));
        return arrayList2;
    }

    private List<ComponentInstance> getDyadRankedPaths(List<List<T>> list) throws PredictionException, InterruptedException {
        HashMap hashMap = new HashMap();
        for (List<T> list2 : list) {
            IComponentInstance solutionCompositionFromState = HASCOUtil.getSolutionCompositionFromState(this.components, ((TFDNode) list2.get(list2.size() - 1)).getState(), true);
            this.pathToPipelines.put(list2, solutionCompositionFromState);
            if (this.useLandmarkers) {
                hashMap.put(evaluateLandmarkersForAlgorithm(solutionCompositionFromState), solutionCompositionFromState);
            } else {
                IVector denseDoubleVector = new DenseDoubleVector(this.characterizer.characterize(solutionCompositionFromState));
                if (this.scaler != null) {
                    this.scaler.transformAlternatives(new DyadRankingDataset(Arrays.asList(new SparseDyadRankingInstance(new DenseDoubleVector(this.datasetMetaFeatures), Arrays.asList(denseDoubleVector)))));
                }
                hashMap.put(denseDoubleVector, solutionCompositionFromState);
            }
        }
        return rankRandomPipelines(hashMap);
    }

    private IVector evaluateLandmarkersForAlgorithm(ComponentInstance componentInstance) throws InterruptedException {
        double[] characterize = this.characterizer.characterize(componentInstance);
        double[] dArr = new double[this.characterizer.getLengthOfCharacterization() + this.landmarkers.length];
        System.arraycopy(characterize, 0, dArr, 0, characterize.length);
        for (int i = 0; i < this.landmarkers.length; i++) {
            double d = 0.0d;
            for (Instances instances : this.landmarkerSets[i]) {
                try {
                    d += new FixedSplitClassifierEvaluator(new WekaInstances(instances), new WekaInstances(this.evaluationDataset), EClassificationPerformanceMeasure.ERRORRATE).evaluate((ISupervisedLearner) this.classifierFactory.getComponentInstantiation(componentInstance)).doubleValue();
                } catch (InterruptedException e) {
                    throw e;
                } catch (Exception e2) {
                    logger.error("Couldn't get classifier for {}", componentInstance);
                }
            }
            if (d != 0.0d) {
                d /= r0.length;
            }
            dArr[characterize.length + i] = d;
        }
        return new DenseDoubleVector(dArr);
    }

    private List<ComponentInstance> rankRandomPipelines(Map<IVector, ComponentInstance> map) throws PredictionException, InterruptedException {
        IRanking predict = this.dyadRanker.predict(new SparseDyadRankingInstance(new DenseDoubleVector(this.datasetMetaFeatures), new ArrayList(map.keySet())));
        ArrayList arrayList = new ArrayList();
        Iterator it = predict.iterator();
        while (it.hasNext()) {
            arrayList.add(map.get(((IDyad) it.next()).getAlternative()));
        }
        return arrayList;
    }

    private List<Pair<ComponentInstance, V>> evaluateTopKPaths(List<ComponentInstance> list) throws InterruptedException, ExecutionException, TimeoutException {
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(Executors.newFixedThreadPool(1));
        ArrayList arrayList = new ArrayList();
        for (ComponentInstance componentInstance : list) {
            executorCompletionService.submit(() -> {
                try {
                    Instant now = Instant.now();
                    Comparable evaluate = this.pipelineEvaluator.evaluate(componentInstance);
                    postSolution(componentInstance, Duration.between(now, Instant.now()).toMillis(), evaluate);
                    return new Pair(componentInstance, evaluate);
                } catch (Exception e) {
                    logger.error("Couldn't evaluate {}", componentInstance);
                    if (!(e instanceof InterruptedException)) {
                        return null;
                    }
                    Thread.currentThread().interrupt();
                    return null;
                }
            });
        }
        for (int i = 0; i < list.size(); i++) {
            logger.info("Got {} solutions. Waiting for iteration {} of max iterations {}", new Object[]{Integer.valueOf(arrayList.size()), Integer.valueOf(i + 1), Integer.valueOf(list.size())});
            Future poll = executorCompletionService.poll(20L, TimeUnit.SECONDS);
            if (poll == null) {
                logger.info("Didn't receive any futures (expected {} futures)", Integer.valueOf(list.size()));
            } else {
                try {
                    Pair pair = (Pair) poll.get(20L, TimeUnit.SECONDS);
                    if (pair != null) {
                        logger.info("Evaluation was successful. Adding {} to solutions", pair.getY());
                        arrayList.add(pair);
                    } else {
                        logger.info("No solution was found while waiting up to 20s.");
                        poll.cancel(true);
                    }
                } catch (InterruptedException e) {
                    throw e;
                } catch (Exception e2) {
                    logger.info("Got exception while evaluating {}", e2.getMessage());
                }
            }
        }
        return arrayList;
    }

    private V getBestSolution(List<Pair<ComponentInstance, V>> list) {
        return (V) list.stream().map((v0) -> {
            return v0.getY();
        }).min((v0, v1) -> {
            return v0.compareTo(v1);
        }).orElse(null);
    }

    public void setGenerator(IGraphGenerator<T, A> iGraphGenerator, IPathGoalTester<T, A> iPathGoalTester) {
        this.graphGenerator = iGraphGenerator;
        this.goalTester = iPathGoalTester;
        initializeRandomSearch();
    }

    private void initializeRandomSearch() {
        this.randomPathCompleter = new RandomSearch<>(new GraphSearchWithSubpathEvaluationsInput(this.graphGenerator, this.goalTester, new RandomizedDepthFirstNodeEvaluator(this.random)), (Predicate) null, this.random);
        do {
        } while (!(this.randomPathCompleter.next() instanceof AlgorithmInitializedEvent));
    }

    public void setDataset(Instances instances) {
        try {
            if (this.useLandmarkers) {
                List stratifiedSplit = WekaUtil.getStratifiedSplit(instances, 42L, 0.8d);
                Instances instances2 = (Instances) stratifiedSplit.get(0);
                this.evaluationDataset = (Instances) stratifiedSplit.get(1);
                this.datasetMetaFeatures = new LandmarkerCharacterizer().characterize(instances).entrySet().stream().mapToDouble((v0) -> {
                    return v0.getValue();
                }).toArray();
                setUpLandmarkingDatasets(instances, instances2);
            } else {
                this.datasetMetaFeatures = new LandmarkerCharacterizer().characterize(instances).entrySet().stream().mapToDouble((v0) -> {
                    return v0.getValue();
                }).toArray();
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        } catch (SplitFailedException e2) {
            throw new IllegalArgumentException((Throwable) e2);
        } catch (Exception e3) {
            logger.error("Failed to characterize the dataset", e3);
        }
    }

    private void setUpLandmarkingDatasets(Instances instances, Instances instances2) {
        this.landmarkerSets = new Instances[this.landmarkers.length][this.landmarkerSampleSize];
        for (int i = 0; i < this.landmarkers.length; i++) {
            int i2 = this.landmarkers[i];
            for (int i3 = 0; i3 < this.landmarkerSampleSize; i3++) {
                Instances instances3 = new Instances(instances, i2);
                for (int i4 = 0; i4 < i2; i4++) {
                    instances3.add(instances2.get(this.random.nextInt(instances2.size())));
                }
                this.landmarkerSets[i][i3] = instances3;
            }
        }
    }

    protected void postSolution(ComponentInstance componentInstance, long j, V v) {
        try {
            EvaluatedSearchGraphPath evaluatedSearchGraphPath = new EvaluatedSearchGraphPath((List) this.pathToPipelines.getKey(componentInstance), (List) null, v);
            evaluatedSearchGraphPath.setAnnotation("fTime", Long.valueOf(j));
            evaluatedSearchGraphPath.setAnnotation("timeToSolution", Long.valueOf(Duration.between(this.firstEvaluation, Instant.now()).toMillis()));
            evaluatedSearchGraphPath.setAnnotation("nodesEvaluatedToSolution", Integer.valueOf(this.randomlyCompletedPaths));
            logger.debug("Posting solution {}", evaluatedSearchGraphPath);
            this.eventBus.post(new EvaluatedSearchSolutionCandidateFoundEvent((IAlgorithm) null, evaluatedSearchGraphPath));
        } catch (Exception e) {
            logger.error("Couldn't post solution to event bus.", e);
        }
    }

    public void setPipelineEvaluator(IObjectEvaluator<ComponentInstance, V> iObjectEvaluator) {
        this.pipelineEvaluator = iObjectEvaluator;
    }

    public boolean requiresGraphGenerator() {
        return true;
    }

    public void registerSolutionListener(Object obj) {
        this.eventBus.register(obj);
    }

    public boolean reportsSolutions() {
        return true;
    }
}
