package gov.sandia.cognition.learning.algorithm.factor.machine;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Randomized;
import java.util.Random;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/factor/machine/AbstractFactorizationMachineLearner.class */
public abstract class AbstractFactorizationMachineLearner extends AbstractAnytimeSupervisedBatchLearner<Vector, Double, FactorizationMachine> implements Randomized, MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_FACTOR_COUNT = 10;
    public static final boolean DEFAULT_BIAS_ENABLED = true;
    public static final boolean DEFAULT_WEIGHTS_ENABLED = true;
    public static final double DEFAULT_BIAS_REGULARIZATION = 0.0d;
    public static final double DEFAULT_WEIGHT_REGULARIZATION = 0.001d;
    public static final double DEFAULT_FACTOR_REGULARIZATION = 0.01d;
    public static final double DEFAULT_SEED_SCALE = 0.01d;
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    protected boolean biasEnabled;
    protected boolean weightsEnabled;
    protected int factorCount;
    protected double biasRegularization;
    protected double weightRegularization;
    protected double factorRegularization;
    protected double seedScale;
    protected Random random;
    protected transient FactorizationMachine result;
    protected transient int dimensionality;

    public AbstractFactorizationMachineLearner() {
        this(10, 0.0d, 0.001d, 0.01d, 0.01d, 100, new Random());
    }

    public AbstractFactorizationMachineLearner(int i, double d, double d2, double d3, double d4, int i2, Random random) {
        super(i2);
        setFactorCount(i);
        setBiasEnabled(true);
        setWeightsEnabled(true);
        setBiasRegularization(d);
        setWeightRegularization(d2);
        setFactorRegularization(d3);
        setSeedScale(d4);
        setRandom(random);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public boolean initializeAlgorithm() {
        Matrix createMatrix;
        this.dimensionality = DatasetUtil.getInputDimensionality((Iterable) this.data);
        Vector createVector = VectorFactory.getDenseDefault().createVector(this.dimensionality);
        if (this.factorCount <= 0) {
            createMatrix = null;
        } else {
            createMatrix = MatrixFactory.getDenseDefault().createMatrix(this.factorCount, this.dimensionality);
            for (int i = 0; i < this.dimensionality; i++) {
                for (int i2 = 0; i2 < this.factorCount; i2++) {
                    createMatrix.setElement(i2, i, this.seedScale * this.random.nextGaussian());
                }
            }
        }
        this.result = new FactorizationMachine(0.0d, createVector, createMatrix);
        return true;
    }

    @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
    /* renamed from: getResult */
    public FactorizationMachine getResult2() {
        return this.result;
    }

    public int getFactorCount() {
        return this.factorCount;
    }

    public void setFactorCount(int i) {
        ArgumentChecker.assertIsNonNegative("factorCount", i);
        this.factorCount = i;
    }

    public boolean isBiasEnabled() {
        return this.biasEnabled;
    }

    public void setBiasEnabled(boolean z) {
        this.biasEnabled = z;
    }

    public boolean isWeightsEnabled() {
        return this.weightsEnabled;
    }

    public void setWeightsEnabled(boolean z) {
        this.weightsEnabled = z;
    }

    public boolean isFactorsEnabled() {
        return getFactorCount() > 0;
    }

    public double getBiasRegularization() {
        return this.biasRegularization;
    }

    public void setBiasRegularization(double d) {
        ArgumentChecker.assertIsNonNegative("biasRegularization", d);
        this.biasRegularization = d;
    }

    public double getWeightRegularization() {
        return this.weightRegularization;
    }

    public void setWeightRegularization(double d) {
        ArgumentChecker.assertIsNonNegative("weightRegularization", d);
        this.weightRegularization = d;
    }

    public double getFactorRegularization() {
        return this.factorRegularization;
    }

    public void setFactorRegularization(double d) {
        ArgumentChecker.assertIsNonNegative("factorRegularization", d);
        this.factorRegularization = d;
    }

    public double getSeedScale() {
        return this.seedScale;
    }

    public void setSeedScale(double d) {
        ArgumentChecker.assertIsNonNegative("seedScale", d);
        this.seedScale = d;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public Random getRandom() {
        return this.random;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public void setRandom(Random random) {
        this.random = random;
    }
}
