package ai.libs.jaicore.math.probability.pl;

import ai.libs.jaicore.basic.IOwnerBasedAlgorithmConfig;
import ai.libs.jaicore.basic.algorithm.AAlgorithm;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.doubles.DoubleListIterator;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.shorts.ShortList;
import java.util.Iterator;
import java.util.List;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/math/probability/pl/PLMMAlgorithm.class */
public class PLMMAlgorithm extends AAlgorithm<PLInferenceProblem, DoubleList> {
    private final List<ShortList> rankings;
    private final int numRankings;
    private final int numObjects;
    private final IntList winVector;
    private DoubleList skillVector;
    private Logger logger;

    public PLMMAlgorithm(PLInferenceProblem pLInferenceProblem) {
        this(pLInferenceProblem, null, null);
    }

    public PLMMAlgorithm(PLInferenceProblem pLInferenceProblem, IOwnerBasedAlgorithmConfig iOwnerBasedAlgorithmConfig) {
        this(pLInferenceProblem, null, iOwnerBasedAlgorithmConfig);
    }

    public PLMMAlgorithm(PLInferenceProblem pLInferenceProblem, DoubleList doubleList, IOwnerBasedAlgorithmConfig iOwnerBasedAlgorithmConfig) {
        super(iOwnerBasedAlgorithmConfig, pLInferenceProblem);
        this.logger = LoggerFactory.getLogger(PLMMAlgorithm.class);
        this.numRankings = ((PLInferenceProblem) getInput()).getRankings().size();
        this.numObjects = ((PLInferenceProblem) getInput()).getNumObjects();
        if (this.numObjects < 2) {
            throw new IllegalArgumentException("Cannot create PL-Algorithm for choice problems with only one option.");
        }
        this.rankings = pLInferenceProblem.getRankings();
        Iterator<ShortList> it = this.rankings.iterator();
        while (it.hasNext()) {
            if (it.next().size() != this.numObjects) {
                throw new UnsupportedOperationException("This MM implementation only supports full rankings!");
            }
        }
        this.skillVector = doubleList != null ? doubleList : getDefaultSkillVector(this.numObjects);
        this.winVector = getWinVector();
    }

    public static DoubleList getDefaultSkillVector(int i) {
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        double d = 1.0d / i;
        for (int i2 = 0; i2 < i; i2++) {
            doubleArrayList.add(d);
        }
        return doubleArrayList;
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        double d;
        DoubleList doubleList = null;
        do {
            this.skillVector = normalizeSkillVector(getUpdatedSkillVectorImproved(this.skillVector));
            if (doubleList != null) {
                d = 0.0d;
                for (int i = 0; i < this.numObjects; i++) {
                    d += Math.abs(this.skillVector.getDouble(i) - doubleList.getDouble(i));
                }
            } else {
                d = Double.MAX_VALUE;
            }
            doubleList = this.skillVector;
        } while (d > 1.0E-5d);
        return null;
    }

    private DoubleList normalizeSkillVector(DoubleList doubleList) {
        double d = 0.0d;
        DoubleListIterator it = doubleList.iterator();
        while (it.hasNext()) {
            double doubleValue = ((Double) it.next()).doubleValue();
            if (Double.isNaN(doubleValue)) {
                throw new IllegalArgumentException("Skill vector has NaN entry: " + doubleList);
            }
            d += doubleValue;
        }
        if (d < 0.0d) {
            d *= -1.0d;
        }
        if (d == 0.0d) {
            throw new IllegalArgumentException("Cannot normalize null skill vector: " + doubleList);
        }
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        DoubleListIterator it2 = doubleList.iterator();
        while (it2.hasNext()) {
            doubleArrayList.add(((Double) it2.next()).doubleValue() / d);
        }
        return doubleArrayList;
    }

    private double getSkillOfRankedObject(ShortList shortList, int i, DoubleList doubleList) {
        return doubleList.getDouble(shortList.getShort(i));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private DoubleList getUpdatedSkillVectorImproved(DoubleList doubleList) {
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        double[] dArr = new double[this.numRankings];
        for (int i = 0; i < this.numRankings; i++) {
            ShortList shortList = this.rankings.get(i);
            double[] dArr2 = new double[shortList.size() - 1];
            int length = dArr2.length;
            dArr[i] = dArr2;
            dArr2[length - 1] = getSkillOfRankedObject(shortList, this.numObjects - 1, doubleList) + getSkillOfRankedObject(shortList, this.numObjects - 2, doubleList);
            for (int i2 = this.numObjects - 3; i2 >= 0; i2--) {
                dArr2[i2] = dArr2[i2 + 1] + getSkillOfRankedObject(shortList, i2, doubleList);
            }
            for (int i3 = 0; i3 < length; i3++) {
                dArr2[i3] = 1.0d / dArr2[i3];
            }
        }
        short s = 0;
        while (true) {
            short s2 = s;
            if (s2 >= this.numObjects) {
                this.logger.debug("Updated vector: {}", doubleArrayList);
                return doubleArrayList;
            }
            double d = 0.0d;
            for (int i4 = 0; i4 < this.numRankings; i4++) {
                ShortList shortList2 = this.rankings.get(i4);
                Object[] objArr = dArr[i4];
                for (int i5 = 0; i5 < objArr.length; i5++) {
                    d += objArr[i5];
                    if (shortList2.getShort(i5) == s2) {
                        break;
                    }
                }
            }
            if (d == 0.0d) {
                throw new IllegalStateException("Denominator in PL-model must not be null.");
            }
            doubleArrayList.add(this.winVector.getInt(s2) / d);
            s = (short) (s2 + 1);
        }
    }

    private IntList getWinVector() {
        IntArrayList intArrayList = new IntArrayList();
        short s = 0;
        while (true) {
            short s2 = s;
            if (s2 >= this.numObjects) {
                return intArrayList;
            }
            int i = 0;
            for (ShortList shortList : ((PLInferenceProblem) getInput()).getRankings()) {
                if (shortList.indexOf(s2) < shortList.size() - 1) {
                    i++;
                }
            }
            intArrayList.add(i);
            s = (short) (s2 + 1);
        }
    }

    /* renamed from: call, reason: merged with bridge method [inline-methods] */
    public DoubleList m2call() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        next();
        if (this.skillVector.size() != this.numObjects) {
            throw new IllegalStateException("Have " + this.skillVector.size() + " skills (" + this.skillVector + ") for " + this.numObjects + " objects.");
        }
        DoubleListIterator it = this.skillVector.iterator();
        while (it.hasNext()) {
            if (Double.isNaN(((Double) it.next()).doubleValue())) {
                throw new IllegalStateException("Illegal skill return value: " + this.skillVector);
            }
        }
        return this.skillVector;
    }

    public DoubleList getSkillVector() {
        return this.skillVector;
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }
}
