package ai.libs.jaicore.ml.ranking.dyad.learner.search;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DenseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.learner.Dyad;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.IDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.util.AbstractDyadScaler;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IEvaluatedPath;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.common.math.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/search/ADyadRankedNodeQueue.class */
public abstract class ADyadRankedNodeQueue<N, V extends Comparable<V>> implements Queue<IEvaluatedPath<N, ?, V>> {
    private Logger logger;
    private IDyadRanker dyadRanker;
    protected AbstractDyadScaler scaler;
    private boolean useScaler;
    private List<IEvaluatedPath<N, ?, V>> queue;
    private List<IVector> nodeCharacterizations;
    private IVector originalContextCharacterization;
    private IVector contextCharacterization;
    private List<IDyad> queryDyads;
    private BiMap<IEvaluatedPath<N, ?, V>, IVector> nodesAndCharacterizationsMap;

    public ADyadRankedNodeQueue(IVector iVector) {
        this.logger = LoggerFactory.getLogger(getClass());
        this.useScaler = false;
        this.queue = new ArrayList();
        this.nodeCharacterizations = new ArrayList();
        this.queryDyads = new ArrayList();
        this.nodesAndCharacterizationsMap = HashBiMap.create();
        this.contextCharacterization = iVector.addConstantToCopy(0.0d);
        this.originalContextCharacterization = iVector;
        this.logger.trace("Construct ADyadNodeQueue with contexcharacterization {}", iVector);
    }

    public ADyadRankedNodeQueue(IVector iVector, IDyadRanker iDyadRanker, AbstractDyadScaler abstractDyadScaler) {
        this(iVector);
        this.dyadRanker = iDyadRanker;
        this.scaler = abstractDyadScaler;
        if (abstractDyadScaler != null) {
            this.useScaler = true;
            transformContextCharacterization();
        }
    }

    protected abstract IVector characterize(IEvaluatedPath<N, ?, V> iEvaluatedPath);

    @Override // java.util.Collection
    public int size() {
        return this.queue.size();
    }

    @Override // java.util.Collection
    public boolean isEmpty() {
        return this.queue.isEmpty();
    }

    @Override // java.util.Collection
    public boolean contains(Object obj) {
        return this.queue.contains(obj);
    }

    @Override // java.util.Collection, java.lang.Iterable
    public Iterator<IEvaluatedPath<N, ?, V>> iterator() {
        return this.queue.iterator();
    }

    @Override // java.util.Collection
    public Object[] toArray() {
        return this.queue.toArray();
    }

    @Override // java.util.Collection
    public <T> T[] toArray(T[] tArr) {
        return (T[]) this.queue.toArray(tArr);
    }

    @Override // java.util.Collection
    public boolean remove(Object obj) {
        if (!(obj instanceof IEvaluatedPath)) {
            return false;
        }
        int i = -1;
        for (int i2 = 0; i2 < this.queue.size(); i2++) {
            if (this.queue.get(i2).equals(obj)) {
                i = i2;
            }
        }
        if (i == -1) {
            return false;
        }
        removeNodeAtPosition(i);
        return true;
    }

    @Override // java.util.Collection
    public boolean containsAll(Collection<?> collection) {
        return this.queue.containsAll(collection);
    }

    @Override // java.util.Collection
    public boolean addAll(Collection<? extends IEvaluatedPath<N, ?, V>> collection) {
        this.logger.trace("Add {} nodes", Integer.valueOf(collection.size()));
        boolean z = false;
        Iterator<? extends IEvaluatedPath<N, ?, V>> it = collection.iterator();
        while (it.hasNext()) {
            if (add((IEvaluatedPath) it.next())) {
                z = true;
            }
        }
        return z;
    }

    @Override // java.util.Collection
    public boolean removeAll(Collection<?> collection) {
        boolean z = false;
        Iterator<?> it = collection.iterator();
        while (it.hasNext()) {
            if (remove(it.next())) {
                z = true;
            }
        }
        return z;
    }

    @Override // java.util.Collection
    public boolean retainAll(Collection<?> collection) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.Collection
    public void clear() {
        this.queue.clear();
        this.nodesAndCharacterizationsMap.clear();
        this.nodeCharacterizations.clear();
    }

    @Override // java.util.Queue, java.util.Collection
    public boolean add(IEvaluatedPath<N, ?, V> iEvaluatedPath) {
        if (this.queue.contains(iEvaluatedPath)) {
            return true;
        }
        if (iEvaluatedPath == null) {
            return false;
        }
        try {
            this.logger.debug("Add node to OPEN.");
            IVector characterize = characterize(iEvaluatedPath);
            this.nodeCharacterizations.add(characterize);
            Dyad dyad = new Dyad(this.contextCharacterization, characterize);
            this.queryDyads.add(dyad);
            if (this.useScaler) {
                DyadRankingDataset dyadRankingDataset = new DyadRankingDataset();
                dyadRankingDataset.add((DyadRankingDataset) new DenseDyadRankingInstance((List<IDyad>) Arrays.asList(dyad)));
                this.scaler.transformAlternatives(dyadRankingDataset);
            }
            replaceNaNByZeroes(characterize);
            this.nodesAndCharacterizationsMap.put(iEvaluatedPath, characterize);
            IRanking predict = this.dyadRanker.predict(new DenseDyadRankingInstance(this.queryDyads));
            this.queue.clear();
            for (int i = 0; i < ((IRanking) predict.getPrediction()).size(); i++) {
                IEvaluatedPath<N, ?, V> iEvaluatedPath2 = (IEvaluatedPath) this.nodesAndCharacterizationsMap.inverse().get(((Dyad) ((IRanking) predict.getPrediction()).get(i)).getAlternative());
                if (iEvaluatedPath2 != null) {
                    this.queue.add(iEvaluatedPath2);
                } else {
                    this.logger.warn("Got a node in a prediction that doesnt exist");
                }
            }
            return true;
        } catch (PredictionException e) {
            this.logger.warn("Failed to characterize: {}", e.getLocalizedMessage());
            this.nodeCharacterizations.remove(this.nodeCharacterizations.size() - 1);
            return false;
        } catch (InterruptedException e2) {
            Thread.currentThread().interrupt();
            return false;
        }
    }

    private void replaceNaNByZeroes(IVector iVector) {
        for (int i = 0; i < iVector.length(); i++) {
            if (Double.isNaN(iVector.getValue(i))) {
                iVector.setValue(i, 0.0d);
            }
        }
    }

    @Override // java.util.Queue
    public boolean offer(IEvaluatedPath<N, ?, V> iEvaluatedPath) {
        return add((IEvaluatedPath) iEvaluatedPath);
    }

    @Override // java.util.Queue
    public IEvaluatedPath<N, ?, V> remove() {
        return removeNodeAtPosition(0);
    }

    public IEvaluatedPath<N, ?, V> removeNodeAtPosition(int i) {
        IEvaluatedPath<N, ?, V> remove = this.queue.remove(i);
        this.logger.trace("Retrieve node from OPEN. Index: {}", Integer.valueOf(i));
        this.nodeCharacterizations.remove(this.nodesAndCharacterizationsMap.get(remove));
        IVector iVector = (IVector) this.nodesAndCharacterizationsMap.remove(remove);
        int i2 = -1;
        int i3 = 0;
        while (true) {
            if (i3 >= this.queryDyads.size()) {
                break;
            }
            if (this.queryDyads.get(i3).getAlternative().equals(iVector)) {
                i2 = i3;
                break;
            }
            i3++;
        }
        if (i2 >= -1) {
            this.queryDyads.remove(i2);
        }
        return remove;
    }

    @Override // java.util.Queue
    public IEvaluatedPath<N, ?, V> poll() {
        if (this.queue.isEmpty()) {
            return null;
        }
        return remove();
    }

    @Override // java.util.Queue
    public IEvaluatedPath<N, ?, V> element() {
        return this.queue.get(0);
    }

    @Override // java.util.Queue
    public IEvaluatedPath<N, ?, V> peek() {
        if (this.queue.isEmpty()) {
            return null;
        }
        this.logger.trace("Peek from OPEN.");
        return element();
    }

    public IDyadRanker getDyadRanker() {
        return this.dyadRanker;
    }

    public void setDyadRanker(IDyadRanker iDyadRanker) {
        this.logger.trace("Update dyad ranker. Was {} now is {}", this.dyadRanker.getClass(), iDyadRanker.getClass());
        this.dyadRanker = iDyadRanker;
    }

    public AbstractDyadScaler getScaler() {
        return this.scaler;
    }

    public void setScaler(AbstractDyadScaler abstractDyadScaler) {
        if (this.useScaler) {
            this.logger.trace("Update scaler. Was {} now is {}", this.scaler.getClass(), abstractDyadScaler.getClass());
        } else {
            this.logger.trace("Now using scaler {}.", abstractDyadScaler.getClass());
            this.useScaler = true;
        }
        this.scaler = abstractDyadScaler;
        this.contextCharacterization = this.originalContextCharacterization.addConstantToCopy(0.0d);
        transformContextCharacterization();
    }

    private void transformContextCharacterization() {
        this.logger.trace("Transform context characterization with scaler {}", this.scaler.getClass());
        Dyad dyad = new Dyad(this.contextCharacterization, this.contextCharacterization);
        DyadRankingDataset dyadRankingDataset = new DyadRankingDataset();
        dyadRankingDataset.add((DyadRankingDataset) new DenseDyadRankingInstance((List<IDyad>) Arrays.asList(dyad)));
        this.scaler.transformInstances(dyadRankingDataset);
    }
}
