package ai.libs.jaicore.ml.weka.ranking.label.learner.clusterbased.modifiedisac;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.ranking.RankingPredictionBatch;
import ai.libs.jaicore.ml.ranking.label.learner.clusterbased.IGroupBasedRanker;
import ai.libs.jaicore.ml.ranking.label.learner.clusterbased.customdatatypes.Group;
import ai.libs.jaicore.ml.ranking.label.learner.clusterbased.customdatatypes.ProblemInstance;
import ai.libs.jaicore.ml.ranking.label.learner.clusterbased.customdatatypes.RankingForGroup;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.IRankingPredictionBatch;
import org.api4.java.ai.ml.ranking.label.dataset.ILabelRankingDataset;
import org.api4.java.ai.ml.ranking.label.dataset.ILabelRankingInstance;
import weka.core.Instance;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/ranking/label/learner/clusterbased/modifiedisac/ModifiedISAC.class */
public class ModifiedISAC extends ASupervisedLearner<ILabelRankingInstance, ILabelRankingDataset, IRanking<String>, IRankingPredictionBatch> implements IGroupBasedRanker<String, ILabelRankingInstance, ILabelRankingDataset, double[]> {
    private Map<double[], Integer> positionOfInstance = new HashMap();
    private ArrayList<ClassifierRankingForGroup> rankings = new ArrayList<>();
    private List<Group<double[], Instance>> foundCluster;
    private Normalizer norm;

    public void fit(ILabelRankingDataset iLabelRankingDataset) throws TrainingException {
        try {
            ModifiedISACInstanceCollector modifiedISACInstanceCollector = new ModifiedISACInstanceCollector();
            List<ProblemInstance<Instance>> problemInstances = modifiedISACInstanceCollector.getProblemInstances();
            ArrayList arrayList = new ArrayList();
            this.norm = new Normalizer(problemInstances);
            this.norm.setupnormalize();
            Iterator<ProblemInstance<Instance>> it = problemInstances.iterator();
            while (it.hasNext()) {
                arrayList.add(this.norm.normalize(((Instance) it.next().getInstance()).toDoubleArray()));
            }
            ModifiedISACGroupBuilder modifiedISACGroupBuilder = new ModifiedISACGroupBuilder();
            modifiedISACGroupBuilder.setPoints(arrayList);
            int i = 0;
            Iterator<ProblemInstance<Instance>> it2 = problemInstances.iterator();
            while (it2.hasNext()) {
                this.positionOfInstance.put(((Instance) it2.next().getInstance()).toDoubleArray(), Integer.valueOf(i));
                i++;
            }
            this.foundCluster = modifiedISACGroupBuilder.buildGroup(problemInstances);
            constructRanking(modifiedISACInstanceCollector);
        } catch (Exception e) {
            throw new TrainingException("Could not build the ranker.", e);
        }
    }

    private void constructRanking(ModifiedISACInstanceCollector modifiedISACInstanceCollector) {
        for (Group<double[], Instance> group : this.foundCluster) {
            ArrayList arrayList = new ArrayList();
            int[] iArr = new int[modifiedISACInstanceCollector.getNumberOfClassifier()];
            double[] dArr = new double[modifiedISACInstanceCollector.getNumberOfClassifier()];
            for (ProblemInstance problemInstance : group.getInstances()) {
                int i = 0;
                Iterator<Map.Entry<double[], Integer>> it = this.positionOfInstance.entrySet().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Map.Entry<double[], Integer> next = it.next();
                    if (Arrays.equals(next.getKey(), ((Instance) problemInstance.getInstance()).toDoubleArray())) {
                        i = next.getValue().intValue();
                        break;
                    }
                }
                ArrayList<Pair<String, Double>> arrayList2 = modifiedISACInstanceCollector.getCollectedClassifierandPerformance().get(i);
                for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                    double doubleValue = ((Double) arrayList2.get(i2).getY()).doubleValue();
                    if (!Double.isNaN(doubleValue)) {
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + doubleValue;
                        int i4 = i2;
                        iArr[i4] = iArr[i4] + 1;
                    }
                }
            }
            for (int i5 = 0; i5 < dArr.length; i5++) {
                dArr[i5] = dArr[i5] / iArr[i5];
            }
            List<String> allClassifier = modifiedISACInstanceCollector.getAllClassifier();
            HashMap hashMap = new HashMap();
            for (int i6 = 0; i6 < dArr.length; i6++) {
                hashMap.put(allClassifier.get(i6), Double.valueOf(dArr[i6]));
            }
            while (!hashMap.isEmpty()) {
                double d = Double.MIN_VALUE;
                String str = null;
                for (Map.Entry entry : hashMap.entrySet()) {
                    double doubleValue2 = ((Double) entry.getValue()).doubleValue();
                    if (doubleValue2 > d) {
                        str = (String) entry.getKey();
                        d = doubleValue2;
                    }
                }
                if (str == null) {
                    Iterator it2 = hashMap.keySet().iterator();
                    while (it2.hasNext()) {
                        arrayList.add((String) it2.next());
                    }
                    hashMap.clear();
                } else {
                    arrayList.add(str);
                    hashMap.remove(str);
                }
            }
            this.rankings.add(new ClassifierRankingForGroup(group.getId(), arrayList));
        }
    }

    public RankingForGroup<double[], String> getRanking(ILabelRankingInstance iLabelRankingInstance) {
        ClassifierRankingForGroup classifierRankingForGroup = null;
        double[] normalize = this.norm.normalize(iLabelRankingInstance.getPoint());
        L1DistanceMetric l1DistanceMetric = new L1DistanceMetric();
        double d = Double.MAX_VALUE;
        Iterator<ClassifierRankingForGroup> it = this.rankings.iterator();
        while (it.hasNext()) {
            ClassifierRankingForGroup next = it.next();
            double doubleValue = l1DistanceMetric.computeDistance((double[]) next.getIdentifierForGroup().getIdentifier(), normalize).doubleValue();
            if (doubleValue <= d) {
                classifierRankingForGroup = next;
                d = doubleValue;
            }
        }
        return classifierRankingForGroup;
    }

    public List<ClassifierRankingForGroup> getRankings() {
        return this.rankings;
    }

    public IRanking<String> predict(ILabelRankingInstance iLabelRankingInstance) throws PredictionException, InterruptedException {
        return getRanking(iLabelRankingInstance);
    }

    public IRankingPredictionBatch predict(ILabelRankingInstance[] iLabelRankingInstanceArr) throws PredictionException, InterruptedException {
        ArrayList arrayList = new ArrayList();
        for (ILabelRankingInstance iLabelRankingInstance : iLabelRankingInstanceArr) {
            arrayList.add(predict(iLabelRankingInstance));
        }
        return new RankingPredictionBatch(arrayList);
    }
}
