package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.common.math.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/activelearning/ARandomlyInitializingDyadRanker.class */
public abstract class ARandomlyInitializingDyadRanker extends ActiveDyadRanker {
    private final Logger logger;
    private final int numberRandomQueriesAtStart;
    private final Map<IDyad, SummaryStatistics> dyadStats;
    private final List<IVector> instanceFeatures;
    private final Random random;
    private final int minibatchSize;
    private int iteration;

    public ARandomlyInitializingDyadRanker(PLNetDyadRanker pLNetDyadRanker, IDyadRankingPoolProvider iDyadRankingPoolProvider, int i, int i2, int i3) {
        super(pLNetDyadRanker, iDyadRankingPoolProvider);
        this.logger = LoggerFactory.getLogger(ARandomlyInitializingDyadRanker.class);
        this.dyadStats = new HashMap();
        this.instanceFeatures = new ArrayList(iDyadRankingPoolProvider.getInstanceFeatures());
        this.numberRandomQueriesAtStart = i2;
        this.minibatchSize = i3;
        this.iteration = 0;
        Iterator<IVector> it = this.instanceFeatures.iterator();
        while (it.hasNext()) {
            Iterator<IDyad> it2 = iDyadRankingPoolProvider.getDyadsByInstance(it.next()).iterator();
            while (it2.hasNext()) {
                this.dyadStats.put(it2.next(), new SummaryStatistics());
            }
        }
        this.random = new Random(i);
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ActiveDyadRanker
    public void activelyTrain(int i) throws TrainingException, InterruptedException {
        for (int i2 = 0; i2 < i; i2++) {
            if (this.iteration < this.numberRandomQueriesAtStart) {
                DyadRankingDataset dyadRankingDataset = new DyadRankingDataset();
                for (int i3 = 0; i3 < this.minibatchSize; i3++) {
                    Collections.shuffle(this.instanceFeatures, this.random);
                    if (!this.instanceFeatures.isEmpty()) {
                        ArrayList arrayList = new ArrayList(this.poolProvider.getDyadsByInstance(this.instanceFeatures.get(0)));
                        Collections.shuffle(arrayList, this.random);
                        LinkedList linkedList = new LinkedList();
                        linkedList.add(((IDyad) arrayList.get(0)).getAlternative());
                        linkedList.add(((IDyad) arrayList.get(1)).getAlternative());
                        dyadRankingDataset.add((DyadRankingDataset) this.poolProvider.query(new SparseDyadRankingInstance(((IDyad) arrayList.get(0)).getContext(), linkedList)));
                    }
                }
                try {
                    updateRanker(dyadRankingDataset);
                } catch (TrainingException e) {
                    this.logger.error("Updating the dyad ranking learner did not succeed.", e);
                }
            } else {
                activelyTrainWithOneInstance();
            }
            this.iteration++;
        }
    }

    public int getNumberRandomQueriesAtStart() {
        return this.numberRandomQueriesAtStart;
    }

    public int getIteration() {
        return this.iteration;
    }

    public Map<IDyad, SummaryStatistics> getDyadStats() {
        return this.dyadStats;
    }

    public List<IVector> getInstanceFeatures() {
        return this.instanceFeatures;
    }

    public Random getRandom() {
        return this.random;
    }

    public int getMinibatchSize() {
        return this.minibatchSize;
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ActiveDyadRanker
    public abstract void activelyTrainWithOneInstance() throws TrainingException, InterruptedException;

    public void updateRanker(DyadRankingDataset dyadRankingDataset) throws TrainingException, InterruptedException {
        this.ranker.fit((IDyadRankingDataset) dyadRankingDataset);
        Iterator<IVector> it = getInstanceFeatures().iterator();
        while (it.hasNext()) {
            for (IDyad iDyad : this.poolProvider.getDyadsByInstance(it.next())) {
                this.dyadStats.get(iDyad).addValue(this.ranker.getSkillForDyad(iDyad));
            }
        }
    }
}
