package org.linqs.psl.reasoner.sgd;

import java.util.Iterator;
import org.linqs.psl.config.Config;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/reasoner/sgd/SGDReasoner.class */
public class SGDReasoner implements Reasoner {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) SGDReasoner.class);
    public static final String CONFIG_PREFIX = "sgd";
    public static final String MAX_ITER_KEY = "sgd.maxiterations";
    public static final int MAX_ITER_DEFAULT = 200;
    public static final String OBJECTIVE_BREAK_KEY = "sgd.objectivebreak";
    public static final boolean OBJECTIVE_BREAK_DEFAULT = true;
    public static final String OBJ_TOL_KEY = "sgd.tolerance";
    public static final float OBJ_TOL_DEFAULT = 1.0E-5f;
    public static final String LEARNING_RATE_KEY = "sgd.learningrate";
    public static final float LEARNING_RATE_DEFAULT = 1.0f;
    public static final String PRINT_OBJECTIVE = "sgd.printobj";
    public static final boolean PRINT_OBJECTIVE_DEFAULT = true;
    public static final String PRINT_INITIAL_OBJECTIVE_KEY = "sgd.printinitialobj";
    public static final boolean PRINT_INITIAL_OBJECTIVE_DEFAULT = false;
    private int maxIter = Config.getInt(MAX_ITER_KEY, 200);
    private boolean objectiveBreak = Config.getBoolean(OBJECTIVE_BREAK_KEY, true);
    private boolean printObj = Config.getBoolean(PRINT_OBJECTIVE, true);
    private boolean printInitialObj = Config.getBoolean(PRINT_INITIAL_OBJECTIVE_KEY, false);
    private float tolerance = Config.getFloat(OBJ_TOL_KEY, 1.0E-5f);

    public int getMaxIter() {
        return this.maxIter;
    }

    public void setMaxIter(int i) {
        this.maxIter = i;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void optimize(TermStore termStore) {
        if (!(termStore instanceof VariableTermStore)) {
            throw new IllegalArgumentException("SGDReasoner requires an VariableTermStore (found " + termStore.getClass().getName() + ").");
        }
        VariableTermStore<SGDObjectiveTerm, RandomVariableAtom> variableTermStore = (VariableTermStore) termStore;
        float[] variableValues = variableTermStore.getVariableValues();
        float f = -1.0f;
        float f2 = Float.POSITIVE_INFINITY;
        int i = 1;
        if (this.printObj) {
            log.trace("objective:Iterations,Time(ms),Objective");
            if (this.printInitialObj) {
                f = computeObjective(variableTermStore, variableValues);
                log.trace("objective:{},{},{}", 0, 0, Float.valueOf(f));
            }
        }
        long j = 0;
        while (i <= this.maxIter && (!this.objectiveBreak || i == 1 || !MathUtils.equals(f, f2, this.tolerance))) {
            long currentTimeMillis = System.currentTimeMillis();
            Iterator it = variableTermStore.iterator();
            while (it.hasNext()) {
                ((SGDObjectiveTerm) it.next()).minimize(i, variableValues);
            }
            long currentTimeMillis2 = System.currentTimeMillis();
            f2 = f;
            f = computeObjective(variableTermStore, variableValues);
            j += currentTimeMillis2 - currentTimeMillis;
            if (this.printObj) {
                log.info("objective:{},{},{}", Integer.valueOf(i), Long.valueOf(j), Float.valueOf(f));
            }
            i++;
        }
        variableTermStore.syncAtoms();
        log.info("Optimization completed in {} iterations. Objective.: {}", Integer.valueOf(i - 1), Float.valueOf(f));
        log.debug("Optimized with {} variables and {} terms.", Integer.valueOf(variableTermStore.getNumVariables()), Integer.valueOf(variableTermStore.size()));
    }

    public float computeObjective(VariableTermStore<SGDObjectiveTerm, RandomVariableAtom> variableTermStore, float[] fArr) {
        float f = 0.0f;
        Iterator it = IteratorUtils.newIterable(variableTermStore.isLoaded() ? variableTermStore.noWriteIterator() : variableTermStore.iterator()).iterator();
        while (it.hasNext()) {
            f += ((SGDObjectiveTerm) it.next()).evaluate(fArr);
        }
        return f / variableTermStore.size();
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void close() {
    }
}
