package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/activelearning/PrototypicalPoolBasedActiveDyadRanker.class */
public class PrototypicalPoolBasedActiveDyadRanker extends ARandomlyInitializingDyadRanker {
    private ArrayList<IDyadRankingInstance> seenInstances;
    private double ratioOfOldInstancesForMinibatch;
    private int lengthOfTopRankingToConsider;

    public PrototypicalPoolBasedActiveDyadRanker(PLNetDyadRanker pLNetDyadRanker, IDyadRankingPoolProvider iDyadRankingPoolProvider, int i, int i2, double d, int i3, int i4) {
        super(pLNetDyadRanker, iDyadRankingPoolProvider, i4, i3, i);
        this.seenInstances = new ArrayList<>(iDyadRankingPoolProvider.getPool().size());
        this.ratioOfOldInstancesForMinibatch = d;
        this.lengthOfTopRankingToConsider = i2;
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ARandomlyInitializingDyadRanker, ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ActiveDyadRanker
    public void activelyTrainWithOneInstance() throws TrainingException, InterruptedException {
        DyadRankingDataset dyadRankingDataset = new DyadRankingDataset();
        ArrayList arrayList = new ArrayList(getMinibatchSize());
        Iterator<IVector> it = this.poolProvider.getInstanceFeatures().iterator();
        while (it.hasNext()) {
            arrayList.add(new Pair(it.next(), Double.valueOf(54.0d)));
        }
        Collections.shuffle(arrayList);
        int min = Integer.min((int) (this.ratioOfOldInstancesForMinibatch * getMinibatchSize()), this.seenInstances.size());
        int minibatchSize = getMinibatchSize() - min;
        for (int i = 0; i < minibatchSize; i++) {
            ArrayList arrayList2 = new ArrayList(this.poolProvider.getDyadsByInstance((IVector) ((Pair) arrayList.get(i)).getFirst()));
            if (arrayList2.size() < 2) {
                break;
            }
            IVector context = ((IDyad) arrayList2.get(0)).getContext();
            ArrayList arrayList3 = new ArrayList(arrayList2.size());
            Iterator it2 = arrayList2.iterator();
            while (it2.hasNext()) {
                arrayList3.add(((IDyad) it2.next()).getAlternative());
            }
            IDyadRankingInstance pairWithLeastCertainty = this.ranker.getPairWithLeastCertainty(new SparseDyadRankingInstance(context, arrayList3));
            ArrayList arrayList4 = new ArrayList(pairWithLeastCertainty.getNumberOfRankedElements());
            Iterator it3 = pairWithLeastCertainty.iterator();
            while (it3.hasNext()) {
                arrayList4.add(((IDyad) it3.next()).getAlternative());
            }
            IDyadRankingInstance query = this.poolProvider.query(new SparseDyadRankingInstance(((IDyad) pairWithLeastCertainty.getLabel().get(0)).getContext(), arrayList4));
            this.seenInstances.add(query);
            dyadRankingDataset.add((DyadRankingDataset) query);
        }
        Collections.shuffle(this.seenInstances);
        dyadRankingDataset.addAll(this.seenInstances.subList(0, min));
        updateRanker(dyadRankingDataset);
    }

    public double getRatioOfOldInstancesForMinibatch() {
        return this.ratioOfOldInstancesForMinibatch;
    }

    public void setRatioOfOldInstancesForMinibatch(double d) {
        this.ratioOfOldInstancesForMinibatch = d;
    }

    public int getLengthOfTopRankingToConsider() {
        return this.lengthOfTopRankingToConsider;
    }

    public void setLengthOfTopRankingToConsider(int i) {
        this.lengthOfTopRankingToConsider = i;
    }
}
