package es.uam.eps.ir.ranksys.mf.plsa;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.jet.math.Functions;
import es.uam.eps.ir.ranksys.fast.preference.FastPreferenceData;
import es.uam.eps.ir.ranksys.fast.preference.IdxPref;
import es.uam.eps.ir.ranksys.mf.Factorization;
import es.uam.eps.ir.ranksys.mf.Factorizer;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import java.util.concurrent.locks.Lock;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.ranksys.fast.preference.StreamsAbstractFastPreferenceData;

/* loaded from: input_file:es/uam/eps/ir/ranksys/mf/plsa/PLSAFactorizer.class */
public class PLSAFactorizer<U, I> extends Factorizer<U, I> {
    private static final Logger LOG = Logger.getLogger(PLSAFactorizer.class.getName());
    private final int numIter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:es/uam/eps/ir/ranksys/mf/plsa/PLSAFactorizer$PLSAPreferenceData.class */
    public static class PLSAPreferenceData<U, I> extends StreamsAbstractFastPreferenceData<U, I> {
        private final FastPreferenceData<U, I> data;
        private final Long2ObjectOpenHashMap<double[]> qz;

        /* loaded from: input_file:es/uam/eps/ir/ranksys/mf/plsa/PLSAFactorizer$PLSAPreferenceData$PLSAIdxPref.class */
        public class PLSAIdxPref extends IdxPref {
            public double[] qz;

            public PLSAIdxPref(int i, double d, double[] dArr) {
                super(i, d);
                this.qz = dArr;
            }
        }

        public PLSAPreferenceData(FastPreferenceData<U, I> fastPreferenceData, int i) {
            super(fastPreferenceData, fastPreferenceData);
            this.data = fastPreferenceData;
            this.qz = new Long2ObjectOpenHashMap<>();
            fastPreferenceData.getUidxWithPreferences().forEach(i2 -> {
                fastPreferenceData.getUidxPreferences(i2).forEach(idxPref -> {
                    putQz(i2, idxPref.v1, new double[i]);
                });
            });
        }

        private double[] getQz(int i, int i2) {
            return (double[]) this.qz.get((i * this.data.numItems()) + i2);
        }

        private double[] putQz(int i, int i2, double[] dArr) {
            return (double[]) this.qz.put((i * this.data.numItems()) + i2, dArr);
        }

        public int numUsers(int i) {
            return this.data.numUsers(i);
        }

        public int numItems(int i) {
            return this.data.numItems(i);
        }

        public IntStream getUidxWithPreferences() {
            return this.data.getUidxWithPreferences();
        }

        public IntStream getIidxWithPreferences() {
            return this.data.getIidxWithPreferences();
        }

        public Stream<IdxPref> getUidxPreferences(int i) {
            return this.data.getUidxPreferences(i).map(idxPref -> {
                return new PLSAIdxPref(idxPref.v1, idxPref.v2, getQz(i, idxPref.v1));
            });
        }

        public Stream<IdxPref> getIidxPreferences(int i) {
            return this.data.getIidxPreferences(i).map(idxPref -> {
                return new PLSAIdxPref(idxPref.v1, idxPref.v2, getQz(idxPref.v1, i));
            });
        }

        public int numPreferences() {
            return this.data.numPreferences();
        }
    }

    public PLSAFactorizer(int i) {
        this.numIter = i;
    }

    @Override // es.uam.eps.ir.ranksys.mf.Factorizer
    public double error(Factorization<U, I> factorization, FastPreferenceData<U, I> fastPreferenceData) {
        DenseDoubleMatrix2D userMatrix = factorization.getUserMatrix();
        DenseDoubleMatrix2D itemMatrix = factorization.getItemMatrix();
        return fastPreferenceData.getUidxWithPreferences().parallel().mapToDouble(i -> {
            DoubleMatrix1D zMult = itemMatrix.zMult(userMatrix.viewRow(i), (DoubleMatrix1D) null);
            return fastPreferenceData.getUidxPreferences(i).mapToDouble(idxPref -> {
                return (-idxPref.v2) * zMult.getQuick(idxPref.v1);
            }).sum();
        }).sum();
    }

    @Override // es.uam.eps.ir.ranksys.mf.Factorizer
    public Factorization<U, I> factorize(int i, FastPreferenceData<U, I> fastPreferenceData) {
        Factorization<U, I> factorization = new Factorization<>(fastPreferenceData, fastPreferenceData, i, d -> {
            return Math.sqrt(1.0d / i) * Math.random();
        });
        factorize(factorization, fastPreferenceData);
        return factorization;
    }

    @Override // es.uam.eps.ir.ranksys.mf.Factorizer
    public void factorize(Factorization<U, I> factorization, FastPreferenceData<U, I> fastPreferenceData) {
        DenseDoubleMatrix2D userMatrix = factorization.getUserMatrix();
        DenseDoubleMatrix2D itemMatrix = factorization.getItemMatrix();
        IntOpenHashSet intOpenHashSet = new IntOpenHashSet(fastPreferenceData.getUidxWithPreferences().toArray());
        IntStream.range(0, userMatrix.rows()).filter(i -> {
            return !intOpenHashSet.contains(i);
        }).forEach(i2 -> {
            userMatrix.viewRow(i2).assign(0.0d);
        });
        IntOpenHashSet intOpenHashSet2 = new IntOpenHashSet(fastPreferenceData.getIidxWithPreferences().toArray());
        IntStream.range(0, itemMatrix.rows()).filter(i3 -> {
            return !intOpenHashSet2.contains(i3);
        }).forEach(i4 -> {
            itemMatrix.viewRow(i4).assign(0.0d);
        });
        PLSAPreferenceData<U, I> pLSAPreferenceData = new PLSAPreferenceData<>(fastPreferenceData, userMatrix.columns());
        for (int i5 = 0; i5 < userMatrix.columns(); i5++) {
            DoubleMatrix1D viewColumn = userMatrix.viewColumn(i5);
            viewColumn.assign(Functions.mult(1.0d / viewColumn.aggregate(Functions.plus, Functions.identity)));
        }
        itemMatrix.assign(Functions.mult(1.0d / itemMatrix.aggregate(Functions.plus, Functions.identity)));
        for (int i6 = 1; i6 <= this.numIter; i6++) {
            long nanoTime = System.nanoTime();
            expectation(userMatrix, itemMatrix, pLSAPreferenceData);
            maximization(userMatrix, itemMatrix, pLSAPreferenceData);
            int i7 = i6;
            LOG.log(Level.INFO, String.format("iteration n = %3d t = %.2fs", Integer.valueOf(i7), Double.valueOf((System.nanoTime() - nanoTime) / 1.0E9d)));
            LOG.log(Level.FINE, () -> {
                return String.format("iteration n = %3d e = %.6f", Integer.valueOf(i7), Double.valueOf(error(factorization, fastPreferenceData)));
            });
        }
    }

    private void expectation(DenseDoubleMatrix2D denseDoubleMatrix2D, DenseDoubleMatrix2D denseDoubleMatrix2D2, PLSAPreferenceData<U, I> pLSAPreferenceData) {
        pLSAPreferenceData.getUidxWithPreferences().parallel().forEach(i -> {
            pLSAPreferenceData.getUidxPreferences(i).forEach(idxPref -> {
                int i = idxPref.v1;
                double[] dArr = ((PLSAPreferenceData.PLSAIdxPref) idxPref).qz;
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    dArr[i2] = denseDoubleMatrix2D2.getQuick(i, i2) * denseDoubleMatrix2D.getQuick(i, i2);
                }
                normalizeQz(dArr);
            });
        });
    }

    private void maximization(DenseDoubleMatrix2D denseDoubleMatrix2D, DenseDoubleMatrix2D denseDoubleMatrix2D2, PLSAPreferenceData<U, I> pLSAPreferenceData) {
        Int2ObjectOpenHashMap int2ObjectOpenHashMap = new Int2ObjectOpenHashMap();
        pLSAPreferenceData.getIidxWithPreferences().forEach(i -> {
        });
        denseDoubleMatrix2D.assign(0.0d);
        denseDoubleMatrix2D2.assign(0.0d);
        pLSAPreferenceData.getUidxWithPreferences().parallel().forEach(i2 -> {
            DoubleMatrix1D viewRow = denseDoubleMatrix2D.viewRow(i2);
            pLSAPreferenceData.getUidxPreferences(i2).forEach(idxPref -> {
                int i2 = idxPref.v1;
                double d = idxPref.v2;
                double[] dArr = ((PLSAPreferenceData.PLSAIdxPref) idxPref).qz;
                Lock lock = (Lock) int2ObjectOpenHashMap.get(i2);
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    viewRow.setQuick(i3, viewRow.getQuick(i3) + (dArr[i3] * d));
                }
                lock.lock();
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    try {
                        denseDoubleMatrix2D2.setQuick(i2, i4, denseDoubleMatrix2D2.getQuick(i2, i4) + (dArr[i4] * d));
                    } finally {
                        lock.unlock();
                    }
                }
            });
        });
        for (int i3 = 0; i3 < denseDoubleMatrix2D.columns(); i3++) {
            DoubleMatrix1D viewColumn = denseDoubleMatrix2D.viewColumn(i3);
            viewColumn.assign(Functions.mult(1.0d / viewColumn.aggregate(Functions.plus, Functions.identity)));
        }
        denseDoubleMatrix2D2.assign(Functions.mult(1.0d / denseDoubleMatrix2D2.aggregate(Functions.plus, Functions.identity)));
    }

    private static void normalizeQz(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
    }
}
