package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DenseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/activelearning/DyadDatasetPoolProvider.class */
public class DyadDatasetPoolProvider implements IDyadRankingPoolProvider {
    private List<IDyadRankingInstance> pool;
    private HashSet<IDyadRankingInstance> queriedRankings;
    private int numberQueries = 0;
    private boolean removeDyadsWhenQueried = false;
    private HashMap<IVector, Set<IDyad>> dyadsByInstances = new HashMap<>();
    private HashMap<IVector, Set<IDyad>> dyadsByAlternatives = new HashMap<>();
    private HashMap<IVector, IDyadRankingInstance> dyadRankingsByInstances = new HashMap<>();
    private HashMap<IVector, IDyadRankingInstance> dyadRankingsByAlternatives = new HashMap<>();

    public DyadDatasetPoolProvider(IDyadRankingDataset iDyadRankingDataset) {
        this.pool = new ArrayList(iDyadRankingDataset.size());
        Iterator it = iDyadRankingDataset.iterator();
        while (it.hasNext()) {
            addDyadRankingInstance((IDyadRankingInstance) it.next());
        }
        this.queriedRankings = new HashSet<>();
    }

    public Collection<IDyadRankingInstance> getPool() {
        return this.pool;
    }

    public IDyadRankingInstance query(IDyadRankingInstance iDyadRankingInstance) {
        this.numberQueries++;
        if (!(iDyadRankingInstance instanceof SparseDyadRankingInstance)) {
            throw new IllegalArgumentException("Currently only supports SparseDyadRankingInstances!");
        }
        SparseDyadRankingInstance sparseDyadRankingInstance = (SparseDyadRankingInstance) iDyadRankingInstance;
        ArrayList arrayList = new ArrayList(sparseDyadRankingInstance.getNumberOfRankedElements());
        Iterator<IDyad> it = sparseDyadRankingInstance.iterator();
        while (it.hasNext()) {
            IDyad next = it.next();
            arrayList.add(new Pair(next, Integer.valueOf(getPositionInRankingByInstanceFeatures(next))));
        }
        Collections.sort(arrayList, Comparator.comparing((v0) -> {
            return v0.getRight();
        }));
        ArrayList arrayList2 = new ArrayList(arrayList.size());
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            arrayList2.add((IDyad) ((Pair) it2.next()).getFirst());
        }
        DenseDyadRankingInstance denseDyadRankingInstance = new DenseDyadRankingInstance(arrayList2);
        if (this.removeDyadsWhenQueried) {
            Iterator it3 = arrayList2.iterator();
            while (it3.hasNext()) {
                removeDyadFromPool((IDyad) it3.next());
            }
        }
        this.queriedRankings.add(denseDyadRankingInstance);
        return denseDyadRankingInstance;
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider
    public Set<IDyad> getDyadsByInstance(IVector iVector) {
        return !this.dyadsByInstances.containsKey(iVector) ? new HashSet() : this.dyadsByInstances.get(iVector);
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider
    public Set<IDyad> getDyadsByAlternative(IVector iVector) {
        return !this.dyadsByAlternatives.containsKey(iVector) ? new HashSet() : this.dyadsByAlternatives.get(iVector);
    }

    private void addDyadRankingInstance(IDyadRankingInstance iDyadRankingInstance) {
        this.pool.add(iDyadRankingInstance);
        this.dyadRankingsByInstances.put(((IDyad) iDyadRankingInstance.getLabel().get(0)).getContext(), iDyadRankingInstance);
        this.dyadRankingsByAlternatives.put(((IDyad) iDyadRankingInstance.getLabel().get(0)).getAlternative(), iDyadRankingInstance);
        Iterator it = iDyadRankingInstance.iterator();
        while (it.hasNext()) {
            IDyad iDyad = (IDyad) it.next();
            if (!this.dyadsByInstances.containsKey(iDyad.getContext())) {
                this.dyadsByInstances.put(iDyad.getContext(), new HashSet());
            }
            this.dyadsByInstances.get(iDyad.getContext()).add(iDyad);
            if (!this.dyadsByAlternatives.containsKey(iDyad.getAlternative())) {
                this.dyadsByAlternatives.put(iDyad.getAlternative(), new HashSet());
            }
            this.dyadsByAlternatives.get(iDyad.getAlternative()).add(iDyad);
        }
    }

    private int getPositionInRankingByInstanceFeatures(IDyad iDyad) {
        if (!this.dyadRankingsByInstances.containsKey(iDyad.getContext())) {
            return -1;
        }
        IDyadRankingInstance iDyadRankingInstance = this.dyadRankingsByInstances.get(iDyad.getContext());
        boolean z = false;
        int i = 0;
        while (i < iDyadRankingInstance.getNumberOfRankedElements() && !z) {
            if (((IDyad) iDyadRankingInstance.getLabel().get(i)).equals(iDyad)) {
                z = true;
            } else {
                i++;
            }
        }
        return i;
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider
    public Collection<IVector> getInstanceFeatures() {
        return this.dyadsByInstances.keySet();
    }

    private void removeDyadFromPool(IDyad iDyad) {
        if (this.dyadsByInstances.containsKey(iDyad.getContext())) {
            this.dyadsByInstances.get(iDyad.getContext()).remove(iDyad);
            if (this.dyadsByInstances.get(iDyad.getContext()).size() < 2) {
                this.dyadsByInstances.remove(iDyad.getContext());
            }
        }
        if (this.dyadsByAlternatives.containsKey(iDyad.getAlternative())) {
            this.dyadsByAlternatives.get(iDyad.getAlternative()).remove(iDyad);
            if (this.dyadsByAlternatives.get(iDyad.getAlternative()).size() < 2) {
                this.dyadsByAlternatives.remove(iDyad.getAlternative());
            }
        }
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider
    public void setRemoveDyadsWhenQueried(boolean z) {
        this.removeDyadsWhenQueried = z;
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider
    public int getPoolSize() {
        int i = 0;
        Iterator<Set<IDyad>> it = this.dyadsByInstances.values().iterator();
        while (it.hasNext()) {
            i += it.next().size();
        }
        return i;
    }

    public int getNumberQueries() {
        return this.numberQueries;
    }

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider
    public DyadRankingDataset getQueriedRankings() {
        return new DyadRankingDataset(new ArrayList(this.queriedRankings));
    }
}
