package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/regression/KernelWeightedRobustRegression.class */
public class KernelWeightedRobustRegression<InputType, OutputType> extends AbstractAnytimeSupervisedBatchLearner<InputType, OutputType, Evaluator<? super InputType, ? extends OutputType>> {
    private Evaluator<? super InputType, ? extends OutputType> result;
    private SupervisedBatchLearner<InputType, OutputType, ?> iterationLearner;
    private Kernel<? super OutputType> kernelWeightingFunction;
    private double tolerance;
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_TOLERANCE = 1.0E-5d;
    private ArrayList<DefaultWeightedInputOutputPair<InputType, OutputType>> weightedData;

    public KernelWeightedRobustRegression(SupervisedBatchLearner<InputType, OutputType, ?> supervisedBatchLearner, Kernel<? super OutputType> kernel) {
        this(supervisedBatchLearner, kernel, 100, 1.0E-5d);
    }

    public KernelWeightedRobustRegression(SupervisedBatchLearner<InputType, OutputType, ?> supervisedBatchLearner, Kernel<? super OutputType> kernel, int i, double d) {
        super(i);
        setLearned(null);
        setTolerance(d);
        setKernelWeightingFunction(kernel);
        setIterationLearner(supervisedBatchLearner);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        this.weightedData = new ArrayList<>(((Collection) this.data).size());
        for (InputOutputPair inputOutputPair : (Collection) this.data) {
            this.weightedData.add(new DefaultWeightedInputOutputPair<>(inputOutputPair.getInput(), inputOutputPair.getOutput(), DatasetUtil.getWeight((InputOutputPair<?, ?>) inputOutputPair)));
        }
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        this.result = (Evaluator) this.iterationLearner.learn(this.weightedData);
        return updateWeights(this.result) > this.tolerance;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
    }

    private double updateWeights(Evaluator<? super InputType, ? extends OutputType> evaluator) {
        double d = 0.0d;
        Iterator<DefaultWeightedInputOutputPair<InputType, OutputType>> it = this.weightedData.iterator();
        while (it.hasNext()) {
            DefaultWeightedInputOutputPair<InputType, OutputType> next = it.next();
            double evaluate = this.kernelWeightingFunction.evaluate(next.getOutput(), evaluator.evaluate(next.getInput()));
            d += Math.abs(evaluate - next.getWeight());
            next.setWeight(evaluate);
        }
        return d / this.weightedData.size();
    }

    public Kernel<? super OutputType> getKernelWeightingFunction() {
        return this.kernelWeightingFunction;
    }

    public void setKernelWeightingFunction(Kernel<? super OutputType> kernel) {
        this.kernelWeightingFunction = kernel;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Tolerance must be > 0.0");
        }
        this.tolerance = d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void setLearned(Evaluator<InputType, OutputType> evaluator) {
        this.result = evaluator;
    }

    /* renamed from: getResult, reason: merged with bridge method [inline-methods] */
    public Evaluator<? super InputType, ? extends OutputType> m91getResult() {
        return this.result;
    }

    public SupervisedBatchLearner<InputType, OutputType, ?> getIterationLearner() {
        return this.iterationLearner;
    }

    public void setIterationLearner(SupervisedBatchLearner<InputType, OutputType, ?> supervisedBatchLearner) {
        this.iterationLearner = supervisedBatchLearner;
    }
}
