package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableVectorFunction;
import gov.sandia.cognition.util.AbstractRandomized;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Random;

@PublicationReference(author = {"Wikipedia"}, title = "Multilayer perceptron", type = PublicationType.WebPage, year = 2009, url = "http://en.wikipedia.org/wiki/Multilayer_perceptron")
/* loaded from: input_file:gov/sandia/cognition/learning/function/vector/ThreeLayerFeedforwardNeuralNetwork.class */
public class ThreeLayerFeedforwardNeuralNetwork extends AbstractRandomized implements VectorizableVectorFunction, VectorInputEvaluator<Vector, Vector>, VectorOutputEvaluator<Vector, Vector>, GradientDescendable {
    public static final double DEFAULT_INITIALIZATION_RANGE = 0.001d;
    public static final DifferentiableUnivariateScalarFunction DEFAULT_SQUASHING_FUNCTION = new AtanFunction();
    public static final int DEFAULT_RANDOM_SEED = 1;
    protected Matrix inputToHiddenWeights;
    protected Vector inputToHiddenBiasWeights;
    protected Matrix hiddenToOutputWeights;
    protected Vector hiddenToOutputBiasWeights;
    private DifferentiableUnivariateScalarFunction squashingFunction;
    private double initializationRange;

    public ThreeLayerFeedforwardNeuralNetwork() {
        this(1, 1, 1);
    }

    public ThreeLayerFeedforwardNeuralNetwork(int i, int i2, int i3) {
        this(i, i2, i3, DEFAULT_SQUASHING_FUNCTION);
    }

    public ThreeLayerFeedforwardNeuralNetwork(int i, int i2, int i3, DifferentiableUnivariateScalarFunction differentiableUnivariateScalarFunction) {
        this(i, i2, i3, differentiableUnivariateScalarFunction, 1, 0.001d);
    }

    public ThreeLayerFeedforwardNeuralNetwork(int i, int i2, int i3, DifferentiableUnivariateScalarFunction differentiableUnivariateScalarFunction, int i4, double d) {
        super(new Random(i4));
        setInitializationRange(d);
        setSquashingFunction(differentiableUnivariateScalarFunction);
        initializeWeights(i, i2, i3);
    }

    @Override // gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public ThreeLayerFeedforwardNeuralNetwork mo237clone() {
        ThreeLayerFeedforwardNeuralNetwork clone = super.clone();
        clone.inputToHiddenWeights = ObjectUtil.cloneSafe(this.inputToHiddenWeights);
        clone.inputToHiddenBiasWeights = ObjectUtil.cloneSafe(this.inputToHiddenBiasWeights);
        clone.hiddenToOutputWeights = ObjectUtil.cloneSafe(this.hiddenToOutputWeights);
        clone.hiddenToOutputBiasWeights = ObjectUtil.cloneSafe(this.hiddenToOutputBiasWeights);
        clone.squashingFunction = ObjectUtil.cloneSafe(this.squashingFunction);
        return clone;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // gov.sandia.cognition.learning.algorithm.gradient.ParameterGradientEvaluator
    public Matrix computeParameterGradient(Vector vector) {
        int inputDimensionality = getInputDimensionality();
        int hiddenDimensionality = getHiddenDimensionality();
        int outputDimensionality = getOutputDimensionality();
        int i = inputDimensionality * hiddenDimensionality;
        int i2 = hiddenDimensionality * outputDimensionality;
        int i3 = i + hiddenDimensionality + i2 + outputDimensionality;
        Vector evaluateHiddenLayerActivation = evaluateHiddenLayerActivation(vector);
        Vector evaluateSquashedHiddenLayerActivation = evaluateSquashedHiddenLayerActivation(evaluateHiddenLayerActivation);
        double[] dArr = new double[hiddenDimensionality];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr[i4] = this.squashingFunction.differentiate(evaluateHiddenLayerActivation.getElement(i4));
        }
        Matrix createMatrix = MatrixFactory.getDefault().createMatrix(outputDimensionality, i3);
        int i5 = (i3 - outputDimensionality) - i2;
        for (int i6 = 0; i6 < hiddenDimensionality; i6++) {
            double element = evaluateSquashedHiddenLayerActivation.getElement(i6);
            for (int i7 = 0; i7 < outputDimensionality; i7++) {
                createMatrix.setElement(i7, i5, element);
                i5++;
            }
        }
        int i8 = i3 - outputDimensionality;
        for (int i9 = 0; i9 < outputDimensionality; i9++) {
            createMatrix.setElement(i9, i9 + i8, 1.0d);
        }
        int i10 = hiddenDimensionality * inputDimensionality;
        for (int i11 = 0; i11 < outputDimensionality; i11++) {
            for (int i12 = 0; i12 < hiddenDimensionality; i12++) {
                double element2 = this.hiddenToOutputWeights.getElement(i11, i12);
                double d = dArr[i12];
                createMatrix.setElement(i11, i12 + i10, element2 * d);
                for (int i13 = 0; i13 < inputDimensionality; i13++) {
                    createMatrix.setElement(i11, (i13 * hiddenDimensionality) + i12, element2 * d * vector.getElement(i13));
                }
            }
        }
        return createMatrix;
    }

    public Vector convertToVector() {
        Vector convertToVector = this.inputToHiddenWeights.convertToVector();
        Vector vector = this.inputToHiddenBiasWeights;
        Vector convertToVector2 = this.hiddenToOutputWeights.convertToVector();
        Vector vector2 = this.hiddenToOutputBiasWeights;
        Vector createVector = VectorFactory.getDefault().createVector(convertToVector.getDimensionality() + vector.getDimensionality() + convertToVector2.getDimensionality() + vector2.getDimensionality());
        int i = 0;
        for (int i2 = 0; i2 < convertToVector.getDimensionality(); i2++) {
            createVector.setElement(i, convertToVector.getElement(i2));
            i++;
        }
        for (int i3 = 0; i3 < vector.getDimensionality(); i3++) {
            createVector.setElement(i, vector.getElement(i3));
            i++;
        }
        for (int i4 = 0; i4 < convertToVector2.getDimensionality(); i4++) {
            createVector.setElement(i, convertToVector2.getElement(i4));
            i++;
        }
        for (int i5 = 0; i5 < vector2.getDimensionality(); i5++) {
            createVector.setElement(i, vector2.getElement(i5));
            i++;
        }
        return createVector;
    }

    public int getNumParameters() {
        int inputDimensionality = getInputDimensionality();
        int hiddenDimensionality = getHiddenDimensionality();
        int outputDimensionality = getOutputDimensionality();
        int i = inputDimensionality * hiddenDimensionality;
        return i + hiddenDimensionality + (hiddenDimensionality * outputDimensionality) + outputDimensionality;
    }

    public void convertFromVector(Vector vector) {
        int inputDimensionality = getInputDimensionality();
        int hiddenDimensionality = getHiddenDimensionality();
        int outputDimensionality = getOutputDimensionality();
        int i = inputDimensionality * hiddenDimensionality;
        int i2 = hiddenDimensionality * outputDimensionality;
        int i3 = i + hiddenDimensionality + i2 + outputDimensionality;
        vector.assertDimensionalityEquals(i3);
        Vector subVector = vector.subVector(0, i - 1);
        Vector subVector2 = vector.subVector(i, (i + hiddenDimensionality) - 1);
        Vector subVector3 = vector.subVector(i + hiddenDimensionality, ((i + hiddenDimensionality) + i2) - 1);
        Vector subVector4 = vector.subVector(i + hiddenDimensionality + i2, i3 - 1);
        this.inputToHiddenWeights.convertFromVector(subVector);
        this.inputToHiddenBiasWeights = subVector2;
        this.hiddenToOutputWeights.convertFromVector(subVector3);
        this.hiddenToOutputBiasWeights = subVector4;
    }

    public Vector evaluate(Vector vector) {
        return evaluateOutputFromSquashedHiddenLayerActivation(evaluateSquashedHiddenLayerActivation(evaluateHiddenLayerActivation(vector)));
    }

    protected Vector evaluateHiddenLayerActivation(Vector vector) {
        Vector times = this.inputToHiddenWeights.times(vector);
        times.plusEquals(this.inputToHiddenBiasWeights);
        return times;
    }

    protected Vector evaluateSquashedHiddenLayerActivation(Vector vector) {
        return ElementWiseVectorFunction.evaluate(vector, getSquashingFunction());
    }

    protected Vector evaluateOutputFromSquashedHiddenLayerActivation(Vector vector) {
        Vector times = this.hiddenToOutputWeights.times(vector);
        times.plusEquals(this.hiddenToOutputBiasWeights);
        return times;
    }

    public void reinitializeWeights() {
        initializeWeights(getInputDimensionality(), getHiddenDimensionality(), getOutputDimensionality());
    }

    public void initializeWeights(int i, int i2, int i3) {
        if (i < 1) {
            throw new IllegalArgumentException("inputDimensionality must be >= 1");
        }
        if (i2 < 1) {
            throw new IllegalArgumentException("hiddenDimensionality must be >= 1");
        }
        if (i3 < 1) {
            throw new IllegalArgumentException("outputDimensionality must be >= 1");
        }
        this.inputToHiddenWeights = MatrixFactory.getDefault().createUniformRandom(i2, i, -getInitializationRange(), getInitializationRange(), getRandom());
        this.inputToHiddenBiasWeights = VectorFactory.getDefault().createUniformRandom(i2, -getInitializationRange(), getInitializationRange(), this.random);
        this.hiddenToOutputWeights = MatrixFactory.getDefault().createUniformRandom(i3, i2, -getInitializationRange(), getInitializationRange(), getRandom());
        this.hiddenToOutputBiasWeights = VectorFactory.getDefault().createUniformRandom(i3, -getInitializationRange(), getInitializationRange(), this.random);
    }

    public int getOutputDimensionality() {
        return this.hiddenToOutputWeights.getNumRows();
    }

    public void setOutputDimensionality(int i) {
        initializeWeights(getInputDimensionality(), getHiddenDimensionality(), i);
    }

    public int getHiddenDimensionality() {
        return this.hiddenToOutputWeights.getNumColumns();
    }

    public void setHiddenDimensionality(int i) {
        initializeWeights(getInputDimensionality(), i, getOutputDimensionality());
    }

    public int getInputDimensionality() {
        return this.inputToHiddenWeights.getNumColumns();
    }

    public void setInputDimensionality(int i) {
        initializeWeights(i, getHiddenDimensionality(), getOutputDimensionality());
    }

    public DifferentiableUnivariateScalarFunction getSquashingFunction() {
        return this.squashingFunction;
    }

    public void setSquashingFunction(DifferentiableUnivariateScalarFunction differentiableUnivariateScalarFunction) {
        if (differentiableUnivariateScalarFunction == null) {
            throw new IllegalArgumentException("Squashing function cannot be null!");
        }
        this.squashingFunction = differentiableUnivariateScalarFunction;
    }

    public double getInitializationRange() {
        return this.initializationRange;
    }

    public void setInitializationRange(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("initializationRange must be >= 0.0");
        }
        this.initializationRange = d;
    }
}
