package gov.sandia.cognition.learning.algorithm.svm;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;

@CodeReview(reviewer = {"Kevin R. Dixon"}, date = "2008-07-23", changesNeeded = false, comments = {"Minor cosmetic to javadoc.", "Great looking code."})
@PublicationReference(author = {"Olvi L. Mangasarian", "David R. Musicant"}, title = "Successive Overrelaxation for Support Vector Machines", type = PublicationType.Journal, year = 1999, publication = "IEEE Transactions on Neural Networks", pages = {1032, 1037}, url = "ftp://ftp.cs.wisc.edu/math-prog/tech-reports/98-18.ps")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/svm/SuccessiveOverrelaxation.class */
public class SuccessiveOverrelaxation<InputType> extends AbstractAnytimeSupervisedBatchLearner<InputType, Boolean, KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>>> implements MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_MAX_ITERATIONS = 1000;
    public static final double DEFAULT_MAX_WEIGHT = 100.0d;
    public static final double DEFAULT_OVERRELAXATION = 1.3d;
    public static final double DEFAULT_MIN_CHANGE = 1.0E-4d;
    protected Kernel<? super InputType> kernel;
    protected double maxWeight;
    protected double overrelaxation;
    protected double minChange;
    protected KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> result;
    protected double totalChange;
    protected ArrayList<SuccessiveOverrelaxation<InputType>.Entry> entries;
    protected LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, SuccessiveOverrelaxation<InputType>.Entry> supportsMap;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/svm/SuccessiveOverrelaxation$Entry.class */
    public class Entry extends DefaultWeightedValue<InputType> implements Comparable<SuccessiveOverrelaxation<InputType>.Entry> {
        protected InputOutputPair<? extends InputType, ? extends Boolean> example;
        protected boolean output;
        protected double outputDouble;
        protected boolean supportInserted;
        protected double selfKernel;
        protected double previousStepWeight;

        protected Entry(InputOutputPair<? extends InputType, ? extends Boolean> inputOutputPair) {
            super(inputOutputPair.getInput(), 0.0d);
            InputType input = inputOutputPair.getInput();
            this.example = inputOutputPair;
            this.output = inputOutputPair.getOutput().booleanValue();
            this.outputDouble = this.output ? 1.0d : -1.0d;
            this.supportInserted = false;
            this.selfKernel = SuccessiveOverrelaxation.this.kernel.evaluate(input, input);
            this.previousStepWeight = 0.0d;
        }

        public InputType getInput() {
            return (InputType) this.value;
        }

        public boolean getOutput() {
            return this.output;
        }

        public void setUnlabeledWeight(double d) {
            this.weight = this.output ? d : -d;
        }

        public double getUnlabeledWeight() {
            return this.output ? this.weight : -this.weight;
        }

        @Override // java.lang.Comparable
        public int compareTo(SuccessiveOverrelaxation<InputType>.Entry entry) {
            return Double.compare(getUnlabeledWeight(), entry.getUnlabeledWeight());
        }
    }

    public SuccessiveOverrelaxation() {
        this(null);
    }

    public SuccessiveOverrelaxation(Kernel<? super InputType> kernel) {
        this(kernel, 100.0d, 1.3d, 1.0E-4d, 1000);
    }

    public SuccessiveOverrelaxation(Kernel<? super InputType> kernel, double d, double d2, double d3, int i) {
        super(i);
        setKernel(kernel);
        setMaxWeight(d);
        setOverrelaxation(d2);
        setMinChange(d3);
        setEntries(null);
        setResult(null);
        setTotalChange(0.0d);
        setSupportsMap(null);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        if (getData() == null) {
            return false;
        }
        int i = 0;
        Iterator it = getData().iterator();
        while (it.hasNext()) {
            if (((InputOutputPair) it.next()) != null) {
                i++;
            }
        }
        if (i <= 0) {
            return false;
        }
        setTotalChange(1.0d);
        setEntries(new ArrayList<>(i));
        for (InputOutputPair inputOutputPair : getData()) {
            if (inputOutputPair != null && inputOutputPair.getOutput() != null) {
                this.entries.add(new Entry(inputOutputPair));
            }
        }
        setSupportsMap(new LinkedHashMap<>());
        setResult(new KernelBinaryCategorizer<>(getKernel(), Collections.unmodifiableCollection(getSupportsMap().values()), 0.0d));
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        setTotalChange(0.0d);
        Collections.sort(this.entries, Collections.reverseOrder());
        Iterator<SuccessiveOverrelaxation<InputType>.Entry> it = this.entries.iterator();
        while (it.hasNext()) {
            SuccessiveOverrelaxation<InputType>.Entry next = it.next();
            next.previousStepWeight = next.getWeight();
            update(next);
        }
        int size = this.entries.size();
        int size2 = this.supportsMap.size();
        int max = Math.max((int) ((0.5d * (size + 1.0d)) + ((size2 + 1.0d) / ((size2 + 1.0d) * (size2 + 1.0d)))), 1);
        ArrayList arrayList = new ArrayList(this.supportsMap.values());
        Collections.sort(arrayList);
        for (int i = 0; i < max; i++) {
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                update((Entry) it2.next());
            }
        }
        double d = 0.0d;
        Iterator<SuccessiveOverrelaxation<InputType>.Entry> it3 = this.entries.iterator();
        while (it3.hasNext()) {
            SuccessiveOverrelaxation<InputType>.Entry next2 = it3.next();
            double weight = next2.getWeight() - next2.previousStepWeight;
            d += weight * weight;
        }
        setTotalChange(Math.sqrt(d));
        return getTotalChange() > getMinChange();
    }

    protected void update(SuccessiveOverrelaxation<InputType>.Entry entry) {
        InputType input = entry.getInput();
        double d = entry.outputDouble;
        double evaluateAsDouble = this.result.evaluateAsDouble(input);
        double weight = entry.getWeight();
        double bias = this.result.getBias();
        double max = Math.max(0.0d, Math.min(this.maxWeight, (d * weight) - ((this.overrelaxation / (entry.selfKernel + 1.0d)) * ((d * evaluateAsDouble) - 1.0d)))) * d;
        entry.setWeight(max);
        if (max != 0.0d) {
            if (!entry.supportInserted) {
                this.supportsMap.put(entry.example, entry);
                entry.supportInserted = true;
            }
        } else if (entry.supportInserted) {
            this.supportsMap.remove(entry.example);
            entry.supportInserted = false;
        }
        this.result.setBias(bias + (max - weight));
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        if (getSupportsMap() != null) {
            ArrayList arrayList = new ArrayList(this.supportsMap.size());
            Iterator<SuccessiveOverrelaxation<InputType>.Entry> it = this.supportsMap.values().iterator();
            while (it.hasNext()) {
                arrayList.add(new DefaultWeightedValue((WeightedValue) it.next()));
            }
            getResult().setExamples(arrayList);
            setSupportsMap(null);
        }
    }

    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    public double getMaxWeight() {
        return this.maxWeight;
    }

    public void setMaxWeight(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("maxWeight must be positive");
        }
        this.maxWeight = d;
    }

    public double getOverrelaxation() {
        return this.overrelaxation;
    }

    public void setOverrelaxation(double d) {
        if (d <= 0.0d || d >= 2.0d) {
            throw new IllegalArgumentException("overrelaxation must be in (0.0, 2.0), exclusive.");
        }
        this.overrelaxation = d;
    }

    public double getMinChange() {
        return this.minChange;
    }

    public void setMinChange(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("minChange must be positive");
        }
        this.minChange = d;
    }

    @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
    public KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> getResult() {
        return this.result;
    }

    protected void setResult(KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> kernelBinaryCategorizer) {
        this.result = kernelBinaryCategorizer;
    }

    protected ArrayList<SuccessiveOverrelaxation<InputType>.Entry> getEntries() {
        return this.entries;
    }

    protected void setEntries(ArrayList<SuccessiveOverrelaxation<InputType>.Entry> arrayList) {
        this.entries = arrayList;
    }

    protected LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, SuccessiveOverrelaxation<InputType>.Entry> getSupportsMap() {
        return this.supportsMap;
    }

    protected void setSupportsMap(LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, SuccessiveOverrelaxation<InputType>.Entry> linkedHashMap) {
        this.supportsMap = linkedHashMap;
    }

    public double getTotalChange() {
        return this.totalChange;
    }

    protected void setTotalChange(double d) {
        this.totalChange = d;
    }

    @Override // gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm
    public NamedValue<Double> getPerformance() {
        return new DefaultNamedValue("change", Double.valueOf(getTotalChange()));
    }
}
