package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/activelearning/UCBPoolBasedActiveDyadRanker.class */
public class UCBPoolBasedActiveDyadRanker extends ARandomlyInitializingDyadRanker {
    public UCBPoolBasedActiveDyadRanker(PLNetDyadRanker pLNetDyadRanker, IDyadRankingPoolProvider iDyadRankingPoolProvider, int i, int i2, int i3) {
        super(pLNetDyadRanker, iDyadRankingPoolProvider, i, i2, i3);
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ARandomlyInitializingDyadRanker, ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ActiveDyadRanker
    public void activelyTrainWithOneInstance() throws TrainingException, InterruptedException {
        DyadRankingDataset dyadRankingDataset = new DyadRankingDataset();
        for (int i = 0; i < getMinibatchSize(); i++) {
            ArrayList<IDyad> arrayList = new ArrayList(this.poolProvider.getDyadsByInstance(getInstanceFeatures().get(getRandom().nextInt(getInstanceFeatures().size()))));
            ArrayList arrayList2 = new ArrayList(arrayList.size());
            for (IDyad iDyad : arrayList) {
                arrayList2.add(new Pair(iDyad, Double.valueOf(this.ranker.getSkillForDyad(iDyad) + getDyadStats().get(iDyad).getStandardDeviation())));
            }
            Collections.sort(arrayList2, Comparator.comparing(pair -> {
                return Double.valueOf(-((Double) pair.getRight()).doubleValue());
            }));
            IDyad iDyad2 = (IDyad) ((Pair) arrayList2.get(0)).getFirst();
            IDyad iDyad3 = (IDyad) ((Pair) arrayList2.get(1)).getFirst();
            ArrayList arrayList3 = new ArrayList(2);
            arrayList3.add(iDyad2.getAlternative());
            arrayList3.add(iDyad3.getAlternative());
            dyadRankingDataset.add((DyadRankingDataset) this.poolProvider.query(new SparseDyadRankingInstance(iDyad2.getContext(), arrayList3)));
        }
        updateRanker(dyadRankingDataset);
    }
}
