package gov.sandia.cognition.learning.algorithm.svm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.DiscreteSamplingUtil;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

@PublicationReference(author = {"Shai Shalev-Shwartz", "Yoram Singer", "Nathan Srebro"}, title = "Pegasos: Primal Estimated sub-GrAdient SOlver for SVM", year = 2007, type = PublicationType.Conference, publication = "Proceedings of the 24th International Conference on Machine Learning", url = "http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.74.8513")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/svm/PrimalEstimatedSubGradient.class */
public class PrimalEstimatedSubGradient extends AbstractAnytimeSupervisedBatchLearner<Vectorizable, Boolean, LinearBinaryCategorizer> implements Randomized {
    public static final int DEFAULT_SAMPLE_SIZE = 100;
    public static final double DEFAULT_REGULARIZATION_WEIGHT = 1.0E-4d;
    public static final int DEFAULT_MAX_ITERATIONS = 10000;
    protected int sampleSize;
    protected double regularizationWeight;
    protected Random random;
    protected transient int dataSize;
    protected transient ArrayList<? extends InputOutputPair<? extends Vectorizable, Boolean>> dataList;
    protected transient int dimensionality;
    protected transient int dataSampleSize;
    protected transient Vector update;
    protected transient LinearBinaryCategorizer result;

    public PrimalEstimatedSubGradient() {
        this(100, 1.0E-4d, 10000, new Random());
    }

    public PrimalEstimatedSubGradient(int i, double d, int i2, Random random) {
        super(i2);
        setSampleSize(i);
        setRegularizationWeight(d);
        setRandom(random);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        if (CollectionUtil.isEmpty((Collection) this.data)) {
            return false;
        }
        this.dataSize = ((Collection) this.data).size();
        this.dataList = CollectionUtil.asArrayList((Iterable) this.data);
        this.dimensionality = DatasetUtil.getInputDimensionality((Iterable) this.data);
        this.dataSampleSize = Math.min(this.dataSize, this.sampleSize);
        VectorFactory denseDefault = VectorFactory.getDenseDefault();
        this.update = denseDefault.createVector(this.dimensionality);
        double sqrt = Math.sqrt(this.regularizationWeight);
        double d = 1.0d / (this.dimensionality * sqrt);
        Vector createUniformRandom = denseDefault.createUniformRandom(this.dimensionality, -d, d, this.random);
        if (createUniformRandom.norm2() < 1.0d / sqrt) {
            createUniformRandom.unitVectorEquals();
            createUniformRandom.scaleEquals(1.0d / sqrt);
        }
        this.result = new LinearBinaryCategorizer(createUniformRandom, 0.0d);
        this.update = denseDefault.createVector(this.dimensionality);
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        List<InputOutputPair> sampleWithoutReplacement = DiscreteSamplingUtil.sampleWithoutReplacement(this.random, this.dataList, this.dataSampleSize);
        double d = this.regularizationWeight;
        double d2 = 1.0d / (d * this.iteration);
        this.update.zero();
        double d3 = 0.0d;
        int i = 0;
        for (InputOutputPair inputOutputPair : sampleWithoutReplacement) {
            boolean booleanValue = ((Boolean) inputOutputPair.getOutput()).booleanValue();
            double d4 = booleanValue ? 1.0d : -1.0d;
            if (d4 * this.result.evaluateAsDouble((Vectorizable) inputOutputPair.getInput()) < 1.0d) {
                i++;
                Vector convertToVector = ((Vectorizable) inputOutputPair.getInput()).convertToVector();
                if (booleanValue) {
                    this.update.plusEquals(convertToVector);
                } else {
                    this.update.minusEquals(convertToVector);
                }
                d3 += d4;
            }
        }
        Vector weights = this.result.getWeights();
        weights.scaleEquals(1.0d - (d2 * d));
        double size = d2 / sampleWithoutReplacement.size();
        this.update.scaleEquals(size);
        weights.plusEquals(this.update);
        double bias = this.result.getBias() + (d3 * size);
        double sqrt = 1.0d / Math.sqrt(d * weights.norm2Squared());
        if (sqrt < 1.0d) {
            weights.scaleEquals(sqrt);
        }
        this.result.setWeights(weights);
        this.result.setBias(bias);
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        this.dataList = null;
        this.update = null;
    }

    /* renamed from: getResult, reason: merged with bridge method [inline-methods] */
    public LinearBinaryCategorizer m124getResult() {
        return this.result;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public void setSampleSize(int i) {
        ArgumentChecker.assertIsPositive("sampleSize", i);
        this.sampleSize = i;
    }

    public double getRegularizationWeight() {
        return this.regularizationWeight;
    }

    public void setRegularizationWeight(double d) {
        ArgumentChecker.assertIsPositive("regularizationWeight", d);
        this.regularizationWeight = d;
    }

    public Random getRandom() {
        return this.random;
    }

    public void setRandom(Random random) {
        this.random = random;
    }
}
