package gov.sandia.cognition.learning.algorithm.factor.machine;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.SparseVectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;

@PublicationReferences(references = {@PublicationReference(title = "Fast Context-aware Recommendations with Factorization Machines", author = {"Steffen Rendle", "Zeno Gantner", "Christoph Freudenthaler", "Lars Schmidt-Thieme"}, year = 2011, type = PublicationType.Conference, publication = "Proceeding of the 34th international ACM SIGIR conference on Research and development in Information Retrieval (SIGIR)", url = "http://www.inf.uni-konstanz.de/~rendle/pdf/Rendle2011-CARS.pdf"), @PublicationReference(title = "Factorization Machines with libFM", author = {"Steffen Rendle"}, year = 2012, type = PublicationType.Journal, publication = "ACM Transactions on Intelligent Systems Technology", url = "http://www.csie.ntu.edu.tw/~b97053/paper/Factorization%20Machines%20with%20libFM.pdf", notes = {"Algorithm 2: Alternating Least Squares (ALS)"})})
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/factor/machine/FactorizationMachineAlternatingLeastSquares.class */
public class FactorizationMachineAlternatingLeastSquares extends AbstractFactorizationMachineLearner {
    public static final double DEFAULT_MIN_CHANGE = 1.0E-5d;
    protected double minChange;
    protected transient int dataSize;
    protected transient ArrayList<? extends InputOutputPair<? extends Vector, Double>> dataList;
    protected transient ArrayList<Vector> inputsTransposed;
    protected double totalChange;
    protected double totalError;

    public FactorizationMachineAlternatingLeastSquares() {
        this(10, 0.0d, 0.001d, 0.01d, 0.01d, 100, 1.0E-5d, new Random());
    }

    public FactorizationMachineAlternatingLeastSquares(int i, double d, double d2, double d3, double d4, int i2, double d5, Random random) {
        super(i, d, d2, d3, d4, i2, random);
        setMinChange(d5);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.factor.machine.AbstractFactorizationMachineLearner, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public boolean initializeAlgorithm() {
        if (!super.initializeAlgorithm()) {
            return false;
        }
        this.dataSize = ((Collection) this.data).size();
        if (this.dataSize <= 0) {
            return false;
        }
        this.dataList = CollectionUtil.asArrayList((Iterable) this.data);
        SparseVectorFactory sparseDefault = VectorFactory.getSparseDefault();
        this.inputsTransposed = new ArrayList<>(this.dimensionality);
        for (int i = 0; i < this.dimensionality; i++) {
            this.inputsTransposed.add(sparseDefault.createVector(this.dataSize));
        }
        for (int i2 = 0; i2 < this.dataSize; i2++) {
            for (VectorEntry vectorEntry : this.dataList.get(i2).getInput()) {
                if (vectorEntry.getValue() != 0.0d) {
                    this.inputsTransposed.get(vectorEntry.getIndex()).set(i2, vectorEntry.getValue());
                }
            }
        }
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        this.totalChange = 0.0d;
        Vector createVector = VectorFactory.getDenseDefault().createVector(this.dataSize);
        for (int i = 0; i < this.dataSize; i++) {
            InputOutputPair<? extends Vector, Double> inputOutputPair = this.dataList.get(i);
            createVector.set(i, inputOutputPair.getOutput().doubleValue() - this.result.evaluateAsDouble(inputOutputPair.getInput()));
        }
        if (isBiasEnabled()) {
            double bias = this.result.getBias();
            double sum = ((bias * this.dataSize) + createVector.sum()) / (this.dataSize + this.biasRegularization);
            this.result.setBias(sum);
            double d = bias - sum;
            for (int i2 = 0; i2 < this.dataSize; i2++) {
                createVector.increment(i2, d);
            }
            this.totalChange += Math.abs(d);
        }
        if (isWeightsEnabled()) {
            Vector weights = this.result.getWeights();
            for (int i3 = 0; i3 < this.dimensionality; i3++) {
                double element = weights.getElement(i3);
                Vector vector = this.inputsTransposed.get(i3);
                double norm2Squared = vector.norm2Squared();
                double dot = norm2Squared == 0.0d ? 0.0d : ((element * norm2Squared) + vector.dot(createVector)) / (norm2Squared + this.weightRegularization);
                weights.set(i3, dot);
                double d2 = element - dot;
                createVector.scaledPlusEquals(d2, vector);
                this.totalChange += Math.abs(d2);
            }
            this.result.setWeights(weights);
        }
        if (isFactorsEnabled()) {
            Matrix factors = this.result.getFactors();
            for (int i4 = 0; i4 < this.factorCount; i4++) {
                Vector createVector2 = VectorFactory.getDefault().createVector(this.dataSize);
                Vector row = factors.getRow(i4);
                for (int i5 = 0; i5 < this.dataSize; i5++) {
                    createVector2.set(i5, this.dataList.get(i5).getInput().dot(row));
                }
                for (int i6 = 0; i6 < this.dimensionality; i6++) {
                    double d3 = factors.get(i4, i6);
                    Vector vector2 = this.inputsTransposed.get(i6);
                    Vector dotTimes = vector2.dotTimes(createVector2);
                    dotTimes.scaledMinusEquals(d3, vector2.dotTimes(vector2));
                    double norm2Squared2 = dotTimes.norm2Squared();
                    double dotProduct = norm2Squared2 == 0.0d ? 0.0d : ((d3 * norm2Squared2) + dotTimes.dotProduct(createVector)) / (norm2Squared2 + this.factorRegularization);
                    factors.set(i4, i6, dotProduct);
                    double d4 = d3 - dotProduct;
                    createVector.scaledPlusEquals(d4, dotTimes);
                    createVector2.scaledPlusEquals(-d4, vector2);
                    this.totalChange += Math.abs(d4);
                }
            }
            this.result.setFactors(factors);
        }
        this.totalError = createVector.norm2Squared();
        return this.totalChange >= this.minChange;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        this.dataList = null;
        this.inputsTransposed = null;
    }

    public double getTotalChange() {
        return this.totalChange;
    }

    public double getTotalError() {
        return this.totalError;
    }

    public double getRegularizationPenalty() {
        if (this.result == null) {
            return 0.0d;
        }
        double bias = this.biasRegularization * this.result.getBias();
        if (this.result.hasWeights()) {
            bias += this.weightRegularization * this.result.getWeights().norm2Squared();
        }
        if (this.result.hasFactors()) {
            bias += this.factorRegularization * this.result.getFactors().normFrobeniusSquared();
        }
        return bias;
    }

    public double getObjective() {
        return (getTotalError() / Math.max(1, this.dataSize)) + (0.5d * getRegularizationPenalty());
    }

    public double computeObjective() {
        double d = 0.0d;
        for (int i = 0; i < this.dataSize; i++) {
            InputOutputPair<? extends Vector, Double> inputOutputPair = this.dataList.get(i);
            double doubleValue = inputOutputPair.getOutput().doubleValue() - this.result.evaluateAsDouble(inputOutputPair.getInput());
            d += doubleValue * doubleValue;
        }
        return (d / Math.max(1, this.dataSize)) + (0.5d * getRegularizationPenalty());
    }

    public NamedValue<? extends Number> getPerformance() {
        return DefaultNamedValue.create("objective", Double.valueOf(getObjective()));
    }

    public double getMinChange() {
        return this.minChange;
    }

    public void setMinChange(double d) {
        ArgumentChecker.assertIsNonNegative("minChange", d);
        this.minChange = d;
    }
}
