package ai.libs.mlplan.multiclass.wekamlplan.sklearn;

import ai.libs.hasco.exceptions.ComponentInstantiationFailedException;
import ai.libs.hasco.model.CategoricalParameterDomain;
import ai.libs.hasco.model.ComponentInstance;
import ai.libs.hasco.model.Parameter;
import ai.libs.jaicore.basic.ILoggingCustomizable;
import ai.libs.jaicore.ml.scikitwrapper.ScikitLearnWrapper;
import ai.libs.mlplan.multiclass.wekamlplan.IClassifierFactory;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;

/* loaded from: input_file:ai/libs/mlplan/multiclass/wekamlplan/sklearn/SKLearnClassifierFactory.class */
public class SKLearnClassifierFactory implements IClassifierFactory, ILoggingCustomizable {
    private static final CategoricalParameterDomain BOOL_DOMAIN = new CategoricalParameterDomain(Arrays.asList("True", "False"));
    private Logger logger = LoggerFactory.getLogger(SKLearnClassifierFactory.class);
    private String loggerName;

    /* renamed from: getComponentInstantiation, reason: merged with bridge method [inline-methods] */
    public Classifier m29getComponentInstantiation(ComponentInstance componentInstance) throws ComponentInstantiationFailedException {
        this.logger.info("Parse ground component instance {} to ScikitLearnWrapper object.", componentInstance);
        StringBuilder sb = new StringBuilder();
        HashSet hashSet = new HashSet();
        sb.append(extractSKLearnConstructInstruction(componentInstance, hashSet));
        StringBuilder sb2 = new StringBuilder();
        sb2.getClass();
        hashSet.forEach(sb2::append);
        try {
            return new ScikitLearnWrapper(sb.toString(), sb2.toString(), true);
        } catch (IOException e) {
            this.logger.error("Could not create sklearn wrapper for construction {} and imports {}.", sb, sb2);
            return null;
        }
    }

    public String extractSKLearnConstructInstruction(ComponentInstance componentInstance, Set<String> set) {
        StringBuilder sb = new StringBuilder();
        if (componentInstance.getComponent().getName().startsWith("mlplan.util.model.make_forward")) {
            sb.append(extractSKLearnConstructInstruction((ComponentInstance) componentInstance.getSatisfactionOfRequiredInterfaces().get("source"), set));
            sb.append(",");
            sb.append(extractSKLearnConstructInstruction((ComponentInstance) componentInstance.getSatisfactionOfRequiredInterfaces().get("base"), set));
            return sb.toString();
        }
        String[] split = componentInstance.getComponent().getName().split("\\.");
        StringBuilder sb2 = new StringBuilder();
        sb2.append(split[0]);
        for (int i = 1; i < split.length - 1; i++) {
            sb2.append("." + split[i]);
        }
        String str = split[split.length - 1];
        set.add("from " + sb2.toString() + " import " + str + "\n");
        if (componentInstance.getComponent().getName().startsWith("sklearn.feature_selection.f_classif")) {
            sb.append("f_classif(features, targets)");
            return sb.toString();
        }
        sb.append(str);
        sb.append("(");
        if (componentInstance.getComponent().getName().contains("make_pipeline")) {
            sb.append(extractSKLearnConstructInstruction((ComponentInstance) componentInstance.getSatisfactionOfRequiredInterfaces().get("preprocessor"), set));
            sb.append(",");
            sb.append(extractSKLearnConstructInstruction((ComponentInstance) componentInstance.getSatisfactionOfRequiredInterfaces().get("classifier"), set));
        } else if (componentInstance.getComponent().getName().contains("make_union")) {
            sb.append(extractSKLearnConstructInstruction((ComponentInstance) componentInstance.getSatisfactionOfRequiredInterfaces().get("p1"), set));
            sb.append(",");
            sb.append(extractSKLearnConstructInstruction((ComponentInstance) componentInstance.getSatisfactionOfRequiredInterfaces().get("p2"), set));
        } else {
            boolean z = true;
            for (Map.Entry entry : componentInstance.getParameterValues().entrySet()) {
                if (z) {
                    z = false;
                } else {
                    sb.append(",");
                }
                Parameter parameterWithName = componentInstance.getComponent().getParameterWithName((String) entry.getKey());
                sb.append(((String) entry.getKey()) + "=");
                if (parameterWithName.isNumeric()) {
                    sb.append((String) entry.getValue());
                } else if (parameterWithName.isCategorical() && BOOL_DOMAIN.subsumes(parameterWithName.getDefaultDomain())) {
                    sb.append((String) entry.getValue());
                } else {
                    try {
                        sb.append(Integer.parseInt((String) entry.getValue()));
                    } catch (Exception e) {
                        try {
                            sb.append(Double.parseDouble((String) entry.getValue()));
                        } catch (Exception e2) {
                            sb.append("\"" + ((String) entry.getValue()) + "\"");
                        }
                    }
                }
            }
            for (Map.Entry entry2 : componentInstance.getSatisfactionOfRequiredInterfaces().entrySet()) {
                if (z) {
                    z = false;
                } else {
                    sb.append(",");
                }
                sb.append(((String) entry2.getKey()) + "=");
                sb.append(extractSKLearnConstructInstruction((ComponentInstance) entry2.getValue(), set));
            }
        }
        sb.append(")");
        return sb.toString();
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String str) {
        this.loggerName = str;
        this.logger.debug("Switching logger name to {}", str);
        this.logger = LoggerFactory.getLogger(str);
        this.logger.debug("Switched SKLearnClassifierFactory logger to {}", str);
    }
}
