package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.DifferentiableVectorFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;

/* loaded from: input_file:gov/sandia/cognition/learning/function/vector/DifferentiableGeneralizedLinearModel.class */
public class DifferentiableGeneralizedLinearModel extends GeneralizedLinearModel implements GradientDescendable, DifferentiableVectorFunction {
    public DifferentiableGeneralizedLinearModel() {
        this(1, 1, new AtanFunction());
    }

    public DifferentiableGeneralizedLinearModel(int i, int i2, DifferentiableUnivariateScalarFunction differentiableUnivariateScalarFunction) {
        this(new MultivariateDiscriminant(i, i2), new ElementWiseDifferentiableVectorFunction(differentiableUnivariateScalarFunction));
    }

    public DifferentiableGeneralizedLinearModel(MultivariateDiscriminant multivariateDiscriminant, DifferentiableVectorFunction differentiableVectorFunction) {
        super(multivariateDiscriminant, differentiableVectorFunction);
    }

    public DifferentiableGeneralizedLinearModel(MultivariateDiscriminant multivariateDiscriminant, DifferentiableUnivariateScalarFunction differentiableUnivariateScalarFunction) {
        this(multivariateDiscriminant, new ElementWiseDifferentiableVectorFunction(differentiableUnivariateScalarFunction));
    }

    public DifferentiableGeneralizedLinearModel(DifferentiableGeneralizedLinearModel differentiableGeneralizedLinearModel) {
        super(differentiableGeneralizedLinearModel);
    }

    @Override // gov.sandia.cognition.learning.function.vector.GeneralizedLinearModel
    public DifferentiableVectorFunction getSquashingFunction() {
        return (DifferentiableVectorFunction) super.getSquashingFunction();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // gov.sandia.cognition.learning.algorithm.gradient.ParameterGradientEvaluator
    public Matrix computeParameterGradient(Vector vector) {
        return getSquashingFunction().differentiate(getDiscriminant().evaluate(vector)).times(getDiscriminant().computeParameterGradient(vector));
    }

    @Override // gov.sandia.cognition.learning.function.vector.GeneralizedLinearModel, gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
    /* renamed from: clone */
    public DifferentiableGeneralizedLinearModel mo0clone() {
        return (DifferentiableGeneralizedLinearModel) super.mo0clone();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // gov.sandia.cognition.math.DifferentiableEvaluator
    public Matrix differentiate(Vector vector) {
        return getSquashingFunction().differentiate(getDiscriminant().evaluate(vector)).times(getDiscriminant().differentiate(vector));
    }
}
