package ai.libs.mlplan.multiclass.wekamlplan;

import ai.libs.hasco.gui.statsplugin.HASCOModelStatisticsPlugin;
import ai.libs.hasco.model.Component;
import ai.libs.jaicore.basic.ILoggingCustomizable;
import ai.libs.jaicore.basic.TimeOut;
import ai.libs.jaicore.graphvisualizer.plugin.IGUIPlugin;
import ai.libs.jaicore.graphvisualizer.plugin.graphview.GraphViewPlugin;
import ai.libs.jaicore.graphvisualizer.plugin.nodeinfo.NodeInfoGUIPlugin;
import ai.libs.jaicore.graphvisualizer.plugin.solutionperformanceplotter.SolutionPerformanceTimelinePlugin;
import ai.libs.jaicore.graphvisualizer.window.AlgorithmVisualizationWindow;
import ai.libs.jaicore.ml.evaluation.IInstancesClassifier;
import ai.libs.jaicore.planning.hierarchical.algorithms.forwarddecomposition.graphgenerators.tfd.TFDNodeInfoGenerator;
import ai.libs.mlplan.core.AbstractMLPlanBuilder;
import ai.libs.mlplan.core.MLPlan;
import ai.libs.mlplan.multiclass.MLPlanClassifierConfig;
import jaicore.search.gui.plugins.rollouthistograms.SearchRolloutHistogramPlugin;
import jaicore.search.model.travesaltree.JaicoreNodeInfoGenerator;
import java.io.IOException;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Objects;
import javafx.application.Platform;
import javafx.embed.swing.JFXPanel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;

/* loaded from: input_file:ai/libs/mlplan/multiclass/wekamlplan/MLPlanWekaClassifier.class */
public class MLPlanWekaClassifier implements Classifier, CapabilitiesHandler, OptionHandler, ILoggingCustomizable, IInstancesClassifier {
    private String loggerName;
    private final transient AbstractMLPlanBuilder builder;
    private TimeOut timeout;
    private Classifier classifierFoundByMLPlan;
    private double internalValidationErrorOfSelectedClassifier;
    private Logger logger = LoggerFactory.getLogger(MLPlanWekaClassifier.class);
    private boolean visualizationEnabled = false;

    public MLPlanWekaClassifier(AbstractMLPlanBuilder abstractMLPlanBuilder) {
        this.builder = abstractMLPlanBuilder;
        this.timeout = abstractMLPlanBuilder.getTimeOut();
    }

    public void buildClassifier(Instances instances) throws Exception {
        Objects.requireNonNull(this.timeout, "Timeout must be set before running ML-Plan.");
        MLPlan mLPlan = new MLPlan(this.builder, instances);
        mLPlan.setTimeout(this.timeout);
        if (this.loggerName != null) {
            mLPlan.setLoggerName(this.loggerName + ".mlplan");
        }
        if (this.visualizationEnabled) {
            new JFXPanel();
            Platform.runLater(new AlgorithmVisualizationWindow(mLPlan, new GraphViewPlugin(), new IGUIPlugin[]{new NodeInfoGUIPlugin(new JaicoreNodeInfoGenerator(new TFDNodeInfoGenerator())), new SearchRolloutHistogramPlugin(), new SolutionPerformanceTimelinePlugin(), new HASCOModelStatisticsPlugin()}));
        }
        this.classifierFoundByMLPlan = mLPlan.m13call();
    }

    public double[] classifyInstances(Instances instances) throws Exception {
        if (getSelectedClassifier() instanceof IInstancesClassifier) {
            return getSelectedClassifier().classifyInstances(instances);
        }
        double[] dArr = new double[instances.size()];
        for (int i = 0; i < instances.size(); i++) {
            dArr[i] = getSelectedClassifier().classifyInstance(instances.get(i));
        }
        return dArr;
    }

    public double classifyInstance(Instance instance) throws Exception {
        if (this.classifierFoundByMLPlan == null) {
            throw new IllegalStateException("Classifier has not been built yet.");
        }
        return this.classifierFoundByMLPlan.classifyInstance(instance);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.classifierFoundByMLPlan == null) {
            throw new IllegalStateException("Classifier has not been built yet.");
        }
        return this.classifierFoundByMLPlan.distributionForInstance(instance);
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = new Capabilities(this);
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(1);
        return capabilities;
    }

    public Enumeration<Option> listOptions() {
        return null;
    }

    public void setOptions(String[] strArr) throws Exception {
    }

    public String[] getOptions() {
        return new String[0];
    }

    public void setTimeout(TimeOut timeOut) {
        this.timeout = timeOut;
    }

    public MLPlanClassifierConfig getMLPlanConfig() {
        return this.builder.getAlgorithmConfig();
    }

    public Collection<Component> getComponents() throws IOException {
        return this.builder.getComponents();
    }

    public void setVisualizationEnabled(boolean z) {
        this.visualizationEnabled = z;
    }

    public Classifier getSelectedClassifier() {
        return this.classifierFoundByMLPlan;
    }

    public double getInternalValidationErrorOfSelectedClassifier() {
        return this.internalValidationErrorOfSelectedClassifier;
    }

    public void setLoggerName(String str) {
        this.loggerName = str;
        this.logger.info("Switching logger name to {}", str);
        this.logger = LoggerFactory.getLogger(str);
        this.logger.info("Switched ML-Plan logger to {}", str);
    }

    public String getLoggerName() {
        return this.loggerName;
    }
}
