package ai.libs.hasco.knowledgebase;

import ai.libs.hasco.core.Util;
import ai.libs.hasco.model.Component;
import ai.libs.hasco.model.ComponentInstance;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.ExtendedRandomForest;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.FeatureDomain;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.FeatureSpace;
import com.google.common.collect.Sets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/hasco/knowledgebase/FANOVAParameterImportanceEstimator.class */
public class FANOVAParameterImportanceEstimator implements IParameterImportanceEstimator {
    private static final Logger LOGGER = LoggerFactory.getLogger(FANOVAParameterImportanceEstimator.class);
    private PerformanceKnowledgeBase performanceKnowledgeBase;
    private String benchmarkName;
    private Map<String, HashMap<Set<Integer>, Double>> importanceDictionary;
    private Map<String, Set<String>> importantParameterMap;
    private int minNumSamples;
    private double importanceThreshold;
    private int sizeOfLargestSubsetToConsider;
    private Set<String> prunedParameters;

    public FANOVAParameterImportanceEstimator(PerformanceKnowledgeBase performanceKnowledgeBase, String str, int i, double d) {
        this.performanceKnowledgeBase = performanceKnowledgeBase;
        this.benchmarkName = str;
        this.importanceDictionary = new HashMap();
        this.importantParameterMap = new HashMap();
        this.minNumSamples = i;
        this.importanceThreshold = d;
        this.sizeOfLargestSubsetToConsider = 2;
        this.prunedParameters = new HashSet();
    }

    public FANOVAParameterImportanceEstimator(String str, int i, double d) {
        this(null, str, i, d);
    }

    @Override // ai.libs.hasco.knowledgebase.IParameterImportanceEstimator
    public Set<String> extractImportantParameters(ComponentInstance componentInstance, boolean z) throws ExtractionOfImportantParametersFailedException {
        double computeMarginalVarianceContributionForFeatureSubset;
        String componentNamesOfComposition = Util.getComponentNamesOfComposition(componentInstance);
        if (this.importantParameterMap.containsKey(componentNamesOfComposition)) {
            return this.importantParameterMap.get(componentNamesOfComposition);
        }
        Instances performanceSamples = this.performanceKnowledgeBase.getPerformanceSamples(this.benchmarkName, componentInstance);
        FeatureSpace featureSpace = new FeatureSpace(performanceSamples);
        HashSet hashSet = new HashSet();
        if (featureSpace.getDimensionality() < 2) {
            for (FeatureDomain featureDomain : featureSpace.getFeatureDomains()) {
                hashSet.add(featureDomain.getName());
            }
            return hashSet;
        }
        for (FeatureDomain featureDomain2 : featureSpace.getFeatureDomains()) {
            this.prunedParameters.add(featureDomain2.getName());
        }
        ExtendedRandomForest extendedRandomForest = new ExtendedRandomForest();
        try {
            extendedRandomForest.buildClassifier(performanceSamples);
            extendedRandomForest.prepareForest(performanceSamples);
            if (!this.importanceDictionary.containsKey(componentNamesOfComposition)) {
                this.importanceDictionary.put(componentNamesOfComposition, new HashMap<>());
            }
            HashSet hashSet2 = new HashSet();
            for (int i = 0; i < performanceSamples.numAttributes() - 1; i++) {
                hashSet2.add(Integer.valueOf(i));
            }
            for (int i2 = 1; i2 <= this.sizeOfLargestSubsetToConsider; i2++) {
                for (Set<Integer> set : Sets.combinations(hashSet2, i2)) {
                    if (z) {
                        computeMarginalVarianceContributionForFeatureSubset = extendedRandomForest.computeMarginalVarianceContributionForFeatureSubset(set);
                        this.importanceDictionary.get(componentNamesOfComposition).put(set, Double.valueOf(computeMarginalVarianceContributionForFeatureSubset));
                    } else if (this.importanceDictionary.get(componentNamesOfComposition).containsKey(set)) {
                        LOGGER.debug("Taking value from dictionary");
                        computeMarginalVarianceContributionForFeatureSubset = this.importanceDictionary.get(componentNamesOfComposition).get(set).doubleValue();
                    } else {
                        computeMarginalVarianceContributionForFeatureSubset = extendedRandomForest.computeMarginalVarianceContributionForFeatureSubset(set);
                        this.importanceDictionary.get(componentNamesOfComposition).put(set, Double.valueOf(computeMarginalVarianceContributionForFeatureSubset));
                        if (Double.isNaN(computeMarginalVarianceContributionForFeatureSubset)) {
                            computeMarginalVarianceContributionForFeatureSubset = 1.0d;
                            LOGGER.debug("importance value is NaN, so it will be set to 1");
                        }
                    }
                    LOGGER.debug("Importance value for parameter subset {}: {}", set, Double.valueOf(computeMarginalVarianceContributionForFeatureSubset));
                    Logger logger = LOGGER;
                    Object[] objArr = new Object[3];
                    objArr[0] = Double.valueOf(computeMarginalVarianceContributionForFeatureSubset);
                    objArr[1] = Double.valueOf(this.importanceThreshold);
                    objArr[2] = Boolean.valueOf(computeMarginalVarianceContributionForFeatureSubset >= this.importanceThreshold);
                    logger.debug("Importance value {} >= {}: ", objArr);
                    if (computeMarginalVarianceContributionForFeatureSubset >= this.importanceThreshold) {
                        Iterator<Integer> it = set.iterator();
                        while (it.hasNext()) {
                            hashSet.add(extendedRandomForest.getFeatureSpace().getFeatureDomain(it.next().intValue()).getName());
                        }
                    }
                }
            }
            this.importantParameterMap.put(componentNamesOfComposition, hashSet);
            this.prunedParameters.removeAll(hashSet);
            return hashSet;
        } catch (Exception e) {
            throw new ExtractionOfImportantParametersFailedException("Could not build model", e);
        }
    }

    @Override // ai.libs.hasco.knowledgebase.IParameterImportanceEstimator
    public Map<String, Double> computeImportanceForSingleComponent(Component component) {
        Instances performanceSamplesForIndividualComponent = this.performanceKnowledgeBase.getPerformanceSamplesForIndividualComponent(this.benchmarkName, component);
        if (performanceSamplesForIndividualComponent == null) {
            return null;
        }
        ExtendedRandomForest extendedRandomForest = new ExtendedRandomForest();
        HashMap hashMap = new HashMap();
        try {
            extendedRandomForest.buildClassifier(performanceSamplesForIndividualComponent);
            for (int i = 0; i < performanceSamplesForIndividualComponent.numAttributes() - 1; i++) {
                HashSet hashSet = new HashSet();
                hashSet.add(Integer.valueOf(i));
                hashMap.put(performanceSamplesForIndividualComponent.attribute(i).name(), Double.valueOf(extendedRandomForest.computeMarginalVarianceContributionForFeatureSubset(hashSet)));
            }
        } catch (Exception e) {
            LOGGER.error("Could not build model and compute marginal variance contribution.", e);
        }
        return hashMap;
    }

    @Override // ai.libs.hasco.knowledgebase.IParameterImportanceEstimator
    public boolean readyToEstimateImportance(ComponentInstance componentInstance) {
        return this.performanceKnowledgeBase.kDistinctAttributeValuesAvailable(this.benchmarkName, componentInstance, this.minNumSamples);
    }

    @Override // ai.libs.hasco.knowledgebase.IParameterImportanceEstimator
    public PerformanceKnowledgeBase getPerformanceKnowledgeBase() {
        return this.performanceKnowledgeBase;
    }

    @Override // ai.libs.hasco.knowledgebase.IParameterImportanceEstimator
    public void setPerformanceKnowledgeBase(PerformanceKnowledgeBase performanceKnowledgeBase) {
        this.performanceKnowledgeBase = performanceKnowledgeBase;
    }

    @Override // ai.libs.hasco.knowledgebase.IParameterImportanceEstimator
    public int getNumberPrunedParameters() {
        return this.prunedParameters.size();
    }

    @Override // ai.libs.hasco.knowledgebase.IParameterImportanceEstimator
    public Set<String> getPrunedParameters() {
        return this.prunedParameters;
    }
}
