package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.learning.function.categorization.DefaultKernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.DefaultWeightedValue;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/perceptron/AbstractLinearCombinationOnlineLearner.class */
public abstract class AbstractLinearCombinationOnlineLearner extends AbstractKernelizableBinaryCategorizerOnlineLearner {
    protected boolean updateBias;

    public AbstractLinearCombinationOnlineLearner(boolean z) {
        this(z, VectorFactory.getDefault());
    }

    public AbstractLinearCombinationOnlineLearner(boolean z, VectorFactory<?> vectorFactory) {
        super(vectorFactory);
        setUpdateBias(z);
    }

    @Override // gov.sandia.cognition.learning.algorithm.perceptron.AbstractOnlineLinearBinaryCategorizerLearner
    public void update(LinearBinaryCategorizer linearBinaryCategorizer, Vector vector, boolean z) {
        Vector weights = linearBinaryCategorizer.getWeights();
        if (weights == null) {
            initialize(linearBinaryCategorizer, vector, z);
            weights = linearBinaryCategorizer.getWeights();
        }
        double evaluateAsDouble = linearBinaryCategorizer.evaluateAsDouble(vector);
        double d = z ? 1.0d : -1.0d;
        double computeUpdate = computeUpdate(linearBinaryCategorizer, vector, z, evaluateAsDouble);
        double computeDecay = computeDecay(linearBinaryCategorizer, vector, z, evaluateAsDouble, computeUpdate);
        if (computeDecay != 1.0d) {
            if (computeDecay == 0.0d) {
                weights.zero();
            } else {
                weights.scaleEquals(computeDecay);
            }
        }
        if (computeUpdate != 0.0d) {
            if (computeUpdate != 1.0d) {
                weights.plusEquals(vector.scale(computeUpdate * d));
            } else if (z) {
                weights.plusEquals(vector);
            } else {
                weights.minusEquals(vector);
            }
        }
        linearBinaryCategorizer.setWeights(weights);
        if (this.updateBias) {
            linearBinaryCategorizer.setBias((linearBinaryCategorizer.getBias() * computeDecay) + (d * computeUpdate));
        }
        double computeRescaling = computeRescaling(linearBinaryCategorizer, vector, z, evaluateAsDouble, computeUpdate, computeDecay);
        if (computeRescaling != 1.0d) {
            weights.scaleEquals(computeRescaling);
            linearBinaryCategorizer.setWeights(weights);
            if (this.updateBias) {
                linearBinaryCategorizer.setBias(linearBinaryCategorizer.getBias() * computeRescaling);
            }
        }
    }

    @Override // gov.sandia.cognition.learning.algorithm.perceptron.AbstractKernelizableBinaryCategorizerOnlineLearner, gov.sandia.cognition.learning.algorithm.perceptron.KernelizableBinaryCategorizerOnlineLearner
    public <InputType> DefaultKernelBinaryCategorizer<InputType> createInitialLearnedObject(Kernel<? super InputType> kernel) {
        return new DefaultKernelBinaryCategorizer<>(kernel);
    }

    @Override // gov.sandia.cognition.learning.algorithm.perceptron.KernelizableBinaryCategorizerOnlineLearner
    public <InputType> void update(DefaultKernelBinaryCategorizer<InputType> defaultKernelBinaryCategorizer, InputType inputtype, boolean z) {
        if (defaultKernelBinaryCategorizer.getExamples().isEmpty()) {
            initialize((DefaultKernelBinaryCategorizer<DefaultKernelBinaryCategorizer<InputType>>) defaultKernelBinaryCategorizer, (DefaultKernelBinaryCategorizer<InputType>) inputtype, z);
        }
        double evaluateAsDouble = defaultKernelBinaryCategorizer.evaluateAsDouble(inputtype);
        double d = z ? 1.0d : -1.0d;
        double computeUpdate = computeUpdate((DefaultKernelBinaryCategorizer<DefaultKernelBinaryCategorizer<InputType>>) defaultKernelBinaryCategorizer, (DefaultKernelBinaryCategorizer<InputType>) inputtype, z, evaluateAsDouble);
        double computeDecay = computeDecay((DefaultKernelBinaryCategorizer<DefaultKernelBinaryCategorizer<InputType>>) defaultKernelBinaryCategorizer, (DefaultKernelBinaryCategorizer<InputType>) inputtype, z, evaluateAsDouble, computeUpdate);
        if (computeDecay != 1.0d) {
            if (computeDecay == 0.0d) {
                defaultKernelBinaryCategorizer.getExamples().clear();
            } else {
                for (DefaultWeightedValue<InputType> defaultWeightedValue : defaultKernelBinaryCategorizer.getExamples()) {
                    defaultWeightedValue.setWeight(computeDecay * defaultWeightedValue.getWeight());
                }
            }
        }
        if (computeUpdate != 0.0d) {
            defaultKernelBinaryCategorizer.add(inputtype, computeUpdate * d);
        }
        if (this.updateBias) {
            defaultKernelBinaryCategorizer.setBias((defaultKernelBinaryCategorizer.getBias() * computeDecay) + (d * computeUpdate));
        }
        double computeRescaling = computeRescaling((DefaultKernelBinaryCategorizer<DefaultKernelBinaryCategorizer<InputType>>) defaultKernelBinaryCategorizer, (DefaultKernelBinaryCategorizer<InputType>) inputtype, z, evaluateAsDouble, computeUpdate, computeDecay);
        if (computeRescaling != 1.0d) {
            for (DefaultWeightedValue<InputType> defaultWeightedValue2 : defaultKernelBinaryCategorizer.getExamples()) {
                defaultWeightedValue2.setWeight(computeRescaling * defaultWeightedValue2.getWeight());
            }
            if (this.updateBias) {
                defaultKernelBinaryCategorizer.setBias(defaultKernelBinaryCategorizer.getBias() * computeRescaling);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [gov.sandia.cognition.math.matrix.Vector] */
    protected void initialize(LinearBinaryCategorizer linearBinaryCategorizer, Vector vector, boolean z) {
        linearBinaryCategorizer.setWeights(getVectorFactory().createVector(vector.getDimensionality()));
    }

    protected abstract double computeUpdate(LinearBinaryCategorizer linearBinaryCategorizer, Vector vector, boolean z, double d);

    protected double computeDecay(LinearBinaryCategorizer linearBinaryCategorizer, Vector vector, boolean z, double d, double d2) {
        return 1.0d;
    }

    protected double computeRescaling(LinearBinaryCategorizer linearBinaryCategorizer, Vector vector, boolean z, double d, double d2, double d3) {
        return 1.0d;
    }

    protected <InputType> void initialize(DefaultKernelBinaryCategorizer<InputType> defaultKernelBinaryCategorizer, InputType inputtype, boolean z) {
    }

    protected abstract <InputType> double computeUpdate(DefaultKernelBinaryCategorizer<InputType> defaultKernelBinaryCategorizer, InputType inputtype, boolean z, double d);

    protected <InputType> double computeDecay(DefaultKernelBinaryCategorizer<InputType> defaultKernelBinaryCategorizer, InputType inputtype, boolean z, double d, double d2) {
        return 1.0d;
    }

    protected <InputType> double computeRescaling(DefaultKernelBinaryCategorizer<InputType> defaultKernelBinaryCategorizer, InputType inputtype, boolean z, double d, double d2, double d3) {
        return 1.0d;
    }

    public boolean isUpdateBias() {
        return this.updateBias;
    }

    protected void setUpdateBias(boolean z) {
        this.updateBias = z;
    }
}
