package org.linqs.psl.reasoner.admm;

import java.util.List;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.Parallel;

/* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner.class */
public class ADMMReasoner extends Reasoner<ADMMObjectiveTerm> {
    private static final Logger log = Logger.getLogger(ADMMReasoner.class);
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private int computePeriod;
    private final float stepSize;
    private boolean primalDualBreak;
    private double epsilonRel;
    private double epsilonAbs;
    private double primalRes;
    private double epsilonPrimal;
    private double dualRes;
    private double epsilonDual;
    private double AxNorm;
    private double AyNorm;
    private double BzNorm;
    private long termBlockSize;
    private long variableBlockSize;

    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$TermWorker.class */
    private class TermWorker extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final float[] consensusValues;

        public TermWorker(ADMMTermStore aDMMTermStore, long j) {
            this.termStore = aDMMTermStore;
            this.blockSize = j;
            this.consensusValues = aDMMTermStore.getVariableValues();
        }

        public Object clone() {
            return new TermWorker(this.termStore, this.blockSize);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(long j, Long l) {
            long size = this.termStore.size();
            for (int i = 0; i < this.blockSize; i++) {
                long j2 = (j * this.blockSize) + i;
                if (j2 >= size) {
                    return;
                }
                ADMMObjectiveTerm aDMMObjectiveTerm = this.termStore.get(j2);
                if (aDMMObjectiveTerm.isActive()) {
                    aDMMObjectiveTerm.updateLagrange(ADMMReasoner.this.stepSize, this.consensusValues);
                    aDMMObjectiveTerm.minimize(ADMMReasoner.this.stepSize, this.consensusValues);
                }
            }
        }
    }

    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$VariableWorker.class */
    private class VariableWorker extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final int numVariables;
        private final float[] consensusValues;
        private final GroundAtom[] consensusAtoms;

        public VariableWorker(ADMMTermStore aDMMTermStore, long j, int i) {
            this.termStore = aDMMTermStore;
            this.blockSize = j;
            this.numVariables = i;
            this.consensusValues = aDMMTermStore.getVariableValues();
            this.consensusAtoms = aDMMTermStore.getVariableAtoms();
        }

        public Object clone() {
            return new VariableWorker(this.termStore, this.blockSize, this.numVariables);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(long j, Long l) {
            int i;
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            for (int i2 = 0; i2 < this.blockSize && (i = (int) ((j * this.blockSize) + i2)) < this.numVariables; i2++) {
                List<ADMMTermStore.LocalRecord> localRecords = this.termStore.getLocalRecords(i);
                if (localRecords != null) {
                    double d6 = 0.0d;
                    int i3 = 0;
                    for (ADMMTermStore.LocalRecord localRecord : localRecords) {
                        ADMMObjectiveTerm aDMMObjectiveTerm = this.termStore.get(localRecord.termIndex);
                        if (aDMMObjectiveTerm.isActive()) {
                            float variableValue = aDMMObjectiveTerm.getVariableValue(localRecord.variableIndex);
                            float variableLagrange = aDMMObjectiveTerm.getVariableLagrange(localRecord.variableIndex);
                            d6 += variableValue + (variableLagrange / ADMMReasoner.this.stepSize);
                            d3 += variableValue * variableValue;
                            d5 += variableLagrange * variableLagrange;
                            i3++;
                        }
                    }
                    if (i3 != 0) {
                        float max = this.consensusAtoms[i].isFixed() ? this.consensusValues[i] : Math.max(Math.min((float) (d6 / i3), ADMMReasoner.UPPER_BOUND), ADMMReasoner.LOWER_BOUND);
                        float f = this.consensusValues[i] - max;
                        d2 += f * f * i3;
                        d4 += max * max * i3;
                        this.consensusValues[i] = max;
                        for (ADMMTermStore.LocalRecord localRecord2 : localRecords) {
                            ADMMObjectiveTerm aDMMObjectiveTerm2 = this.termStore.get(localRecord2.termIndex);
                            if (aDMMObjectiveTerm2.isActive()) {
                                float variableValue2 = aDMMObjectiveTerm2.getVariableValue(localRecord2.variableIndex) - max;
                                d += variableValue2 * variableValue2;
                            }
                        }
                    }
                }
            }
            ADMMReasoner.this.updateIterationVariables(d, d2, d3, d4, d5);
        }
    }

    public ADMMReasoner() {
        this.maxIterations = Options.ADMM_MAX_ITER.getInt();
        this.primalDualBreak = Options.ADMM_PRIMAL_DUAL_BREAK.getBoolean();
        this.stepSize = Options.ADMM_STEP_SIZE.getFloat();
        this.computePeriod = Options.ADMM_COMPUTE_PERIOD.getInt();
        this.epsilonAbs = Options.ADMM_EPSILON_ABS.getDouble();
        this.epsilonRel = Options.ADMM_EPSILON_REL.getDouble();
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public double optimize(TermStore<ADMMObjectiveTerm> termStore, List<EvaluationInstance> list, TrainingMap trainingMap) {
        if (!(termStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore (found " + termStore.getClass().getName() + ").");
        }
        ADMMTermStore aDMMTermStore = (ADMMTermStore) termStore;
        aDMMTermStore.initForOptimization();
        initForOptimization(aDMMTermStore);
        long size = aDMMTermStore.size();
        int numVariables = aDMMTermStore.getNumVariables();
        this.termBlockSize = (size / (Parallel.getNumThreads() * 4)) + 1;
        this.variableBlockSize = (numVariables / (Parallel.getNumThreads() * 4)) + 1;
        long ceil = (long) Math.ceil(size / this.termBlockSize);
        long ceil2 = (long) Math.ceil(numVariables / this.variableBlockSize);
        double sqrt = Math.sqrt(aDMMTermStore.getNumLocalVariables()) * this.epsilonAbs;
        Reasoner.ObjectiveResult objectiveResult = null;
        Reasoner.ObjectiveResult objectiveResult2 = null;
        boolean z = false;
        long j = 0;
        int i = 1;
        while (!z) {
            long currentTimeMillis = System.currentTimeMillis();
            this.primalRes = 0.0d;
            this.dualRes = 0.0d;
            this.AxNorm = 0.0d;
            this.AyNorm = 0.0d;
            this.BzNorm = 0.0d;
            Parallel.count(ceil, new TermWorker(aDMMTermStore, this.termBlockSize));
            Parallel.count(ceil2, new VariableWorker(aDMMTermStore, this.variableBlockSize, numVariables));
            this.primalRes = Math.sqrt(this.primalRes);
            this.dualRes = this.stepSize * Math.sqrt(this.dualRes);
            this.epsilonPrimal = sqrt + (this.epsilonRel * Math.max(Math.sqrt(this.AxNorm), Math.sqrt(this.BzNorm)));
            this.epsilonDual = sqrt + (this.epsilonRel * Math.sqrt(this.AyNorm));
            long currentTimeMillis2 = System.currentTimeMillis();
            j += currentTimeMillis2 - currentTimeMillis;
            z = breakOptimization(i, aDMMTermStore, objectiveResult, objectiveResult2);
            if (i % this.computePeriod == 0 || z) {
                objectiveResult2 = objectiveResult;
                objectiveResult = parallelComputeObjective(aDMMTermStore);
                if (objectiveResult.violatedConstraints > 0 && i <= ((int) (this.maxIterations * this.budget))) {
                    z = false;
                }
                log.trace("Iteration {} -- Objective: {}, Violated Constraints: {}, Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}, Iteration Time: {}, Total Optimization Time: {}.", Integer.valueOf(i), Float.valueOf(objectiveResult.objective), Long.valueOf(objectiveResult.violatedConstraints), Double.valueOf(this.primalRes), Double.valueOf(this.dualRes), Double.valueOf(this.epsilonPrimal), Double.valueOf(this.epsilonDual), Long.valueOf(currentTimeMillis2 - currentTimeMillis), Long.valueOf(j));
                evaluate(aDMMTermStore, i, list, trainingMap);
            }
            i++;
        }
        optimizationComplete(aDMMTermStore, objectiveResult, j, i - 1);
        return objectiveResult.objective;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.reasoner.Reasoner
    public boolean breakOptimization(int i, TermStore<ADMMObjectiveTerm> termStore, Reasoner.ObjectiveResult objectiveResult, Reasoner.ObjectiveResult objectiveResult2) {
        if (super.breakOptimization(i, termStore, objectiveResult, objectiveResult2)) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if ((objectiveResult != null && objectiveResult.violatedConstraints > 0) || !this.primalDualBreak || i <= 1 || this.primalRes >= this.epsilonPrimal || this.dualRes >= this.epsilonDual) {
            return false;
        }
        log.trace("Breaking optimization. Primal residual: {} below tolerance: {} and dual residual: {} below tolerance: {}.", Double.valueOf(this.primalRes), Double.valueOf(this.epsilonPrimal), Double.valueOf(this.dualRes), Double.valueOf(this.epsilonDual));
        return true;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public synchronized void updateIterationVariables(double d, double d2, double d3, double d4, double d5) {
        this.primalRes += d;
        this.dualRes += d2;
        this.AxNorm += d3;
        this.AyNorm += d5;
        this.BzNorm += d4;
    }
}
