package gov.sandia.cognition.learning.algorithm.pca;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;

@CodeReview(reviewer = {"Kevin R. Dixon"}, date = "2008-07-23", changesNeeded = false, comments = {"Added PublicationReference to Sanger's master's thesis.", "Minor changes to javadoc.", "Looks fine."})
@PublicationReference(author = {"Terrence D. Sanger"}, title = "Optimal Unsupervised Learning in a Single-Layer Linear Feedforward Neural Network", type = PublicationType.Thesis, year = 1989, url = "http://ece-classweb.ucsd.edu/winter06/ece173/documents/Sanger%201989%20--%20Optimal%20Unsupervised%20Learning%20in%20a%20Single-layer%20Linear%20FeedforwardNN.pdf")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/pca/GeneralizedHebbianAlgorithm.class */
public class GeneralizedHebbianAlgorithm extends AbstractAnytimeBatchLearner<Collection<Vector>, PrincipalComponentsAnalysisFunction> implements PrincipalComponentsAnalysis, MeasurablePerformanceAlgorithm {
    public static final String PERFORMANCE_NAME = "Change";
    private double learningRate;
    private int numComponents;
    private PrincipalComponentsAnalysisFunction result;
    private ArrayList<Vector> components;
    private Vector mean;
    private double minChange;
    private transient double change;

    public GeneralizedHebbianAlgorithm(int i, double d, int i2, double d2) {
        super(i2);
        setNumComponents(i);
        setLearningRate(d);
        setMinChange(d2);
        setResult(null);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    /* renamed from: clone */
    public AbstractAnytimeBatchLearner<Collection<Vector>, PrincipalComponentsAnalysisFunction> mo1clone() {
        GeneralizedHebbianAlgorithm generalizedHebbianAlgorithm = (GeneralizedHebbianAlgorithm) super.mo1clone();
        generalizedHebbianAlgorithm.setData(ObjectUtil.cloneSmartElementsAsArrayList(getData()));
        generalizedHebbianAlgorithm.setResult((PrincipalComponentsAnalysisFunction) ObjectUtil.cloneSafe(m78getResult()));
        generalizedHebbianAlgorithm.mean = ObjectUtil.cloneSafe(this.mean);
        return generalizedHebbianAlgorithm;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        setData(ObjectUtil.cloneSmartElementsAsArrayList((Collection) this.data));
        int numComponents = getNumComponents();
        int dimensionality = getData().iterator().next().getDimensionality();
        if (numComponents > dimensionality) {
            throw new IllegalArgumentException("Number of EigenVectors must be <= dimension of Vectors");
        }
        this.mean = MultivariateStatisticsUtil.computeMean(getData());
        Iterator<Vector> it = getData().iterator();
        while (it.hasNext()) {
            it.next().minusEquals(this.mean);
        }
        this.components = new ArrayList<>(numComponents);
        for (int i = 0; i < numComponents; i++) {
            Vector createVector = VectorFactory.getDefault().createVector(dimensionality);
            createVector.setElement(i, 1.0d);
            this.components.add(createVector);
        }
        this.change = 0.0d;
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        Matrix createMatrix = MatrixFactory.getDefault().createMatrix(getNumComponents(), getData().iterator().next().getDimensionality());
        for (int i = 0; i < getNumComponents(); i++) {
            createMatrix.setRow(i, this.components.get(i).unitVector());
        }
        setResult(new PrincipalComponentsAnalysisFunction(this.mean, new MultivariateDiscriminant(createMatrix)));
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        ArrayList arrayList = new ArrayList(getNumComponents());
        for (int i = 0; i < getNumComponents(); i++) {
            arrayList.add(this.components.get(i).clone());
        }
        double learningRate = getLearningRate();
        for (Vector vector : getData()) {
            RingAccumulator ringAccumulator = new RingAccumulator();
            for (int i2 = 0; i2 < getNumComponents(); i2++) {
                for (int i3 = 0; i3 <= i2; i3++) {
                    Vector vector2 = this.components.get(i3);
                    ringAccumulator.accumulate(vector2.scale(vector2.dotProduct(vector)));
                }
                this.components.get(i2).plusEquals(vector.minus(ringAccumulator.getSum()).scale(this.components.get(i2).dotProduct(vector) * learningRate));
            }
        }
        double d = 0.0d;
        for (int i4 = 0; i4 < getNumComponents(); i4++) {
            d += this.components.get(i4).minus((Ring) arrayList.get(i4)).norm2();
        }
        double d2 = d / learningRate;
        boolean z = (Math.abs(d2) <= getMinChange() || Double.isNaN(d2) || Double.isInfinite(d2)) ? false : true;
        this.change = Math.abs(d2);
        return z;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("LearningRate must be (0,1]");
        }
        this.learningRate = d;
    }

    public double getMinChange() {
        return this.minChange;
    }

    public void setMinChange(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("minChange must be greater than or equal to zero");
        }
        this.minChange = d;
    }

    @Override // gov.sandia.cognition.learning.algorithm.pca.PrincipalComponentsAnalysis
    public int getNumComponents() {
        return this.numComponents;
    }

    public void setNumComponents(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of components must be > 0");
        }
        this.numComponents = i;
    }

    @Override // gov.sandia.cognition.learning.algorithm.pca.PrincipalComponentsAnalysis
    /* renamed from: getResult, reason: merged with bridge method [inline-methods] */
    public PrincipalComponentsAnalysisFunction m78getResult() {
        return this.result;
    }

    protected void setResult(PrincipalComponentsAnalysisFunction principalComponentsAnalysisFunction) {
        this.result = principalComponentsAnalysisFunction;
    }

    public double getChange() {
        return this.change;
    }

    public NamedValue<Double> getPerformance() {
        return DefaultNamedValue.create(PERFORMANCE_NAME, Double.valueOf(getChange()));
    }
}
