package org.linqs.psl.reasoner.sgd;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.ArrayUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/reasoner/sgd/SGDReasoner.class */
public class SGDReasoner extends Reasoner<SGDObjectiveTerm> {
    private static final Logger log = Logger.getLogger(SGDReasoner.class);
    private static final float EPSILON = 1.0E-8f;
    private boolean firstOrderBreak;
    private float firstOrderTolerance;
    private float firstOrderNorm;
    private float[] prevGradient;
    private float adamBeta1;
    private float adamBeta2;
    private float[] accumulatedGradientSquares;
    private float[] accumulatedGradientMean;
    private float[] accumulatedGradientVariance;
    private float initialLearningRate;
    private float learningRateInverseScaleExp;
    private boolean coordinateStep;
    private SGDLearningSchedule learningSchedule;
    private SGDExtension sgdExtension;

    /* loaded from: input_file:org/linqs/psl/reasoner/sgd/SGDReasoner$SGDExtension.class */
    public enum SGDExtension {
        NONE,
        ADAGRAD,
        ADAM
    }

    /* loaded from: input_file:org/linqs/psl/reasoner/sgd/SGDReasoner$SGDLearningSchedule.class */
    public enum SGDLearningSchedule {
        CONSTANT,
        STEPDECAY
    }

    public SGDReasoner() {
        this.maxIterations = Options.SGD_MAX_ITER.getInt();
        this.firstOrderBreak = Options.SGD_FIRST_ORDER_BREAK.getBoolean();
        this.firstOrderTolerance = Options.SGD_FIRST_ORDER_THRESHOLD.getFloat();
        this.firstOrderNorm = Options.SGD_FIRST_ORDER_NORM.getFloat();
        this.initialLearningRate = Options.SGD_LEARNING_RATE.getFloat();
        this.learningRateInverseScaleExp = Options.SGD_INVERSE_TIME_EXP.getFloat();
        this.learningSchedule = SGDLearningSchedule.valueOf(Options.SGD_LEARNING_SCHEDULE.getString().toUpperCase());
        this.coordinateStep = Options.SGD_COORDINATE_STEP.getBoolean();
        this.sgdExtension = SGDExtension.valueOf(Options.SGD_EXTENSION.getString().toUpperCase());
        this.prevGradient = null;
        this.adamBeta1 = Options.SGD_ADAM_BETA_1.getFloat();
        this.adamBeta2 = Options.SGD_ADAM_BETA_2.getFloat();
        this.accumulatedGradientSquares = null;
        this.accumulatedGradientMean = null;
        this.accumulatedGradientVariance = null;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public double optimize(TermStore<SGDObjectiveTerm> termStore, List<EvaluationInstance> list, TrainingMap trainingMap) {
        termStore.initForOptimization();
        initForOptimization(termStore);
        float f = Float.POSITIVE_INFINITY;
        float[] fArr = null;
        float f2 = Float.POSITIVE_INFINITY;
        float[] fArr2 = null;
        long j = 0;
        boolean z = false;
        int i = 1;
        while (!z) {
            long currentTimeMillis = System.currentTimeMillis();
            float f3 = 0.0f;
            float calculateAnnealedLearningRate = calculateAnnealedLearningRate(i);
            if (i > 1) {
                Arrays.fill(this.prevGradient, 0.0f);
            }
            Iterator<SGDObjectiveTerm> it = termStore.iterator();
            while (it.hasNext()) {
                SGDObjectiveTerm next = it.next();
                if (next.isActive()) {
                    if (i > 1) {
                        f3 += next.evaluate(fArr);
                        addTermGradient(next, this.prevGradient, fArr, termStore.getVariableAtoms());
                    }
                    variableUpdate(next, termStore, i, calculateAnnealedLearningRate);
                }
            }
            evaluate(termStore, i, list, trainingMap);
            if (i == 1) {
                this.prevGradient = new float[termStore.getVariableValues().length];
                fArr = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
                fArr2 = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
            } else {
                clipGradient(fArr, this.prevGradient);
                z = breakOptimization(i, termStore, new Reasoner.ObjectiveResult(f3, 0L), new Reasoner.ObjectiveResult(f, 0L));
                if (f3 < f2) {
                    f2 = f3;
                    System.arraycopy(fArr, 0, fArr2, 0, fArr2.length);
                }
                System.arraycopy(termStore.getVariableValues(), 0, fArr, 0, fArr.length);
                f = f3;
            }
            long currentTimeMillis2 = System.currentTimeMillis();
            j += currentTimeMillis2 - currentTimeMillis;
            if (i > 1) {
                log.trace("Iteration {} -- Objective: {}, Violated Constraints: 0, Gradient Norm: {}, Iteration Time: {}, Total Optimization Time: {}", Integer.valueOf(i - 1), Float.valueOf(f3), Float.valueOf(MathUtils.pNorm(this.prevGradient, this.firstOrderNorm)), Long.valueOf(currentTimeMillis2 - currentTimeMillis), Long.valueOf(j));
            }
            i++;
        }
        Reasoner.ObjectiveResult computeObjective = computeObjective(termStore);
        if (computeObjective.objective < f2) {
            f2 = computeObjective.objective;
            fArr2 = fArr;
        }
        float[] variableValues = termStore.getVariableValues();
        System.arraycopy(fArr2, 0, variableValues, 0, variableValues.length);
        optimizationComplete(termStore, new Reasoner.ObjectiveResult(f2, 0L), j, i - 1);
        return f2;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    protected void initForOptimization(TermStore<SGDObjectiveTerm> termStore) {
        super.initForOptimization(termStore);
        switch (this.sgdExtension) {
            case NONE:
                return;
            case ADAGRAD:
                this.accumulatedGradientSquares = new float[termStore.getVariableCounts().unobserved];
                return;
            case ADAM:
                int i = termStore.getVariableCounts().unobserved;
                this.accumulatedGradientMean = new float[i];
                this.accumulatedGradientVariance = new float[i];
                return;
            default:
                throw new IllegalArgumentException(String.format("Unsupported SGD Extensions: '%s'", this.sgdExtension));
        }
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    protected void optimizationComplete(TermStore<SGDObjectiveTerm> termStore, Reasoner.ObjectiveResult objectiveResult, long j, int i) {
        super.optimizationComplete(termStore, objectiveResult, j, i);
        this.prevGradient = null;
        this.accumulatedGradientSquares = null;
        this.accumulatedGradientMean = null;
        this.accumulatedGradientVariance = null;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    protected boolean breakOptimization(int i, TermStore<SGDObjectiveTerm> termStore, Reasoner.ObjectiveResult objectiveResult, Reasoner.ObjectiveResult objectiveResult2) {
        if (super.breakOptimization(i, termStore, objectiveResult, objectiveResult2)) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if ((objectiveResult != null && objectiveResult.violatedConstraints > 0) || !this.firstOrderBreak || !MathUtils.equals(MathUtils.pNorm(this.prevGradient, this.firstOrderNorm), 0.0f, this.firstOrderTolerance)) {
            return false;
        }
        log.trace("Breaking optimization. Gradient magnitude: {} below tolerance: {}.", Float.valueOf(MathUtils.pNorm(this.prevGradient, this.firstOrderNorm)), Float.valueOf(this.firstOrderTolerance));
        return true;
    }

    private void addTermGradient(SGDObjectiveTerm sGDObjectiveTerm, float[] fArr, float[] fArr2, GroundAtom[] groundAtomArr) {
        int size = sGDObjectiveTerm.size();
        int[] atomIndexes = sGDObjectiveTerm.getAtomIndexes();
        float computeInnerPotential = sGDObjectiveTerm.computeInnerPotential(fArr2);
        for (int i = 0; i < size; i++) {
            if (!groundAtomArr[atomIndexes[i]].isFixed()) {
                int i2 = atomIndexes[i];
                fArr[i2] = fArr[i2] + sGDObjectiveTerm.computeVariablePartial(i, computeInnerPotential);
            }
        }
    }

    private float calculateAnnealedLearningRate(int i) {
        switch (this.learningSchedule) {
            case CONSTANT:
                return this.initialLearningRate;
            case STEPDECAY:
                return this.initialLearningRate / ((float) Math.pow(i, this.learningRateInverseScaleExp));
            default:
                throw new IllegalArgumentException(String.format("Illegal value found for SGD learning schedule: '%s'", this.learningSchedule));
        }
    }

    private void variableUpdate(SGDObjectiveTerm sGDObjectiveTerm, TermStore termStore, int i, float f) {
        GroundAtom[] variableAtoms = termStore.getVariableAtoms();
        float[] variableValues = termStore.getVariableValues();
        int size = sGDObjectiveTerm.size();
        int[] atomIndexes = sGDObjectiveTerm.getAtomIndexes();
        float computeInnerPotential = sGDObjectiveTerm.computeInnerPotential(variableValues);
        for (int i2 = 0; i2 < size; i2++) {
            if (!variableAtoms[atomIndexes[i2]].isFixed()) {
                variableValues[atomIndexes[i2]] = Math.max(0.0f, Math.min(1.0f, variableValues[atomIndexes[i2]] - computeVariableStep(atomIndexes[i2], i, f, sGDObjectiveTerm.computeVariablePartial(i2, computeInnerPotential))));
                if (this.coordinateStep) {
                    computeInnerPotential = sGDObjectiveTerm.computeInnerPotential(variableValues);
                }
            }
        }
    }

    private float computeVariableStep(int i, int i2, float f, float f2) {
        float pow;
        switch (this.sgdExtension) {
            case NONE:
                pow = f2 * f;
                break;
            case ADAGRAD:
                this.accumulatedGradientSquares = ArrayUtils.ensureCapacity(this.accumulatedGradientSquares, i);
                this.accumulatedGradientSquares[i] = this.accumulatedGradientSquares[i] + (f2 * f2);
                pow = f2 * (f / ((float) Math.sqrt(this.accumulatedGradientSquares[i] + EPSILON)));
                break;
            case ADAM:
                this.accumulatedGradientMean = ArrayUtils.ensureCapacity(this.accumulatedGradientMean, i);
                this.accumulatedGradientMean[i] = (this.adamBeta1 * this.accumulatedGradientMean[i]) + ((1.0f - this.adamBeta1) * f2);
                this.accumulatedGradientVariance = ArrayUtils.ensureCapacity(this.accumulatedGradientVariance, i);
                this.accumulatedGradientVariance[i] = (this.adamBeta2 * this.accumulatedGradientVariance[i]) + ((1.0f - this.adamBeta2) * f2 * f2);
                pow = (this.accumulatedGradientMean[i] / (1.0f - ((float) Math.pow(this.adamBeta1, i2)))) * (f / (((float) Math.sqrt(this.accumulatedGradientVariance[i] / (1.0f - ((float) Math.pow(this.adamBeta2, i2))))) + EPSILON));
                break;
            default:
                throw new IllegalArgumentException(String.format("Unsupported SGD Extensions: '%s'", this.sgdExtension));
        }
        return pow;
    }
}
