package org.linqs.psl.reasoner.admm;

import java.util.Iterator;
import org.linqs.psl.config.Config;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.admm.term.LinearConstraintTerm;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner.class */
public class ADMMReasoner implements Reasoner {
    private static final Logger log;
    public static final String CONFIG_PREFIX = "admmreasoner";
    public static final String MAX_ITER_KEY = "admmreasoner.maxiterations";
    public static final int MAX_ITER_DEFAULT = 25000;
    public static final String COMPUTE_PERIOD_KEY = "admmreasoner.computeperiod";
    public static final int COMPUTE_PERIOD_DEFAULT = 50;
    public static final String STEP_SIZE_KEY = "admmreasoner.stepsize";
    public static final float STEP_SIZE_DEFAULT = 1.0f;
    public static final String EPSILON_ABS_KEY = "admmreasoner.epsilonabs";
    public static final float EPSILON_ABS_DEFAULT = 1.0E-5f;
    public static final String EPSILON_REL_KEY = "admmreasoner.epsilonrel";
    public static final float EPSILON_REL_DEFAULT = 0.001f;
    public static final String OBJECTIVE_BREAK_KEY = "admmreasoner.objectivebreak";
    public static final boolean OBJECTIVE_BREAK_DEFAULT = true;
    public static final String INITIAL_CONSENSUS_VALUE_KEY = "admmreasoner.initialconsensusvalue";
    public static final String INITIAL_CONSENSUS_VALUE_DEFAULT;
    public static final String INITIAL_LOCAL_VALUE_KEY = "admmreasoner.initiallocalvalue";
    public static final String INITIAL_LOCAL_VALUE_DEFAULT;
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private float epsilonRel;
    private float primalRes;
    private float epsilonPrimal;
    private float dualRes;
    private float epsilonDual;
    private float AxNorm;
    private float AyNorm;
    private float BzNorm;
    private float lagrangePenalty;
    private float augmentedLagrangePenalty;
    private float[] consensusValues;
    private int termBlockSize;
    private int variableBlockSize;
    static final /* synthetic */ boolean $assertionsDisabled;
    private int maxIter = Config.getInt(MAX_ITER_KEY, MAX_ITER_DEFAULT);
    private final float stepSize = Config.getFloat(STEP_SIZE_KEY, 1.0f);
    private int computePeriod = Config.getInt(COMPUTE_PERIOD_KEY, 50);
    private boolean objectiveBreak = Config.getBoolean(OBJECTIVE_BREAK_KEY, true);
    private float epsilonAbs = Config.getFloat(EPSILON_ABS_KEY, 1.0E-5f);

    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$InitialValue.class */
    public enum InitialValue {
        ZERO,
        RANDOM,
        ATOM
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$ObjectiveResult.class */
    public static class ObjectiveResult {
        public final float objective;
        public final int violatedConstraints;

        public ObjectiveResult(float f, int i) {
            this.objective = f;
            this.violatedConstraints = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$TermWorker.class */
    public class TermWorker extends Parallel.Worker<Integer> {
        private ADMMTermStore termStore;
        private int blockSize;

        public TermWorker(ADMMTermStore aDMMTermStore, int i) {
            this.termStore = aDMMTermStore;
            this.blockSize = i;
        }

        public Object clone() {
            return new TermWorker(this.termStore, this.blockSize);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(int i, Integer num) {
            int i2;
            int size = this.termStore.size();
            for (int i3 = 0; i3 < this.blockSize && (i2 = (i * this.blockSize) + i3) < size; i3++) {
                this.termStore.get(i2).updateLagrange(ADMMReasoner.this.stepSize, ADMMReasoner.this.consensusValues);
                this.termStore.get(i2).minimize(ADMMReasoner.this.stepSize, ADMMReasoner.this.consensusValues);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/reasoner/admm/ADMMReasoner$VariableWorker.class */
    public class VariableWorker extends Parallel.Worker<Integer> {
        private ADMMTermStore termStore;
        private int blockSize;

        public VariableWorker(ADMMTermStore aDMMTermStore, int i) {
            this.termStore = aDMMTermStore;
            this.blockSize = i;
        }

        public Object clone() {
            return new VariableWorker(this.termStore, this.blockSize);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(int i, Integer num) {
            int i2;
            int numGlobalVariables = this.termStore.getNumGlobalVariables();
            float f = 0.0f;
            float f2 = 0.0f;
            float f3 = 0.0f;
            float f4 = 0.0f;
            float f5 = 0.0f;
            float f6 = 0.0f;
            float f7 = 0.0f;
            for (int i3 = 0; i3 < this.blockSize && (i2 = (i * this.blockSize) + i3) < numGlobalVariables; i3++) {
                float f8 = 0.0f;
                int size = this.termStore.getLocalVariables(i2).size();
                for (int i4 = 0; i4 < size; i4++) {
                    LocalVariable localVariable = this.termStore.getLocalVariables(i2).get(i4);
                    f8 += localVariable.getValue() + (localVariable.getLagrange() / ADMMReasoner.this.stepSize);
                    f3 += localVariable.getValue() * localVariable.getValue();
                    f5 += localVariable.getLagrange() * localVariable.getLagrange();
                }
                float max = Math.max(Math.min(f8 / size, 1.0f), ADMMReasoner.LOWER_BOUND);
                float f9 = ADMMReasoner.this.consensusValues[i2] - max;
                f2 += f9 * f9 * size;
                f4 += max * max * size;
                ADMMReasoner.this.consensusValues[i2] = max;
                for (int i5 = 0; i5 < size; i5++) {
                    LocalVariable localVariable2 = this.termStore.getLocalVariables(i2).get(i5);
                    float value = localVariable2.getValue() - max;
                    f += value * value;
                    f6 += localVariable2.getLagrange() * (localVariable2.getValue() - ADMMReasoner.this.consensusValues[i2]);
                    f7 = (float) (f7 + (0.5d * ADMMReasoner.this.stepSize * Math.pow(localVariable2.getValue() - ADMMReasoner.this.consensusValues[i2], 2.0d)));
                }
            }
            ADMMReasoner.this.updateIterationVariables(f, f2, f3, f4, f5, f6, f7);
        }
    }

    public ADMMReasoner() {
        if (this.epsilonAbs <= LOWER_BOUND) {
            throw new IllegalArgumentException("Property admmreasoner.epsilonabs must be positive.");
        }
        this.epsilonRel = Config.getFloat(EPSILON_REL_KEY, 0.001f);
        if (this.epsilonRel <= LOWER_BOUND) {
            throw new IllegalArgumentException("Property admmreasoner.epsilonrel must be positive.");
        }
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public void setMaxIter(int i) {
        this.maxIter = i;
    }

    public float getEpsilonRel() {
        return this.epsilonRel;
    }

    public void setEpsilonRel(float f) {
        this.epsilonRel = f;
    }

    public float getEpsilonAbs() {
        return this.epsilonAbs;
    }

    public void setEpsilonAbs(float f) {
        this.epsilonAbs = f;
    }

    public float getLagrangianPenalty() {
        return this.lagrangePenalty;
    }

    public float getAugmentedLagrangianPenalty() {
        return this.augmentedLagrangePenalty;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void optimize(TermStore termStore) {
        optimize(termStore, InitialValue.valueOf(Config.getString(INITIAL_CONSENSUS_VALUE_KEY, INITIAL_CONSENSUS_VALUE_DEFAULT).toUpperCase()), InitialValue.valueOf(Config.getString(INITIAL_LOCAL_VALUE_KEY, INITIAL_LOCAL_VALUE_DEFAULT).toUpperCase()));
    }

    public void optimize(TermStore termStore, InitialValue initialValue, InitialValue initialValue2) {
        if (!(termStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore (found " + termStore.getClass().getName() + ").");
        }
        ADMMTermStore aDMMTermStore = (ADMMTermStore) termStore;
        aDMMTermStore.resetLocalVairables(initialValue2);
        int size = aDMMTermStore.size();
        int numGlobalVariables = aDMMTermStore.getNumGlobalVariables();
        log.debug("Performing optimization with {} variables and {} terms.", Integer.valueOf(numGlobalVariables), Integer.valueOf(size));
        initConsensusValues(aDMMTermStore, initialValue);
        this.termBlockSize = (size / (Parallel.getNumThreads() * 4)) + 1;
        this.variableBlockSize = (numGlobalVariables / (Parallel.getNumThreads() * 4)) + 1;
        int ceil = (int) Math.ceil(size / this.termBlockSize);
        int ceil2 = (int) Math.ceil(numGlobalVariables / this.variableBlockSize);
        float sqrt = (float) (Math.sqrt(aDMMTermStore.getNumLocalVariables()) * this.epsilonAbs);
        ObjectiveResult objectiveResult = null;
        ObjectiveResult objectiveResult2 = null;
        if (log.isTraceEnabled()) {
            objectiveResult = computeObjective(aDMMTermStore);
            Logger logger = log;
            Object[] objArr = new Object[3];
            objArr[0] = 0;
            objArr[1] = Float.valueOf(objectiveResult.objective);
            objArr[2] = Boolean.valueOf(objectiveResult.violatedConstraints == 0);
            logger.trace("Iteration {} -- Objective: {}, Feasible: {}.", objArr);
        }
        int i = 1;
        while (true) {
            if ((i == 1 || this.primalRes > this.epsilonPrimal || this.dualRes > this.epsilonDual) && ((!this.objectiveBreak || objectiveResult2 == null || !MathUtils.equals(objectiveResult.objective, objectiveResult2.objective)) && i <= this.maxIter)) {
                this.primalRes = LOWER_BOUND;
                this.dualRes = LOWER_BOUND;
                this.AxNorm = LOWER_BOUND;
                this.AyNorm = LOWER_BOUND;
                this.BzNorm = LOWER_BOUND;
                this.lagrangePenalty = LOWER_BOUND;
                this.augmentedLagrangePenalty = LOWER_BOUND;
                Parallel.count(ceil, new TermWorker(aDMMTermStore, this.termBlockSize));
                Parallel.count(ceil2, new VariableWorker(aDMMTermStore, this.variableBlockSize));
                this.primalRes = (float) Math.sqrt(this.primalRes);
                this.dualRes = (float) (this.stepSize * Math.sqrt(this.dualRes));
                this.epsilonPrimal = (float) (sqrt + (this.epsilonRel * Math.max(Math.sqrt(this.AxNorm), Math.sqrt(this.BzNorm))));
                this.epsilonDual = (float) (sqrt + (this.epsilonRel * Math.sqrt(this.AyNorm)));
                if (i % this.computePeriod == 0) {
                    if (this.objectiveBreak) {
                        objectiveResult2 = objectiveResult;
                        objectiveResult = computeObjective(aDMMTermStore);
                        Logger logger2 = log;
                        Object[] objArr2 = new Object[7];
                        objArr2[0] = Integer.valueOf(i);
                        objArr2[1] = Float.valueOf(objectiveResult.objective);
                        objArr2[2] = Boolean.valueOf(objectiveResult.violatedConstraints == 0);
                        objArr2[3] = Float.valueOf(this.primalRes);
                        objArr2[4] = Float.valueOf(this.dualRes);
                        objArr2[5] = Float.valueOf(this.epsilonPrimal);
                        objArr2[6] = Float.valueOf(this.epsilonDual);
                        logger2.trace("Iteration {} -- Objective: {}, Feasible: {}, Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", objArr2);
                    } else {
                        log.trace("Iteration {} -- Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", Integer.valueOf(i), Float.valueOf(this.primalRes), Float.valueOf(this.dualRes), Float.valueOf(this.epsilonPrimal), Float.valueOf(this.epsilonDual));
                    }
                }
                i++;
            }
        }
        ObjectiveResult computeObjective = computeObjective(aDMMTermStore);
        if (computeObjective.violatedConstraints > 0) {
            log.warn("No feasible solution found. {} constraints violated.", Integer.valueOf(computeObjective.violatedConstraints));
        }
        Logger logger3 = log;
        Object[] objArr3 = new Object[5];
        objArr3[0] = Integer.valueOf(i - 1);
        objArr3[1] = Float.valueOf(computeObjective.objective);
        objArr3[2] = Boolean.valueOf(computeObjective.violatedConstraints == 0);
        objArr3[3] = Float.valueOf(this.primalRes);
        objArr3[4] = Float.valueOf(this.dualRes);
        logger3.info("Optimization completed in {} iterations. Objective: {}, Feasible: {}, Primal res.: {}, Dual res.: {}", objArr3);
        aDMMTermStore.updateVariables(this.consensusValues);
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void close() {
    }

    public double getDualIncompatibility(GroundRule groundRule, ADMMTermStore aDMMTermStore, float[] fArr) {
        if (fArr == null) {
            fArr = new float[aDMMTermStore.getNumGlobalVariables()];
        }
        if (!$assertionsDisabled && fArr.length != this.consensusValues.length) {
            throw new AssertionError();
        }
        Iterator<ADMMObjectiveTerm> it = aDMMTermStore.getTerms(groundRule).iterator();
        while (it.hasNext()) {
            for (LocalVariable localVariable : it.next().getVariables()) {
                fArr[localVariable.getGlobalId()] = localVariable.getValue();
            }
        }
        aDMMTermStore.updateVariables(fArr);
        double incompatibility = ((WeightedGroundRule) groundRule).getIncompatibility();
        aDMMTermStore.updateVariables(this.consensusValues);
        return incompatibility;
    }

    private void initConsensusValues(ADMMTermStore aDMMTermStore, InitialValue initialValue) {
        this.consensusValues = new float[aDMMTermStore.getNumGlobalVariables()];
        if (initialValue == InitialValue.ZERO) {
            for (int i = 0; i < this.consensusValues.length; i++) {
                this.consensusValues[i] = 0.0f;
            }
            return;
        }
        if (initialValue != InitialValue.RANDOM) {
            if (initialValue != InitialValue.ATOM) {
                throw new IllegalStateException("Unknown initial consensus value: " + initialValue);
            }
            aDMMTermStore.getAtomValues(this.consensusValues);
        } else {
            for (int i2 = 0; i2 < this.consensusValues.length; i2++) {
                this.consensusValues[i2] = RandUtils.nextFloat();
            }
        }
    }

    private ObjectiveResult computeObjective(ADMMTermStore aDMMTermStore) {
        float f = 0.0f;
        int i = 0;
        Iterator<ADMMObjectiveTerm> it = aDMMTermStore.iterator();
        while (it.hasNext()) {
            ADMMObjectiveTerm next = it.next();
            if (!(next instanceof LinearConstraintTerm)) {
                f += next.evaluate(this.consensusValues);
            } else if (next.evaluate(this.consensusValues) > LOWER_BOUND) {
                i++;
            }
        }
        return new ObjectiveResult(f, i);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public synchronized void updateIterationVariables(float f, float f2, float f3, float f4, float f5, float f6, float f7) {
        this.primalRes += f;
        this.dualRes += f2;
        this.AxNorm += f3;
        this.AyNorm += f5;
        this.BzNorm += f4;
        this.lagrangePenalty += f6;
        this.augmentedLagrangePenalty += f7;
    }

    static {
        $assertionsDisabled = !ADMMReasoner.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger((Class<?>) ADMMReasoner.class);
        INITIAL_CONSENSUS_VALUE_DEFAULT = InitialValue.RANDOM.toString();
        INITIAL_LOCAL_VALUE_DEFAULT = InitialValue.RANDOM.toString();
    }
}
