package ai.libs.jaicore.ml.weka;

import ai.libs.jaicore.basic.Maps;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/RankingByPairwiseComparison.class */
public class RankingByPairwiseComparison {
    private RankingByPairwiseComparisonConfig config;
    private List<Integer> labelIndices;
    private Set<String> labelSet = new HashSet();
    private List<PairWiseClassifier> pwClassifiers = new LinkedList();

    /* loaded from: input_file:ai/libs/jaicore/ml/weka/RankingByPairwiseComparison$PairWiseClassifier.class */
    class PairWiseClassifier {
        private String a;
        private String b;
        private Classifier c;

        PairWiseClassifier() {
        }
    }

    public RankingByPairwiseComparison(RankingByPairwiseComparisonConfig rankingByPairwiseComparisonConfig) {
        this.config = rankingByPairwiseComparisonConfig;
    }

    private Instances applyFiltersToDataset(Instances instances) throws Exception {
        Remove remove = new Remove();
        remove.setAttributeIndicesArray(this.labelIndices.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray());
        remove.setInvertSelection(false);
        remove.setInputFormat(instances);
        Instances useFilter = Filter.useFilter(instances, remove);
        Add add = new Add();
        add.setAttributeIndex("last");
        add.setNominalLabels("true,false");
        add.setAttributeName("a>b");
        add.setInputFormat(useFilter);
        Instances useFilter2 = Filter.useFilter(useFilter, add);
        useFilter2.setClassIndex(useFilter2.numAttributes() - 1);
        return useFilter2;
    }

    private static List<Integer> getLabelIndices(int i, Instances instances) {
        LinkedList linkedList = new LinkedList();
        if (i < 0) {
            for (int numAttributes = instances.numAttributes() - 1; numAttributes >= instances.numAttributes() + i; numAttributes--) {
                linkedList.add(Integer.valueOf(numAttributes));
            }
        } else {
            for (int i2 = 0; i2 < i; i2++) {
                linkedList.add(Integer.valueOf(i2));
            }
        }
        return linkedList;
    }

    public void fit(Instances instances, int i) throws Exception {
        this.labelIndices = getLabelIndices(i, instances);
        Stream<R> map = this.labelIndices.stream().map(num -> {
            return instances.attribute(num.intValue()).name();
        });
        Set<String> set = this.labelSet;
        Objects.requireNonNull(set);
        map.forEach((v1) -> {
            r1.add(v1);
        });
        Instances applyFiltersToDataset = applyFiltersToDataset(instances);
        for (int i2 = 0; i2 < this.labelIndices.size() - 1; i2++) {
            try {
                for (int i3 = i2 + 1; i3 < this.labelIndices.size(); i3++) {
                    PairWiseClassifier pairWiseClassifier = new PairWiseClassifier();
                    pairWiseClassifier.a = instances.attribute(this.labelIndices.get(i2).intValue()).name();
                    pairWiseClassifier.b = instances.attribute(this.labelIndices.get(i3).intValue()).name();
                    pairWiseClassifier.c = AbstractClassifier.forName(this.config.getBaseLearner(), (String[]) null);
                    Instances instances2 = new Instances(applyFiltersToDataset);
                    for (int i4 = 0; i4 < instances2.size(); i4++) {
                        instances2.get(i4).setValue(instances2.numAttributes() - 1, instances.get(i4).value(this.labelIndices.get(i2).intValue()) > instances.get(i4).value(this.labelIndices.get(i3).intValue()) ? "true" : "false");
                    }
                    instances2.setClassIndex(instances2.numAttributes() - 1);
                    pairWiseClassifier.c.buildClassifier(instances2);
                    this.pwClassifiers.add(pairWiseClassifier);
                }
            } catch (Exception e) {
                throw new TrainingException("Could not build ranker", e);
            }
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:6:0x007c. Please report as an issue. */
    public List<String> predict(Instance instance) throws PredictionException {
        try {
            Instances instances = new Instances(instance.dataset(), 0);
            instances.add(instance);
            Instances applyFiltersToDataset = applyFiltersToDataset(instances);
            HashMap hashMap = new HashMap();
            this.labelSet.stream().forEach(str -> {
                hashMap.put(str, Double.valueOf(0.0d));
            });
            for (PairWiseClassifier pairWiseClassifier : this.pwClassifiers) {
                double[] distributionForInstance = pairWiseClassifier.c.distributionForInstance(applyFiltersToDataset.get(0));
                String votingStrategy = this.config.getVotingStrategy();
                boolean z = -1;
                switch (votingStrategy.hashCode()) {
                    case -1290561483:
                        if (votingStrategy.equals(RankingByPairwiseComparisonConfig.V_VOTING_STRATEGY_PROBABILITY)) {
                            z = 2;
                            break;
                        }
                        break;
                    case 692443780:
                        if (votingStrategy.equals(RankingByPairwiseComparisonConfig.V_VOTING_STRATEGY_CLASSIFY)) {
                            z = false;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case false:
                        if (distributionForInstance[0] > distributionForInstance[1]) {
                            Maps.increaseCounterInDoubleMap(hashMap, pairWiseClassifier.a);
                            break;
                        } else {
                            Maps.increaseCounterInDoubleMap(hashMap, pairWiseClassifier.b);
                            break;
                        }
                    case true:
                    default:
                        Maps.increaseCounterInDoubleMap(hashMap, pairWiseClassifier.a, distributionForInstance[0]);
                        Maps.increaseCounterInDoubleMap(hashMap, pairWiseClassifier.b, distributionForInstance[1]);
                        break;
                }
            }
            LinkedList linkedList = new LinkedList(hashMap.keySet());
            linkedList.sort((str2, str3) -> {
                return ((Double) hashMap.get(str3)).compareTo((Double) hashMap.get(str2));
            });
            return linkedList;
        } catch (Exception e) {
            throw new PredictionException("Could not create a prediction.", e);
        }
    }
}
