package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.learning.function.scalar.KernelScalarFunction;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;

@PublicationReference(author = {"John Shawe-Taylor", "Nello Cristianini"}, title = "Kernel Methods for Pattern Analysis", type = PublicationType.Book, year = 2004, url = "http://www.kernel-methods.net/")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/regression/KernelBasedIterativeRegression.class */
public class KernelBasedIterativeRegression<InputType> extends AbstractAnytimeSupervisedBatchLearner<InputType, Double, KernelScalarFunction<InputType>> implements MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_MIN_SENSITIVITY = 10.0d;
    private Kernel<? super InputType> kernel;
    private double minSensitivity;
    private KernelScalarFunction<InputType> result;
    private int errorCount;
    private transient LinkedHashMap<InputOutputPair<? extends InputType, Double>, DefaultWeightedValue<InputType>> supportsMap;

    public KernelBasedIterativeRegression() {
        this(null);
    }

    public KernelBasedIterativeRegression(Kernel<? super InputType> kernel) {
        this(kernel, 10.0d);
    }

    public KernelBasedIterativeRegression(Kernel<? super InputType> kernel, double d) {
        this(kernel, d, 100);
    }

    public KernelBasedIterativeRegression(Kernel<? super InputType> kernel, double d, int i) {
        super(i);
        setKernel(kernel);
        setMinSensitivity(d);
        setResult(null);
        setErrorCount(0);
        setSupportsMap(null);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    /* renamed from: clone */
    public KernelBasedIterativeRegression<InputType> mo1clone() {
        KernelBasedIterativeRegression<InputType> kernelBasedIterativeRegression = (KernelBasedIterativeRegression) super.mo1clone();
        kernelBasedIterativeRegression.setKernel((Kernel) ObjectUtil.cloneSmart(getKernel()));
        kernelBasedIterativeRegression.setResult((KernelScalarFunction) ObjectUtil.cloneSafe(m96getResult()));
        kernelBasedIterativeRegression.setSupportsMap((LinkedHashMap) ObjectUtil.cloneSmart(getSupportsMap()));
        return kernelBasedIterativeRegression;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        if (getData() == null) {
            return false;
        }
        int i = 0;
        Iterator it = getData().iterator();
        while (it.hasNext()) {
            if (((InputOutputPair) it.next()) != null) {
                i++;
            }
        }
        if (i <= 0) {
            return false;
        }
        setErrorCount(i);
        setSupportsMap(new LinkedHashMap<>());
        setResult(new KernelScalarFunction<>(getKernel(), getSupportsMap().values(), 0.0d));
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        setErrorCount(0);
        if (getData().size() == 1) {
            InputOutputPair inputOutputPair = (InputOutputPair) getData().iterator().next();
            m96getResult().getExamples().clear();
            m96getResult().setBias(((Double) inputOutputPair.getOutput()).doubleValue());
            return false;
        }
        for (InputOutputPair<? extends InputType, Double> inputOutputPair2 : getData()) {
            if (inputOutputPair2 != null) {
                Object input = inputOutputPair2.getInput();
                double doubleValue = inputOutputPair2.getOutput().doubleValue() - this.result.evaluate((KernelScalarFunction<InputType>) input).doubleValue();
                DefaultWeightedValue<InputType> defaultWeightedValue = this.supportsMap.get(inputOutputPair2);
                double weight = defaultWeightedValue == null ? 0.0d : defaultWeightedValue.getWeight();
                double d = weight;
                if (Math.abs(doubleValue) >= this.minSensitivity) {
                    double d2 = weight == 0.0d ? Math.abs(doubleValue - this.minSensitivity) <= Math.abs(doubleValue + this.minSensitivity) ? doubleValue - this.minSensitivity : doubleValue + this.minSensitivity : weight > 0.0d ? doubleValue - this.minSensitivity : doubleValue + this.minSensitivity;
                    double evaluate = this.kernel.evaluate(input, input);
                    if (evaluate != 0.0d) {
                        d2 /= evaluate;
                    }
                    d = weight + d2;
                    if (weight * d < 0.0d) {
                        d = 0.0d;
                    }
                }
                double d3 = d - weight;
                if (d3 != 0.0d) {
                    setErrorCount(getErrorCount() + 1);
                    double bias = this.result.getBias() + d3;
                    if (defaultWeightedValue == null) {
                        this.supportsMap.put(inputOutputPair2, new DefaultWeightedValue<>(input, d));
                    } else if (d == 0.0d) {
                        this.supportsMap.remove(inputOutputPair2);
                    } else {
                        defaultWeightedValue.setWeight(d);
                    }
                    this.result.setBias(bias);
                }
            }
        }
        return getErrorCount() > 0;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        if (getSupportsMap() != null) {
            m96getResult().setExamples(new ArrayList(getSupportsMap().values()));
            setSupportsMap(null);
        }
    }

    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    /* renamed from: getResult, reason: merged with bridge method [inline-methods] */
    public KernelScalarFunction<InputType> m96getResult() {
        return this.result;
    }

    protected void setResult(KernelScalarFunction<InputType> kernelScalarFunction) {
        this.result = kernelScalarFunction;
    }

    public int getErrorCount() {
        return this.errorCount;
    }

    protected void setErrorCount(int i) {
        this.errorCount = i;
    }

    protected LinkedHashMap<InputOutputPair<? extends InputType, Double>, DefaultWeightedValue<InputType>> getSupportsMap() {
        return this.supportsMap;
    }

    protected void setSupportsMap(LinkedHashMap<InputOutputPair<? extends InputType, Double>, DefaultWeightedValue<InputType>> linkedHashMap) {
        this.supportsMap = linkedHashMap;
    }

    public double getMinSensitivity() {
        return this.minSensitivity;
    }

    public void setMinSensitivity(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("minSensitivity must be non-negative.");
        }
        this.minSensitivity = d;
    }

    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue("error count", Integer.valueOf(getErrorCount()));
    }
}
