package org.linqs.psl.reasoner.dcd;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.dcd.term.DCDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/reasoner/dcd/DCDReasoner.class */
public class DCDReasoner extends Reasoner {
    private static final Logger log = Logger.getLogger(DCDReasoner.class);
    private int maxIterations = Options.DCD_MAX_ITER.getInt();
    private float c = Options.DCD_C.getFloat();
    private boolean truncateEveryStep = Options.DCD_TRUNCATE_EVERY_STEP.getBoolean();

    @Override // org.linqs.psl.reasoner.Reasoner
    public double optimize(TermStore termStore, List<Evaluator> list, TrainingMap trainingMap, Set<StandardPredicate> set) {
        if (!(termStore instanceof VariableTermStore)) {
            throw new IllegalArgumentException("DCDReasoner requires an VariableTermStore (found " + termStore.getClass().getName() + ").");
        }
        VariableTermStore<DCDObjectiveTerm, GroundAtom> variableTermStore = (VariableTermStore) termStore;
        variableTermStore.initForOptimization();
        long j = 0;
        double d = Double.POSITIVE_INFINITY;
        float[] fArr = null;
        long j2 = 0;
        boolean z = false;
        int i = 1;
        while (!z) {
            long currentTimeMillis = System.currentTimeMillis();
            j = 0;
            double d2 = 0.0d;
            Iterator it = variableTermStore.iterator();
            while (it.hasNext()) {
                DCDObjectiveTerm dCDObjectiveTerm = (DCDObjectiveTerm) it.next();
                if (i > 1) {
                    d2 += dCDObjectiveTerm.evaluate(fArr) / this.c;
                }
                j++;
                variableUpdate(dCDObjectiveTerm, variableTermStore);
            }
            if (!this.truncateEveryStep) {
                float[] variableValues = variableTermStore.getVariableValues();
                for (int i2 = 0; i2 < variableTermStore.getNumVariables(); i2++) {
                    variableValues[i2] = Math.max(0.0f, Math.min(1.0f, variableValues[i2]));
                }
            }
            evaluate(variableTermStore, i, list, trainingMap, set);
            variableTermStore.iterationComplete();
            z = breakOptimization(i, d2, d, j);
            if (i == 1) {
                fArr = Arrays.copyOf(variableTermStore.getVariableValues(), variableTermStore.getVariableValues().length);
            } else {
                System.arraycopy(variableTermStore.getVariableValues(), 0, fArr, 0, fArr.length);
                d = d2;
            }
            long currentTimeMillis2 = System.currentTimeMillis();
            j2 += currentTimeMillis2 - currentTimeMillis;
            if (i > 1 && log.isTraceEnabled()) {
                log.trace("Iteration {} -- Objective: {}, Normalized Objective: {}, Iteration Time: {}, Total Optimization Time: {}", Integer.valueOf(i - 1), Double.valueOf(d2), Double.valueOf(d2 / j), Long.valueOf(currentTimeMillis2 - currentTimeMillis), Long.valueOf(j2));
            }
            i++;
        }
        double computeObjective = computeObjective(variableTermStore);
        double syncAtoms = variableTermStore.syncAtoms();
        log.info("Final Objective: {}, Final Normalized Objective: {}, Total Optimization Time: {}, Total Number of Iterations: {}", Double.valueOf(computeObjective), Double.valueOf(computeObjective / j), Long.valueOf(j2), Integer.valueOf(i));
        log.debug("Movement of variables from initial state: {}", Double.valueOf(syncAtoms));
        log.debug("Optimized with {} variables and {} terms.", Integer.valueOf(variableTermStore.getNumRandomVariables()), Long.valueOf(j));
        return computeObjective;
    }

    private boolean breakOptimization(int i, double d, double d2, long j) {
        if (i > ((int) (this.maxIterations * this.budget))) {
            return true;
        }
        return !this.runFullIterations && this.objectiveBreak && MathUtils.equals(d / ((double) j), d2 / ((double) j), (double) this.tolerance);
    }

    private double computeObjective(VariableTermStore<DCDObjectiveTerm, GroundAtom> variableTermStore) {
        double d = 0.0d;
        while (IteratorUtils.newIterable(variableTermStore.isLoaded() ? variableTermStore.noWriteIterator() : variableTermStore.iterator()).iterator().hasNext()) {
            d += ((DCDObjectiveTerm) r0.next()).evaluate(variableTermStore.getVariableValues()) / this.c;
        }
        return d;
    }

    private void variableUpdate(DCDObjectiveTerm dCDObjectiveTerm, VariableTermStore<DCDObjectiveTerm, GroundAtom> variableTermStore) {
        GroundAtom[] variableAtoms = variableTermStore.getVariableAtoms();
        float[] variableValues = variableTermStore.getVariableValues();
        float weight = dCDObjectiveTerm.getRule().getWeight() * this.c;
        float computeGradient = dCDObjectiveTerm.computeGradient(variableValues);
        if (dCDObjectiveTerm.isSquared()) {
            variableUpdate(dCDObjectiveTerm, computeGradient + (dCDObjectiveTerm.getLagrange() / (2.0f * weight)), weight, Float.POSITIVE_INFINITY, variableValues, variableAtoms);
        } else {
            variableUpdate(dCDObjectiveTerm, computeGradient, weight, weight, variableValues, variableAtoms);
        }
    }

    private void variableUpdate(DCDObjectiveTerm dCDObjectiveTerm, float f, float f2, float f3, float[] fArr, GroundAtom[] groundAtomArr) {
        float f4 = f;
        if (MathUtils.isZero(dCDObjectiveTerm.getLagrange())) {
            f4 = Math.min(0.0f, f);
        }
        if (MathUtils.equals(f3, f2) && MathUtils.equals(dCDObjectiveTerm.getLagrange(), f2)) {
            f4 = Math.max(0.0f, f);
        }
        if (MathUtils.isZero(f4)) {
            return;
        }
        float lagrange = dCDObjectiveTerm.getLagrange();
        int[] variableIndexes = dCDObjectiveTerm.getVariableIndexes();
        float[] coefficients = dCDObjectiveTerm.getCoefficients();
        dCDObjectiveTerm.setLagrange(Math.min(f3, Math.max(0.0f, dCDObjectiveTerm.getLagrange() - (f / dCDObjectiveTerm.getQii()))));
        for (int i = 0; i < dCDObjectiveTerm.size(); i++) {
            if (!(groundAtomArr[variableIndexes[i]] instanceof ObservedAtom)) {
                float lagrange2 = fArr[variableIndexes[i]] - ((dCDObjectiveTerm.getLagrange() - lagrange) * coefficients[i]);
                if (this.truncateEveryStep) {
                    lagrange2 = Math.max(0.0f, Math.min(1.0f, lagrange2));
                }
                fArr[variableIndexes[i]] = lagrange2;
            }
        }
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void close() {
    }
}
