package org.linqs.psl.reasoner.admm;

import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;

/* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner.class */
public class ADMMReasoner extends Reasoner {
    private static final Logger log = Logger.getLogger(ADMMReasoner.class);
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private double primalRes;
    private double epsilonPrimal;
    private double dualRes;
    private double epsilonDual;
    private double AxNorm;
    private double AyNorm;
    private double BzNorm;
    private double lagrangePenalty;
    private double augmentedLagrangePenalty;
    private long termBlockSize;
    private long variableBlockSize;
    private int maxIterations = Options.ADMM_MAX_ITER.getInt();
    private final float stepSize = Options.ADMM_STEP_SIZE.getFloat();
    private int computePeriod = Options.ADMM_COMPUTE_PERIOD.getInt();
    private double epsilonAbs = Options.ADMM_EPSILON_ABS.getDouble();
    private double epsilonRel = Options.ADMM_EPSILON_REL.getDouble();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$ObjectiveResult.class */
    public static class ObjectiveResult {
        public final double objective;
        public final long violatedConstraints;

        public ObjectiveResult(double d, long j) {
            this.objective = d;
            this.violatedConstraints = j;
        }
    }

    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$TermWorker.class */
    private class TermWorker extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final float[] consensusValues;
        private final boolean useNonConvex;

        public TermWorker(ADMMTermStore aDMMTermStore, long j, boolean z) {
            this.termStore = aDMMTermStore;
            this.blockSize = j;
            this.useNonConvex = z;
            this.consensusValues = aDMMTermStore.getConsensusValues();
        }

        public Object clone() {
            return new TermWorker(this.termStore, this.blockSize, this.useNonConvex);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(long j, Long l) {
            long size = this.termStore.size();
            for (int i = 0; i < this.blockSize; i++) {
                long j2 = (j * this.blockSize) + i;
                if (j2 >= size) {
                    return;
                }
                if (this.useNonConvex || this.termStore.get(j2).isConvex()) {
                    this.termStore.get(j2).updateLagrange(ADMMReasoner.this.stepSize, this.consensusValues);
                    this.termStore.get(j2).minimize(ADMMReasoner.this.stepSize, this.consensusValues);
                }
            }
        }
    }

    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$VariableWorker.class */
    private class VariableWorker extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final float[] consensusValues;
        private final boolean useNonConvex;

        public VariableWorker(ADMMTermStore aDMMTermStore, long j, boolean z) {
            this.termStore = aDMMTermStore;
            this.blockSize = j;
            this.useNonConvex = z;
            this.consensusValues = aDMMTermStore.getConsensusValues();
        }

        public Object clone() {
            return new VariableWorker(this.termStore, this.blockSize, this.useNonConvex);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(long j, Long l) {
            int i;
            int numConsensusVariables = this.termStore.getNumConsensusVariables();
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            double d7 = 0.0d;
            for (int i2 = 0; i2 < this.blockSize && (i = (int) ((j * this.blockSize) + i2)) < numConsensusVariables; i2++) {
                double d8 = 0.0d;
                int size = this.termStore.getLocalVariables(i).size();
                for (int i3 = 0; i3 < size; i3++) {
                    LocalVariable localVariable = this.termStore.getLocalVariables(i).get(i3);
                    d8 += localVariable.getValue() + (localVariable.getLagrange() / ADMMReasoner.this.stepSize);
                    d3 += localVariable.getValue() * localVariable.getValue();
                    d5 += localVariable.getLagrange() * localVariable.getLagrange();
                }
                float max = Math.max(Math.min((float) (d8 / size), ADMMReasoner.UPPER_BOUND), ADMMReasoner.LOWER_BOUND);
                float f = this.consensusValues[i] - max;
                d2 += f * f * size;
                d4 += max * max * size;
                this.consensusValues[i] = max;
                for (int i4 = 0; i4 < size; i4++) {
                    float value = this.termStore.getLocalVariables(i).get(i4).getValue() - max;
                    d += value * value;
                    d6 += r0.getLagrange() * (r0.getValue() - this.consensusValues[i]);
                    d7 += 0.5d * ADMMReasoner.this.stepSize * Math.pow(r0.getValue() - this.consensusValues[i], 2.0d);
                }
            }
            ADMMReasoner.this.updateIterationVariables(d, d2, d3, d4, d5, d6, d7);
        }
    }

    public double getEpsilonRel() {
        return this.epsilonRel;
    }

    public void setEpsilonRel(double d) {
        this.epsilonRel = d;
    }

    public double getEpsilonAbs() {
        return this.epsilonAbs;
    }

    public void setEpsilonAbs(double d) {
        this.epsilonAbs = d;
    }

    public double getLagrangianPenalty() {
        return this.lagrangePenalty;
    }

    public double getAugmentedLagrangianPenalty() {
        return this.augmentedLagrangePenalty;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public double optimize(TermStore termStore, List<Evaluator> list, TrainingMap trainingMap, Set<StandardPredicate> set) {
        if (!(termStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore (found " + termStore.getClass().getName() + ").");
        }
        ADMMTermStore aDMMTermStore = (ADMMTermStore) termStore;
        aDMMTermStore.initForOptimization();
        long size = aDMMTermStore.size();
        int numConsensusVariables = aDMMTermStore.getNumConsensusVariables();
        log.debug("Performing optimization with {} variables and {} terms.", Integer.valueOf(numConsensusVariables), Long.valueOf(size));
        this.termBlockSize = (size / (Parallel.getNumThreads() * 4)) + 1;
        this.variableBlockSize = (numConsensusVariables / (Parallel.getNumThreads() * 4)) + 1;
        long ceil = (long) Math.ceil(size / this.termBlockSize);
        long ceil2 = (long) Math.ceil(numConsensusVariables / this.variableBlockSize);
        double sqrt = Math.sqrt(aDMMTermStore.getNumLocalVariables()) * this.epsilonAbs;
        ObjectiveResult objectiveResult = null;
        ObjectiveResult objectiveResult2 = null;
        if (log.isTraceEnabled()) {
            objectiveResult = computeObjective(aDMMTermStore);
            Logger logger = log;
            Object[] objArr = new Object[3];
            objArr[0] = 0;
            objArr[1] = Double.valueOf(objectiveResult.objective);
            objArr[2] = Boolean.valueOf(objectiveResult.violatedConstraints == 0);
            logger.trace("Iteration {} -- Objective: {}, Feasible: {}.", objArr);
        }
        int i = 1;
        while (true) {
            this.primalRes = 0.0d;
            this.dualRes = 0.0d;
            this.AxNorm = 0.0d;
            this.AyNorm = 0.0d;
            this.BzNorm = 0.0d;
            this.lagrangePenalty = 0.0d;
            this.augmentedLagrangePenalty = 0.0d;
            boolean z = false;
            if (i >= this.nonconvexPeriod && i % this.nonconvexPeriod < this.nonconvexRounds) {
                z = true;
            }
            Parallel.count(ceil, new TermWorker(aDMMTermStore, this.termBlockSize, z));
            Parallel.count(ceil2, new VariableWorker(aDMMTermStore, this.variableBlockSize, z));
            this.primalRes = Math.sqrt(this.primalRes);
            this.dualRes = this.stepSize * Math.sqrt(this.dualRes);
            this.epsilonPrimal = sqrt + (this.epsilonRel * Math.max(Math.sqrt(this.AxNorm), Math.sqrt(this.BzNorm)));
            this.epsilonDual = sqrt + (this.epsilonRel * Math.sqrt(this.AyNorm));
            if (i % this.computePeriod == 0) {
                if (this.objectiveBreak) {
                    objectiveResult2 = objectiveResult;
                    objectiveResult = computeObjective(aDMMTermStore);
                    Logger logger2 = log;
                    Object[] objArr2 = new Object[7];
                    objArr2[0] = Integer.valueOf(i);
                    objArr2[1] = Double.valueOf(objectiveResult.objective);
                    objArr2[2] = Boolean.valueOf(objectiveResult.violatedConstraints == 0);
                    objArr2[3] = Double.valueOf(this.primalRes);
                    objArr2[4] = Double.valueOf(this.dualRes);
                    objArr2[5] = Double.valueOf(this.epsilonPrimal);
                    objArr2[6] = Double.valueOf(this.epsilonDual);
                    logger2.trace("Iteration {} -- Objective: {}, Feasible: {}, Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", objArr2);
                } else {
                    log.trace("Iteration {} -- Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", Integer.valueOf(i), Double.valueOf(this.primalRes), Double.valueOf(this.dualRes), Double.valueOf(this.epsilonPrimal), Double.valueOf(this.epsilonDual));
                }
                evaluate(aDMMTermStore, i, list, trainingMap, set);
                aDMMTermStore.iterationComplete();
            }
            i++;
            if (breakOptimization(i, objectiveResult, objectiveResult2)) {
                objectiveResult = computeObjective(aDMMTermStore);
                if (breakOptimization(i, objectiveResult, objectiveResult2)) {
                    break;
                }
            }
        }
        Logger logger3 = log;
        Object[] objArr3 = new Object[5];
        objArr3[0] = Integer.valueOf(i - 1);
        objArr3[1] = Double.valueOf(objectiveResult.objective);
        objArr3[2] = Boolean.valueOf(objectiveResult.violatedConstraints == 0);
        objArr3[3] = Double.valueOf(this.primalRes);
        objArr3[4] = Double.valueOf(this.dualRes);
        logger3.info("Optimization completed in {} iterations. Objective: {}, Feasible: {}, Primal res.: {}, Dual res.: {}", objArr3);
        if (objectiveResult.violatedConstraints > 0) {
            log.warn("No feasible solution found. {} constraints violated.", Long.valueOf(objectiveResult.violatedConstraints));
            computeObjective(aDMMTermStore);
        }
        aDMMTermStore.syncAtoms();
        return objectiveResult.objective;
    }

    private boolean breakOptimization(int i, ObjectiveResult objectiveResult, ObjectiveResult objectiveResult2) {
        if (i > ((int) (this.maxIterations * this.budget))) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (objectiveResult != null && objectiveResult.violatedConstraints > 0) {
            return false;
        }
        if (i <= 1 || this.primalRes >= this.epsilonPrimal || this.dualRes >= this.epsilonDual) {
            return this.objectiveBreak && objectiveResult2 != null && MathUtils.equals(objectiveResult.objective, objectiveResult2.objective, (double) this.tolerance);
        }
        return true;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void close() {
    }

    private ObjectiveResult computeObjective(ADMMTermStore aDMMTermStore) {
        double d = 0.0d;
        long j = 0;
        float[] consensusValues = aDMMTermStore.getConsensusValues();
        Iterator<ADMMObjectiveTerm> it = aDMMTermStore.iterator();
        while (it.hasNext()) {
            ADMMObjectiveTerm next = it.next();
            if (!next.isConstraint()) {
                d += next.evaluate(consensusValues);
            } else if (next.evaluate(consensusValues) > LOWER_BOUND) {
                j++;
            }
        }
        return new ObjectiveResult(d, j);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public synchronized void updateIterationVariables(double d, double d2, double d3, double d4, double d5, double d6, double d7) {
        this.primalRes += d;
        this.dualRes += d2;
        this.AxNorm += d3;
        this.AyNorm += d5;
        this.BzNorm += d4;
        this.lagrangePenalty += d6;
        this.augmentedLagrangePenalty += d7;
    }
}
