package org.linqs.psl.application.learning.weight.gradient;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.InitialValue;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/gradient/GradientDescent.class */
public abstract class GradientDescent extends WeightLearningApplication {
    private static final Logger log = Logger.getLogger(GradientDescent.class);
    protected GDExtension gdExtension;
    protected Map<WeightedRule, Integer> ruleIndexMap;
    protected float[] weightGradient;
    protected float[] rvAtomGradient;
    protected float[] deepAtomGradient;
    protected float[] MAPRVAtomGradient;
    protected float[] MAPDeepAtomGradient;
    protected TermState[] trainMAPTermState;
    protected float[] trainMAPAtomValueState;
    protected TermState[] validationMAPTermState;
    protected float[] validationMAPAtomValueState;
    protected boolean saveBestValidationWeights;
    protected float[] bestValidationWeights;
    double currentValidationEvaluationMetric;
    double bestValidationEvaluationMetric;
    protected float baseStepSize;
    protected boolean scaleStepSize;
    protected float maxGradientMagnitude;
    protected float maxGradientNorm;
    protected float stoppingGradientNorm;
    protected boolean clipWeightGradient;
    protected int maxNumSteps;
    protected boolean runFullIterations;
    protected boolean movementBreak;
    protected float parameterMovement;
    protected float movementTolerance;
    protected boolean normBreak;
    protected float normTolerance;
    protected boolean objectiveBreak;
    protected float objectiveTolerance;
    protected float l2Regularization;
    protected float logRegularization;
    protected float entropyRegularization;

    /* loaded from: input_file:org/linqs/psl/application/learning/weight/gradient/GradientDescent$GDExtension.class */
    public enum GDExtension {
        MIRROR_DESCENT,
        PROJECTED_GRADIENT,
        NONE
    }

    public GradientDescent(List<Rule> list, Database database, Database database2, Database database3, Database database4, boolean z) {
        super(list, database, database2, database3, database4, Boolean.valueOf(z));
        this.gdExtension = GDExtension.valueOf(Options.WLA_GRADIENT_DESCENT_EXTENSION.getString().toUpperCase());
        this.ruleIndexMap = new HashMap(this.mutableRules.size());
        for (int i = 0; i < this.mutableRules.size(); i++) {
            this.ruleIndexMap.put(this.mutableRules.get(i), Integer.valueOf(i));
        }
        this.weightGradient = new float[this.mutableRules.size()];
        this.rvAtomGradient = null;
        this.deepAtomGradient = null;
        this.MAPRVAtomGradient = null;
        this.MAPDeepAtomGradient = null;
        this.trainMAPTermState = null;
        this.trainMAPAtomValueState = null;
        this.validationMAPTermState = null;
        this.validationMAPAtomValueState = null;
        this.saveBestValidationWeights = Options.WLA_GRADIENT_DESCENT_SAVE_BEST_VALIDATION_WEIGHTS.getBoolean();
        this.bestValidationWeights = null;
        this.currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
        this.bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
        if (this.saveBestValidationWeights && !this.runValidation) {
            throw new IllegalArgumentException("If saveBestValidationWeights is true, then runValidation must also be true.");
        }
        this.baseStepSize = Options.WLA_GRADIENT_DESCENT_STEP_SIZE.getFloat();
        this.scaleStepSize = Options.WLA_GRADIENT_DESCENT_SCALE_STEP.getBoolean();
        this.clipWeightGradient = Options.WLA_GRADIENT_DESCENT_CLIP_GRADIENT.getBoolean();
        this.maxGradientMagnitude = Options.WLA_GRADIENT_DESCENT_MAX_GRADIENT.getFloat();
        this.maxGradientNorm = Options.WLA_GRADIENT_DESCENT_MAX_GRADIENT_NORM.getFloat();
        this.maxNumSteps = Options.WLA_GRADIENT_DESCENT_NUM_STEPS.getInt();
        this.runFullIterations = Options.WLA_GRADIENT_DESCENT_RUN_FULL_ITERATIONS.getBoolean();
        this.movementBreak = Options.WLA_GRADIENT_DESCENT_MOVEMENT_BREAK.getBoolean();
        this.parameterMovement = Float.POSITIVE_INFINITY;
        this.movementTolerance = Options.WLA_GRADIENT_DESCENT_MOVEMENT_TOLERANCE.getFloat();
        this.normBreak = Options.WLA_GRADIENT_DESCENT_NORM_BREAK.getBoolean();
        this.normTolerance = Options.WLA_GRADIENT_DESCENT_NORM_TOLERANCE.getFloat();
        this.objectiveBreak = Options.WLA_GRADIENT_DESCENT_OBJECTIVE_BREAK.getBoolean();
        this.objectiveTolerance = Options.WLA_GRADIENT_DESCENT_OBJECTIVE_TOLERANCE.getFloat();
        this.stoppingGradientNorm = Options.WLA_GRADIENT_DESCENT_STOPPING_GRADIENT_NORM.getFloat();
        this.l2Regularization = Options.WLA_GRADIENT_DESCENT_L2_REGULARIZATION.getFloat();
        this.logRegularization = Options.WLA_GRADIENT_DESCENT_LOG_REGULARIZATION.getFloat();
        this.entropyRegularization = Options.WLA_GRADIENT_DESCENT_ENTROPY_REGULARIZATION.getFloat();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void postInitGroundModel() {
        super.postInitGroundModel();
        if (this.runValidation && this.evaluation == null) {
            throw new IllegalArgumentException("If validation is being run, then an evaluator must be specified for predicates.");
        }
        if (this.runValidation && this.validationInferenceApplication.getDatabase().getAtomStore().size() <= 0) {
            throw new IllegalStateException("If validation is being run, then validation data must be provided in the runtime.json file.");
        }
        this.trainInferenceApplication.setInitialValue(InitialValue.ATOM);
        this.validationInferenceApplication.setInitialValue(InitialValue.ATOM);
        this.trainMAPTermState = this.trainInferenceApplication.getTermStore().saveState();
        this.validationMAPTermState = this.validationInferenceApplication.getTermStore().saveState();
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.trainMAPAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
        float[] atomValues2 = this.validationInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.validationMAPAtomValueState = Arrays.copyOf(atomValues2, atomValues2.length);
        this.rvAtomGradient = new float[atomValues.length];
        this.deepAtomGradient = new float[atomValues.length];
        this.MAPRVAtomGradient = new float[atomValues.length];
        this.MAPDeepAtomGradient = new float[atomValues.length];
    }

    protected void initForLearning() {
        switch (this.gdExtension) {
            case MIRROR_DESCENT:
            case PROJECTED_GRADIENT:
                simplexScaleWeights();
                break;
        }
        Iterator<DeepPredicate> it = this.deepPredicates.iterator();
        while (it.hasNext()) {
            it.next().predictDeepModel(true);
        }
        this.bestValidationWeights = new float[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); i++) {
            this.bestValidationWeights[i] = this.mutableRules.get(i).getWeight();
        }
        this.currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
        this.bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    protected void doLearn() {
        boolean z = false;
        float[] fArr = new float[this.mutableRules.size()];
        log.info("Gradient Descent Weight Learning Start.");
        initForLearning();
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = this.mutableRules.get(i).getWeight();
        }
        long j = 0;
        int i2 = 0;
        while (!z) {
            long currentTimeMillis = System.currentTimeMillis();
            log.trace("Model:");
            Iterator<WeightedRule> it = this.mutableRules.iterator();
            while (it.hasNext()) {
                log.trace("{}", it.next());
            }
            if (log.isTraceEnabled() && this.evaluation != null) {
                runMAPEvaluation();
                log.trace("MAP State Training Evaluation Metric: {}", Double.valueOf(this.evaluation.getNormalizedRepMetric()));
            }
            if (this.runValidation) {
                runValidationEvaluation();
                log.debug("Current MAP State Validation Evaluation Metric: {}", Double.valueOf(this.currentValidationEvaluationMetric));
            }
            computeIterationStatistics();
            float computeTotalLoss = computeTotalLoss();
            computeTotalWeightGradient();
            computeTotalAtomGradient();
            if (this.clipWeightGradient) {
                clipWeightGradient();
            }
            gradientStep(i2);
            long currentTimeMillis2 = System.currentTimeMillis();
            j += currentTimeMillis2 - currentTimeMillis;
            z = breakOptimization(i2, computeTotalLoss, computeTotalLoss);
            log.trace("Iteration {} -- Weight Learning Objective: {}, Gradient Magnitude: {}, Parameter Movement: {}, Iteration Time: {}", Integer.valueOf(i2), Float.valueOf(computeTotalLoss), Float.valueOf(computeGradientNorm()), Float.valueOf(this.parameterMovement), Long.valueOf(currentTimeMillis2 - currentTimeMillis));
            i2++;
        }
        log.info("Gradient Descent Weight Learning Finished.");
        if (this.saveBestValidationWeights) {
            for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
                this.mutableRules.get(i3).setWeight(fArr[i3]);
            }
        }
        if (this.evaluation != null) {
            runMAPEvaluation();
            log.info("Final MAP State Evaluation Metric: {}", Double.valueOf(this.evaluation.getNormalizedRepMetric()));
        }
        if (this.runValidation) {
            runValidationEvaluation();
            log.info("Final MAP State Validation Evaluation Metric: {}", Double.valueOf(this.evaluation.getNormalizedRepMetric()));
        }
        log.info("Final Model {} ", this.mutableRules);
        log.info("Final Weight Learning Loss: {}, Final Gradient Magnitude: {}, Total optimization time: {}", Float.valueOf(computeTotalLoss()), Float.valueOf(computeGradientNorm()), Long.valueOf(j));
        for (DeepPredicate deepPredicate : this.deepPredicates) {
            deepPredicate.saveDeepModel();
            deepPredicate.close();
        }
    }

    protected void runMAPEvaluation() {
        log.trace("Running MAP Inference.");
        computeMAPStateWithWarmStart(this.trainInferenceApplication, this.trainMAPTermState, this.trainMAPAtomValueState);
        this.evaluation.compute(this.trainingMap);
        Iterator<DeepPredicate> it = this.deepPredicates.iterator();
        while (it.hasNext()) {
            it.next().evalDeepModel();
        }
        Iterator<DeepPredicate> it2 = this.deepPredicates.iterator();
        while (it2.hasNext()) {
            it2.next().predictDeepModel(true);
        }
    }

    protected void runValidationEvaluation() {
        for (int i = 0; i < this.deepPredicates.size(); i++) {
            DeepPredicate deepPredicate = this.deepPredicates.get(i);
            deepPredicate.setDeepModel(this.validationDeepModelPredicates.get(i));
            deepPredicate.predictDeepModel(false);
        }
        log.trace("Running Validation Inference.");
        computeMAPStateWithWarmStart(this.validationInferenceApplication, this.validationMAPTermState, this.validationMAPAtomValueState);
        this.evaluation.compute(this.validationMap);
        this.currentValidationEvaluationMetric = this.evaluation.getNormalizedRepMetric();
        if (this.currentValidationEvaluationMetric > this.bestValidationEvaluationMetric) {
            this.bestValidationEvaluationMetric = this.currentValidationEvaluationMetric;
            for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
                this.bestValidationWeights[i2] = this.mutableRules.get(i2).getWeight();
            }
            log.debug("New Best Validation Model: {}", this.mutableRules);
        }
        log.debug("MAP State Best Validation Evaluation Metric: {}", Double.valueOf(this.bestValidationEvaluationMetric));
        for (int i3 = 0; i3 < this.deepPredicates.size(); i3++) {
            DeepPredicate deepPredicate2 = this.deepPredicates.get(i3);
            deepPredicate2.setDeepModel(this.deepModelPredicates.get(i3));
            deepPredicate2.predictDeepModel(true);
        }
    }

    protected boolean breakOptimization(int i, float f, float f2) {
        if (i > this.maxNumSteps) {
            log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", Integer.valueOf(this.maxNumSteps));
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (this.movementBreak && MathUtils.equals(this.parameterMovement, 0.0f, this.movementTolerance)) {
            log.trace("Breaking Weight Learning. Parameter Movement: {} is within tolerance: {}", Float.valueOf(this.parameterMovement), Float.valueOf(this.movementTolerance));
            return true;
        }
        if (this.normBreak && MathUtils.equals(computeGradientNorm(), 0.0f, this.normTolerance)) {
            log.trace("Breaking Weight Learning. Gradient norm: {} is within tolerance: {}", Float.valueOf(computeGradientNorm()), Float.valueOf(this.normTolerance));
            return true;
        }
        if (!this.objectiveBreak || !MathUtils.equals(f, f2, this.objectiveTolerance)) {
            return false;
        }
        log.trace("Breaking Weight Learning. Objective change: {} is within tolerance: {}", Float.valueOf(Math.abs(f - f2)), Float.valueOf(this.objectiveTolerance));
        return true;
    }

    private void clipWeightGradient() {
        float pNorm = MathUtils.pNorm(this.weightGradient, this.maxGradientNorm);
        if (pNorm > this.maxGradientMagnitude) {
            log.trace("Clipping gradient. Original gradient magnitude: {} exceeds limit: {} in L_{} space.", Float.valueOf(pNorm), Float.valueOf(this.maxGradientMagnitude), Float.valueOf(this.maxGradientNorm));
            for (int i = 0; i < this.mutableRules.size(); i++) {
                this.weightGradient[i] = (this.maxGradientMagnitude * this.weightGradient[i]) / pNorm;
            }
        }
    }

    protected void gradientStep(int i) {
        this.parameterMovement = 0.0f;
        this.parameterMovement += weightGradientStep(i);
        this.parameterMovement += internalParameterGradientStep(i);
        this.parameterMovement += atomGradientStep();
    }

    protected float internalParameterGradientStep(int i) {
        return 0.0f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float weightGradientStep(int i) {
        float f = 0.0f;
        float[] fArr = new float[this.mutableRules.size()];
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            fArr[i2] = this.mutableRules.get(i2).getWeight();
        }
        float computeStepSize = computeStepSize(i);
        switch (this.gdExtension) {
            case MIRROR_DESCENT:
                float f2 = 0.0f;
                for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
                    f2 = (float) (f2 + (this.mutableRules.get(i3).getWeight() * Math.exp((-1.0f) * computeStepSize * this.weightGradient[i3])));
                }
                for (int i4 = 0; i4 < this.mutableRules.size(); i4++) {
                    this.mutableRules.get(i4).setWeight((float) ((this.mutableRules.get(i4).getWeight() * Math.exp(((-1.0f) * computeStepSize) * this.weightGradient[i4])) / f2));
                }
                break;
            case PROJECTED_GRADIENT:
                for (int i5 = 0; i5 < this.mutableRules.size(); i5++) {
                    this.mutableRules.get(i5).setWeight(this.mutableRules.get(i5).getWeight() - (this.weightGradient[i5] * computeStepSize));
                }
                simplexProjectWeights();
                break;
            default:
                for (int i6 = 0; i6 < this.mutableRules.size(); i6++) {
                    this.mutableRules.get(i6).setWeight(this.mutableRules.get(i6).getWeight() - (this.weightGradient[i6] * computeStepSize));
                }
                break;
        }
        this.inTrainingMAPState = false;
        this.inValidationMAPState = false;
        for (int i7 = 0; i7 < this.mutableRules.size(); i7++) {
            f += Math.abs(fArr[i7] - this.mutableRules.get(i7).getWeight());
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float atomGradientStep() {
        float f = 0.0f;
        for (DeepPredicate deepPredicate : this.deepPredicates) {
            deepPredicate.fitDeepPredicate(this.deepAtomGradient);
            f += deepPredicate.predictDeepModel(true);
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float computeStepSize(int i) {
        float f = this.baseStepSize;
        if (this.scaleStepSize) {
            f /= i + 1;
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float computeGradientNorm() {
        float computeGradientDescentNorm;
        switch (this.gdExtension) {
            case MIRROR_DESCENT:
                computeGradientDescentNorm = computeMirrorDescentNorm();
                break;
            case PROJECTED_GRADIENT:
                computeGradientDescentNorm = computeProjectedGradientDescentNorm();
                break;
            default:
                computeGradientDescentNorm = computeGradientDescentNorm();
                break;
        }
        return computeGradientDescentNorm + MathUtils.pNorm(this.deepAtomGradient, 2.0f);
    }

    private float computeMirrorDescentNorm() {
        float f = 0.0f;
        float f2 = 0.0f;
        for (int i = 0; i < this.mutableRules.size(); i++) {
            f2 = (float) (f2 + Math.exp(this.weightGradient[i]));
        }
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            f += (((float) Math.exp(this.weightGradient[i2])) / f2) * ((float) Math.log(r0 * this.mutableRules.size()));
        }
        return f;
    }

    private float computeProjectedGradientDescentNorm() {
        float f = 0.0f;
        int i = 0;
        float[] fArr = (float[]) this.weightGradient.clone();
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (this.logRegularization == 0.0f && MathUtils.equalsStrict(this.mutableRules.get(i2).getWeight(), 0.0f) && this.weightGradient[i2] > 0.0f) {
                fArr[i2] = 0.0f;
            } else if (this.logRegularization == 0.0f && MathUtils.equalsStrict(this.mutableRules.get(i2).getWeight(), 1.0f) && this.weightGradient[i2] < 0.0f) {
                fArr[i2] = 0.0f;
            } else {
                fArr[i2] = this.weightGradient[i2];
                if (this.logRegularization != 0.0f || !MathUtils.isZero(fArr[i2], 1.0E-8d)) {
                    i++;
                }
            }
        }
        float f2 = 0.0f;
        for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
            if (this.logRegularization != 0.0f || !MathUtils.isZero(fArr[i3], 1.0E-8d)) {
                f2 = (float) (f2 + Math.exp(this.weightGradient[i3]));
            }
        }
        for (int i4 = 0; i4 < this.mutableRules.size(); i4++) {
            if (this.logRegularization != 0.0f || !MathUtils.isZero(fArr[i4], 1.0E-8d)) {
                f += (((float) Math.exp(this.weightGradient[i4])) / f2) * ((float) Math.log(r0 * i));
            }
        }
        return f;
    }

    private float computeGradientDescentNorm() {
        float[] fArr = (float[]) this.weightGradient.clone();
        for (int i = 0; i < fArr.length; i++) {
            if (!MathUtils.equals(this.mutableRules.get(i).getWeight(), 0.0f) || this.weightGradient[i] <= 0.0f) {
                fArr[i] = this.weightGradient[i];
            } else {
                fArr[i] = 0.0f;
            }
        }
        return MathUtils.pNorm(fArr, this.stoppingGradientNorm);
    }

    public void simplexProjectWeights() {
        int size = this.mutableRules.size();
        float[] fArr = new float[size];
        for (int i = 0; i < size; i++) {
            fArr[i] = this.mutableRules.get(i).getWeight();
        }
        Arrays.sort(fArr);
        float f = 0.0f;
        float f2 = 0.0f;
        for (int i2 = 1; i2 <= size; i2++) {
            float f3 = f + fArr[size - i2];
            float f4 = (f3 - 1.0f) / i2;
            if (f4 >= fArr[size - i2]) {
                break;
            }
            f = f3;
            f2 = f4;
        }
        for (WeightedRule weightedRule : this.mutableRules) {
            weightedRule.setWeight(Math.max(0.0f, weightedRule.getWeight() - f2));
        }
    }

    private void simplexScaleWeights() {
        float f = 0.0f;
        Iterator<WeightedRule> it = this.mutableRules.iterator();
        while (it.hasNext()) {
            f += it.next().getWeight();
        }
        for (WeightedRule weightedRule : this.mutableRules) {
            weightedRule.setWeight(weightedRule.getWeight() / f);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeMAPStateWithWarmStart(InferenceApplication inferenceApplication, TermState[] termStateArr, float[] fArr) {
        inferenceApplication.getTermStore().loadState(termStateArr);
        AtomStore atomStore = inferenceApplication.getDatabase().getAtomStore();
        float[] atomValues = atomStore.getAtomValues();
        for (int i = 0; i < atomStore.size(); i++) {
            if (!atomStore.getAtom(i).isFixed()) {
                atomValues[i] = fArr[i];
            }
        }
        atomStore.sync();
        computeMAPState(inferenceApplication);
        inferenceApplication.getTermStore().saveState(termStateArr);
        float[] atomValues2 = inferenceApplication.getDatabase().getAtomStore().getAtomValues();
        System.arraycopy(atomValues2, 0, fArr, 0, atomValues2.length);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeCurrentIncompatibility(float[] fArr) {
        Integer num;
        Arrays.fill(fArr, 0.0f);
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        Iterator it = this.trainInferenceApplication.getTermStore().iterator();
        while (it.hasNext()) {
            ReasonerTerm reasonerTerm = (ReasonerTerm) it.next();
            if ((reasonerTerm.getRule() instanceof WeightedRule) && (num = this.ruleIndexMap.get((WeightedRule) reasonerTerm.getRule())) != null) {
                int intValue = num.intValue();
                fArr[intValue] = fArr[intValue] + reasonerTerm.evaluateIncompatibility(atomValues);
            }
        }
    }

    protected abstract void computeIterationStatistics();

    protected float computeTotalLoss() {
        float computeLearningLoss = computeLearningLoss();
        float computeRegularization = computeRegularization();
        log.trace("Learning Loss: {}, Regularization: {}", Float.valueOf(computeLearningLoss), Float.valueOf(computeRegularization));
        return computeLearningLoss + computeRegularization;
    }

    protected abstract float computeLearningLoss();

    protected float computeRegularization() {
        float f = 0.0f;
        for (int i = 0; i < this.mutableRules.size(); i++) {
            WeightedRule weightedRule = this.mutableRules.get(i);
            float max = (float) Math.max(Math.log(weightedRule.getWeight()), Math.log(1.0E-8d));
            f += ((this.l2Regularization * ((float) Math.pow(weightedRule.getWeight(), 2.0d))) - (this.logRegularization * max)) + (this.entropyRegularization * weightedRule.getWeight() * max);
        }
        return f;
    }

    protected void computeTotalWeightGradient() {
        Arrays.fill(this.weightGradient, 0.0f);
        addLearningLossWeightGradient();
        addRegularizationWeightGradient();
    }

    protected abstract void addLearningLossWeightGradient();

    protected void addRegularizationWeightGradient() {
        for (int i = 0; i < this.mutableRules.size(); i++) {
            float log2 = (float) Math.log(Math.max(this.mutableRules.get(i).getWeight(), 1.0E-8d));
            this.weightGradient[i] = (float) (r0[r1] + (((2.0f * this.l2Regularization) * this.mutableRules.get(i).getWeight()) - (this.logRegularization / Math.max(this.mutableRules.get(i).getWeight(), 1.0E-8d))) + (this.entropyRegularization * (log2 + 1.0f)));
        }
    }

    protected abstract void computeTotalAtomGradient();
}
