package org.linqs.psl.reasoner.sgd;

import java.util.Iterator;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/reasoner/sgd/SGDReasoner.class */
public class SGDReasoner extends Reasoner {
    private static final Logger log = LoggerFactory.getLogger(SGDReasoner.class);
    private int maxIterations = Options.SGD_MAX_ITER.getInt();
    private boolean watchMovement = Options.SGD_MOVEMENT.getBoolean();
    private float movementThreshold = Options.SGD_MOVEMENT_THRESHOLD.getFloat();

    @Override // org.linqs.psl.reasoner.Reasoner
    public void optimize(TermStore termStore) {
        float f;
        float f2;
        if (!(termStore instanceof VariableTermStore)) {
            throw new IllegalArgumentException("SGDReasoner requires a VariableTermStore (found " + termStore.getClass().getName() + ").");
        }
        VariableTermStore<SGDObjectiveTerm, RandomVariableAtom> variableTermStore = (VariableTermStore) termStore;
        variableTermStore.initForOptimization();
        float f3 = -1.0f;
        if (this.printInitialObj && log.isTraceEnabled()) {
            f3 = computeObjective(variableTermStore);
            log.trace("Iteration {} -- Objective: {}, Mean Movement: {}, Iteration Time: {}, Total Optimiztion Time: {}", new Object[]{0, Float.valueOf(f3), Float.valueOf(0.0f), 0, 0});
        }
        int i = 1;
        long j = 0;
        do {
            long currentTimeMillis = System.currentTimeMillis();
            f = 0.0f;
            float[] variableValues = variableTermStore.getVariableValues();
            Iterator it = variableTermStore.iterator();
            while (it.hasNext()) {
                f += ((SGDObjectiveTerm) it.next()).minimize(i, variableValues);
            }
            if (variableValues.length != 0) {
                f /= variableValues.length;
            }
            long currentTimeMillis2 = System.currentTimeMillis();
            f2 = f3;
            f3 = computeObjective(variableTermStore);
            j += currentTimeMillis2 - currentTimeMillis;
            if (log.isTraceEnabled()) {
                log.trace("Iteration {} -- Objective: {}, Mean Movement: {}, Iteration Time: {}, Total Optimiztion Time: {}", new Object[]{Integer.valueOf(i), Float.valueOf(f3), Float.valueOf(f), Long.valueOf(currentTimeMillis2 - currentTimeMillis), Long.valueOf(j)});
            }
            i++;
            variableTermStore.iterationComplete();
        } while (!breakOptimization(i, f3, f2, f));
        variableTermStore.syncAtoms();
        log.info("Optimization completed in {} iterations. Objective: {}, Total Optimiztion Time: {}", new Object[]{Integer.valueOf(i - 1), Float.valueOf(f3), Long.valueOf(j)});
        log.debug("Optimized with {} variables and {} terms.", Integer.valueOf(variableTermStore.getNumVariables()), Integer.valueOf(variableTermStore.size()));
    }

    private boolean breakOptimization(int i, float f, float f2, float f3) {
        if (i > ((int) (this.maxIterations * this.budget))) {
            return true;
        }
        return (!this.watchMovement || f3 <= this.movementThreshold) && this.objectiveBreak && MathUtils.equals(f, f2, this.tolerance);
    }

    public float computeObjective(VariableTermStore<SGDObjectiveTerm, RandomVariableAtom> variableTermStore) {
        float f = 0.0f;
        Iterator<SGDObjectiveTerm> noWriteIterator = variableTermStore.isLoaded() ? variableTermStore.noWriteIterator() : variableTermStore.iterator();
        float[] variableValues = variableTermStore.getVariableValues();
        Iterator it = IteratorUtils.newIterable(noWriteIterator).iterator();
        while (it.hasNext()) {
            f += ((SGDObjectiveTerm) it.next()).evaluate(variableValues);
        }
        return f / variableTermStore.size();
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void close() {
    }
}
