package org.grouplens.lenskit.predict.ordrec;

import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import java.util.Iterator;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import mikera.vectorz.AVector;
import mikera.vectorz.IVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.ImmutableVector;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.basic.AbstractRatingPredictor;
import org.grouplens.lenskit.collections.LongUtils;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.iterative.IterationCount;
import org.grouplens.lenskit.iterative.LearningRate;
import org.grouplens.lenskit.iterative.RegularizationTerm;
import org.grouplens.lenskit.symbols.TypedSymbol;
import org.grouplens.lenskit.transform.quantize.Quantizer;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/grouplens/lenskit/predict/ordrec/OrdRecRatingPredictor.class */
public class OrdRecRatingPredictor extends AbstractRatingPredictor {
    private static final Logger logger = LoggerFactory.getLogger(OrdRecRatingPredictor.class);
    public static final TypedSymbol<IVector> RATING_PROBABILITY_CHANNEL = TypedSymbol.of(IVector.class, "org.grouplens.lenskit.predict.ordrec.RatingProbability");
    private ItemScorer itemScorer;
    private UserEventDAO userEventDao;
    private Quantizer quantizer;
    private final double learningRate;
    private final double regTerm;
    private final int iterationCount;
    private final boolean reportDistribution;

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/grouplens/lenskit/predict/ordrec/OrdRecRatingPredictor$OrdRecModel.class */
    private class OrdRecModel {
        private int levelCount;
        private double t1;
        private AVector beta;
        private ImmutableVector qtzValues;

        private OrdRecModel(Quantizer quantizer) {
            this.qtzValues = quantizer.getValues();
            this.levelCount = this.qtzValues.length();
            this.t1 = (this.qtzValues.get(0) + this.qtzValues.get(1)) / 2.0d;
            this.beta = Vector.createLength(this.levelCount - 2);
            double d = this.t1;
            for (int i = 1; i <= this.beta.length(); i++) {
                double d2 = (this.qtzValues.get(i) + this.qtzValues.get(i + 1)) * 0.5d;
                this.beta.set(i - 1, Math.log(d2 - d));
                d = d2;
            }
        }

        public double getT1() {
            return this.t1;
        }

        public AVector getBeta() {
            return this.beta;
        }

        public int getLevelCount() {
            return this.levelCount;
        }

        public double getThreshold(int i) {
            double d = this.t1;
            if (i < 0) {
                return Double.NEGATIVE_INFINITY;
            }
            if (i == 0) {
                return d;
            }
            if (i > this.beta.length()) {
                return Double.POSITIVE_INFINITY;
            }
            for (int i2 = 0; i2 < i; i2++) {
                d += Math.exp(this.beta.get(i2));
            }
            return d;
        }

        public double getProbLE(double d, int i) {
            return 1.0d / (1.0d + Math.exp(d - getThreshold(i)));
        }

        public double getProbEQ(double d, int i) {
            return getProbLE(d, i) - getProbLE(d, i - 1);
        }

        private double derivateOfBeta(int i, int i2, double d) {
            if (i >= 0 && i2 == 0) {
                return 1.0d;
            }
            if (i2 <= 0 || i < i2) {
                return 0.0d;
            }
            return Math.exp(d);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void train(SparseVector sparseVector, MutableSparseVector mutableSparseVector) {
            Vector createLength = Vector.createLength(this.beta.length());
            for (int i = 0; i < OrdRecRatingPredictor.this.iterationCount; i++) {
                Iterator it = sparseVector.iterator();
                while (it.hasNext()) {
                    VectorEntry vectorEntry = (VectorEntry) it.next();
                    double d = mutableSparseVector.get(vectorEntry.getKey());
                    int index = OrdRecRatingPredictor.this.quantizer.index(vectorEntry.getValue());
                    double probEQ = getProbEQ(d, index);
                    double probLE = getProbLE(d, index);
                    double probLE2 = getProbLE(d, index - 1);
                    double derivateOfBeta = (OrdRecRatingPredictor.this.learningRate / probEQ) * ((((probLE * (1.0d - probLE)) * derivateOfBeta(index, 0, this.t1)) - ((probLE2 * (1.0d - probLE2)) * derivateOfBeta(index - 1, 0, this.t1))) - (OrdRecRatingPredictor.this.regTerm * this.t1));
                    for (int i2 = 0; i2 < this.beta.length(); i2++) {
                        createLength.set(i2, (OrdRecRatingPredictor.this.learningRate / probEQ) * ((((probLE * (1.0d - probLE)) * derivateOfBeta(index, i2 + 1, this.beta.get(i2))) - ((probLE2 * (1.0d - probLE2)) * derivateOfBeta(index - 1, i2 + 1, this.beta.get(i2)))) - (OrdRecRatingPredictor.this.regTerm * this.beta.get(i2))));
                    }
                    this.t1 += derivateOfBeta;
                    this.beta.add(createLength);
                }
            }
        }

        public void getProbDistribution(double d, Vector vector) {
            double probLE = getProbLE(d, 0);
            vector.set(0, probLE);
            for (int i = 1; i < getLevelCount(); i++) {
                double probLE2 = getProbLE(d, i);
                vector.set(i, probLE2 - probLE);
                probLE = probLE2;
            }
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("OrdRecParams(t1=").append(this.t1).append(", beta=").append(this.beta).append(")");
            return sb.toString();
        }
    }

    @Inject
    public OrdRecRatingPredictor(ItemScorer itemScorer, UserEventDAO userEventDAO, Quantizer quantizer, @LearningRate double d, @RegularizationTerm double d2, @IterationCount int i, @ReportRatingDistribution boolean z) {
        this.userEventDao = userEventDAO;
        this.itemScorer = itemScorer;
        this.quantizer = quantizer;
        this.learningRate = d;
        this.regTerm = d2;
        this.iterationCount = i;
        this.reportDistribution = z;
    }

    public OrdRecRatingPredictor(ItemScorer itemScorer, UserEventDAO userEventDAO, Quantizer quantizer) {
        this.userEventDao = userEventDAO;
        this.itemScorer = itemScorer;
        this.quantizer = quantizer;
        this.learningRate = 0.001d;
        this.regTerm = 0.015d;
        this.iterationCount = 1000;
        this.reportDistribution = false;
    }

    private SparseVector makeUserVector(long j, UserEventDAO userEventDAO) {
        UserHistory eventsForUser = userEventDAO.getEventsForUser(j, Rating.class);
        SparseVector sparseVector = null;
        if (eventsForUser != null) {
            sparseVector = RatingVectorUserHistorySummarizer.makeRatingVector(eventsForUser);
        }
        return sparseVector;
    }

    public void predict(long j, @Nonnull MutableSparseVector mutableSparseVector) {
        logger.debug("predicting {} items for {}", Integer.valueOf(mutableSparseVector.keyDomain().size()), Long.valueOf(j));
        OrdRecModel ordRecModel = new OrdRecModel(this.quantizer);
        SparseVector makeUserVector = makeUserVector(j, this.userEventDao);
        MutableSparseVector create = MutableSparseVector.create(LongUtils.setUnion(makeUserVector.keySet(), mutableSparseVector.keyDomain()));
        this.itemScorer.score(j, create);
        ordRecModel.train(makeUserVector, create);
        logger.debug("trained parameters for {}: {}", Long.valueOf(j), ordRecModel);
        Vector createLength = Vector.createLength(ordRecModel.getLevelCount());
        Long2ObjectMap addChannel = this.reportDistribution ? mutableSparseVector.addChannel(RATING_PROBABILITY_CHANNEL) : null;
        for (VectorEntry vectorEntry : mutableSparseVector.view(VectorEntry.State.EITHER)) {
            ordRecModel.getProbDistribution(create.get(vectorEntry.getKey()), createLength);
            mutableSparseVector.set(vectorEntry, this.quantizer.getIndexValue(createLength.maxElementIndex()));
            if (addChannel != null) {
                addChannel.put(vectorEntry.getKey(), createLength.immutable());
            }
        }
    }
}
