package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:gov/sandia/cognition/learning/function/cost/SumSquaredErrorCostFunction.class */
public class SumSquaredErrorCostFunction extends AbstractParallelizableCostFunction {

    /* loaded from: input_file:gov/sandia/cognition/learning/function/cost/SumSquaredErrorCostFunction$Cache.class */
    public static class Cache extends AbstractCloneableSerializable {
        public final Matrix J;
        public final Matrix JtJ;
        public final Vector Jte;
        public final double parameterCost;

        protected Cache(Matrix matrix, Matrix matrix2, Vector vector, double d) {
            this.J = matrix;
            this.JtJ = matrix2;
            this.Jte = vector;
            this.parameterCost = d;
        }

        public static Cache compute(GradientDescendable gradientDescendable, Collection<? extends InputOutputPair<? extends Vector, Vector>> collection) {
            RingAccumulator ringAccumulator = new RingAccumulator();
            RingAccumulator ringAccumulator2 = new RingAccumulator();
            double d = 0.0d;
            double d2 = 0.0d;
            for (InputOutputPair<? extends Vector, Vector> inputOutputPair : collection) {
                Vector evaluate = gradientDescendable.evaluate(inputOutputPair.getInput());
                evaluate.minusEquals(inputOutputPair.getOutput());
                double norm2Squared = evaluate.norm2Squared();
                double weight = DatasetUtil.getWeight(inputOutputPair);
                if (weight != 1.0d) {
                    evaluate.scaleEquals(weight);
                }
                d += weight;
                d2 += norm2Squared * weight;
                Matrix computeParameterGradient = gradientDescendable.computeParameterGradient(inputOutputPair.getInput());
                ringAccumulator.accumulate((RingAccumulator) computeParameterGradient);
                ringAccumulator2.accumulate((RingAccumulator) evaluate.times(computeParameterGradient));
            }
            double d3 = d * 2.0d;
            if (d3 == 0.0d) {
                d3 = 1.0d;
            }
            Matrix matrix = (Matrix) ringAccumulator.getSum();
            matrix.scaleEquals(1.0d / d3);
            Matrix times = matrix.transpose().times(matrix);
            Vector vector = (Vector) ringAccumulator2.getSum();
            vector.scaleEquals(1.0d / d3);
            return new Cache(matrix, times, vector, d2 / d3);
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/learning/function/cost/SumSquaredErrorCostFunction$EvaluatePartialSSE.class */
    private static class EvaluatePartialSSE extends DefaultPair<Double, Double> {
        public EvaluatePartialSSE(Double d, Double d2) {
            super(d, d2);
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/learning/function/cost/SumSquaredErrorCostFunction$GradientPartialSSE.class */
    public static class GradientPartialSSE extends DefaultPair<Vector, Double> {
        public GradientPartialSSE(Vector vector, Double d) {
            super(vector, d);
        }
    }

    public SumSquaredErrorCostFunction() {
        this((Collection) null);
    }

    public SumSquaredErrorCostFunction(Collection<? extends InputOutputPair<? extends Vector, Vector>> collection) {
        super(collection);
    }

    @Override // gov.sandia.cognition.learning.function.cost.AbstractSupervisedCostFunction, gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public SumSquaredErrorCostFunction mo0clone() {
        return (SumSquaredErrorCostFunction) super.mo0clone();
    }

    @Override // gov.sandia.cognition.learning.function.cost.ParallelizableCostFunction
    public Object evaluatePartial(Evaluator<? super Vector, ? extends Vector> evaluator) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (InputOutputPair<? extends Vector, Vector> inputOutputPair : getCostParameters()) {
            double euclideanDistanceSquared = inputOutputPair.getOutput().euclideanDistanceSquared(evaluator.evaluate(inputOutputPair.getInput()));
            double weight = DatasetUtil.getWeight(inputOutputPair);
            d2 += weight;
            d += weight * euclideanDistanceSquared;
        }
        return new EvaluatePartialSSE(Double.valueOf(d), Double.valueOf(d2 * 2.0d));
    }

    @Override // gov.sandia.cognition.learning.function.cost.ParallelizableCostFunction
    public Double evaluateAmalgamate(Collection<Object> collection) {
        double d = 0.0d;
        double d2 = 0.0d;
        Iterator<Object> it = collection.iterator();
        while (it.hasNext()) {
            EvaluatePartialSSE evaluatePartialSSE = (EvaluatePartialSSE) it.next();
            d += evaluatePartialSSE.getFirst().doubleValue();
            d2 += evaluatePartialSSE.getSecond().doubleValue();
        }
        return d2 == 0.0d ? Double.valueOf(0.0d) : Double.valueOf(d / d2);
    }

    @Override // gov.sandia.cognition.learning.function.cost.ParallelizableCostFunction
    public Object computeParameterGradientPartial(GradientDescendable gradientDescendable) {
        RingAccumulator ringAccumulator = new RingAccumulator();
        double d = 0.0d;
        for (InputOutputPair<? extends Vector, Vector> inputOutputPair : getCostParameters()) {
            Vector input = inputOutputPair.getInput();
            Vector output = inputOutputPair.getOutput();
            Vector evaluate = gradientDescendable.evaluate(input);
            evaluate.minusEquals(output);
            double weight = DatasetUtil.getWeight(inputOutputPair);
            if (weight != 1.0d) {
                evaluate.scaleEquals(weight);
            }
            d += weight;
            ringAccumulator.accumulate((RingAccumulator) evaluate.times(gradientDescendable.computeParameterGradient(input)));
        }
        return new GradientPartialSSE((Vector) ringAccumulator.getSum(), Double.valueOf(d));
    }

    @Override // gov.sandia.cognition.learning.function.cost.ParallelizableCostFunction
    public Vector computeParameterGradientAmalgamate(Collection<Object> collection) {
        RingAccumulator ringAccumulator = new RingAccumulator();
        double d = 0.0d;
        Iterator<Object> it = collection.iterator();
        while (it.hasNext()) {
            GradientPartialSSE gradientPartialSSE = (GradientPartialSSE) it.next();
            ringAccumulator.accumulate((RingAccumulator) gradientPartialSSE.getFirst());
            d += gradientPartialSSE.getSecond().doubleValue();
        }
        Vector vector = (Vector) ringAccumulator.getSum();
        if (d != 0.0d) {
            vector.scaleEquals(1.0d / (2.0d * d));
        }
        return vector;
    }

    @Override // gov.sandia.cognition.learning.function.cost.AbstractSupervisedCostFunction, gov.sandia.cognition.learning.performance.AbstractSupervisedPerformanceEvaluator, gov.sandia.cognition.learning.performance.SupervisedPerformanceEvaluator
    public Double evaluatePerformance(Collection<? extends TargetEstimatePair<? extends Vector, ? extends Vector>> collection) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (TargetEstimatePair<? extends Vector, ? extends Vector> targetEstimatePair : collection) {
            double euclideanDistanceSquared = targetEstimatePair.getTarget().euclideanDistanceSquared(targetEstimatePair.getEstimate());
            double weight = DatasetUtil.getWeight(targetEstimatePair);
            d2 += weight;
            d += weight * euclideanDistanceSquared;
        }
        double d3 = d2 * 2.0d;
        return d3 == 0.0d ? Double.valueOf(0.0d) : Double.valueOf(d / d3);
    }

    @Override // gov.sandia.cognition.learning.function.cost.AbstractSupervisedCostFunction, gov.sandia.cognition.learning.performance.AbstractSupervisedPerformanceEvaluator, gov.sandia.cognition.learning.performance.SupervisedPerformanceEvaluator
    public /* bridge */ /* synthetic */ Object evaluatePerformance(Collection collection) {
        return evaluatePerformance((Collection<? extends TargetEstimatePair<? extends Vector, ? extends Vector>>) collection);
    }
}
