package ai.libs.mlplan.cli.module.slc;

import ai.libs.jaicore.ml.classification.loss.dataset.AreaUnderROCCurve;
import ai.libs.jaicore.ml.classification.loss.dataset.AveragedInstanceLoss;
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
import ai.libs.jaicore.ml.classification.loss.dataset.ErrorRate;
import ai.libs.jaicore.ml.classification.loss.dataset.F1Measure;
import ai.libs.jaicore.ml.classification.loss.dataset.Precision;
import ai.libs.jaicore.ml.classification.loss.dataset.Recall;
import ai.libs.jaicore.ml.classification.loss.instance.LogLoss;
import ai.libs.mlplan.cli.MLPlanCLI;
import ai.libs.mlplan.cli.module.AMLPlanCLIModule;
import ai.libs.mlplan.cli.module.IMLPlanCLIModule;
import ai.libs.mlplan.cli.module.UnsupportedModuleConfigurationException;
import ai.libs.mlplan.core.AMLPlanBuilder;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.cli.CommandLine;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;

/* loaded from: input_file:ai/libs/mlplan/cli/module/slc/AMLPlan4ClassificationCLIModule.class */
public abstract class AMLPlan4ClassificationCLIModule extends AMLPlanCLIModule implements IMLPlanCLIModule {
    private static final String L_ERRORRATE = "ERRORRATE";
    private static final String L_LOGLOSS = "LOGLOSS";
    private static final String L_AUC = "AUC";
    private static final String L_F1 = "F1";
    private static final String L_PRECISION = "PRECISION";
    private static final String L_RECALL = "RECALL";
    private static final List<String> BINARY_ONLY_MEASURES = Arrays.asList(L_AUC, L_F1, L_PRECISION, L_RECALL);

    /* JADX INFO: Access modifiers changed from: protected */
    public AMLPlan4ClassificationCLIModule(List<String> list, String str) {
        super(list, str, Arrays.asList(L_AUC, L_F1, L_PRECISION, L_RECALL, L_ERRORRATE, L_LOGLOSS), L_ERRORRATE);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void configureLoss(CommandLine commandLine, ICategoricalAttribute iCategoricalAttribute, AMLPlanBuilder aMLPlanBuilder) {
        int parseInt = Integer.parseInt(commandLine.getOptionValue(MLPlanCLI.O_POS_CLASS_INDEX, MLPlanCLI.getDefault(MLPlanCLI.O_POS_CLASS_INDEX)));
        if (commandLine.hasOption(MLPlanCLI.O_POS_CLASS_NAME)) {
            parseInt = iCategoricalAttribute.getLabels().indexOf(commandLine.getOptionValue(MLPlanCLI.O_POS_CLASS_NAME));
            if (parseInt < 0) {
                throw new UnsupportedModuleConfigurationException("The provided name of the positive class is not contained in the list of class labels");
            }
        }
        String optionValue = commandLine.getOptionValue(MLPlanCLI.O_LOSS, L_ERRORRATE);
        if (BINARY_ONLY_MEASURES.contains(optionValue) && iCategoricalAttribute.getLabels().size() > 2) {
            throw new UnsupportedModuleConfigurationException("Cannot use binary performance measure for non-binary classification dataset.");
        }
        boolean z = -1;
        switch (optionValue.hashCode()) {
            case -1881593071:
                if (optionValue.equals(L_RECALL)) {
                    z = 5;
                    break;
                }
                break;
            case -1177949496:
                if (optionValue.equals(L_ERRORRATE)) {
                    z = false;
                    break;
                }
                break;
            case 2219:
                if (optionValue.equals(L_F1)) {
                    z = 3;
                    break;
                }
                break;
            case 65167:
                if (optionValue.equals(L_AUC)) {
                    z = 2;
                    break;
                }
                break;
            case 1060418631:
                if (optionValue.equals(L_LOGLOSS)) {
                    z = true;
                    break;
                }
                break;
            case 1410417758:
                if (optionValue.equals(L_PRECISION)) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                aMLPlanBuilder.withPerformanceMeasure(EClassificationPerformanceMeasure.ERRORRATE);
                return;
            case true:
                aMLPlanBuilder.withPerformanceMeasure(new AveragedInstanceLoss(new LogLoss()));
                return;
            case true:
                aMLPlanBuilder.withPerformanceMeasure(new AreaUnderROCCurve(parseInt));
                return;
            case true:
                aMLPlanBuilder.withPerformanceMeasure(new F1Measure(parseInt));
                return;
            case true:
                aMLPlanBuilder.withPerformanceMeasure(new Precision(parseInt));
                return;
            case true:
                aMLPlanBuilder.withPerformanceMeasure(new Recall(parseInt));
                return;
            default:
                throw new UnsupportedModuleConfigurationException("Unsupported measure " + optionValue);
        }
    }

    public ICategoricalAttribute getLabelAttribute(ILabeledDataset iLabeledDataset) {
        if (iLabeledDataset.getLabelAttribute() instanceof ICategoricalAttribute) {
            return iLabeledDataset.getLabelAttribute();
        }
        throw new UnsupportedModuleConfigurationException("ML-Plan for classification requires a categorical target attribute.");
    }

    @Override // ai.libs.mlplan.cli.module.IMLPlanCLIModule
    public String getRunReportAsString(ISupervisedLearner iSupervisedLearner, ILearnerRunReport iLearnerRunReport) {
        StringBuilder sb = new StringBuilder();
        sb.append(iSupervisedLearner).append("\n");
        sb.append("Error-Rate: ").append(new ErrorRate().loss(iLearnerRunReport.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class)));
        return sb.toString();
    }
}
