package org.linqs.psl.reasoner;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.streaming.StreamingTermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;

/* loaded from: input_file:org/linqs/psl/reasoner/Reasoner.class */
public abstract class Reasoner<T extends ReasonerTerm> {
    private static final Logger log;
    protected int maxIterations;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected double budget = 1.0d;
    protected boolean evaluate = Options.REASONER_EVALUATE.getBoolean();
    protected boolean runFullIterations = Options.REASONER_RUN_FULL_ITERATIONS.getBoolean();
    protected boolean objectiveBreak = Options.REASONER_OBJECTIVE_BREAK.getBoolean();
    protected float objectiveTolerance = Options.REASONER_OBJECTIVE_TOLERANCE.getFloat();
    protected boolean variableMovementBreak = Options.REASONER_VARIABLE_MOVEMENT_BREAK.getBoolean();
    protected float variableMovementTolerance = Options.REASONER_VARIABLE_MOVEMENT_TOLERANCE.getFloat();
    protected float variableMovementNorm = Options.REASONER_VARIABLE_MOVEMENT_NORM.getFloat();
    protected float[] prevVariableValues = null;
    protected float[][] workerRVAtomGradients = (float[][]) null;
    protected float[][] workerDeepGradients = (float[][]) null;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/reasoner/Reasoner$GradientWorker.class */
    public static class GradientWorker extends Parallel.Worker<Long> {
        private final TermStore termStore;
        private final int blockSize;
        private final float[] variableValues;
        private final GroundAtom[] variableAtoms;
        private final float[][] rvAtomGradients;
        private final float[][] deepAtomGradients;

        public GradientWorker(TermStore termStore, float[][] fArr, float[][] fArr2, int i) {
            this.termStore = termStore;
            this.variableValues = termStore.getVariableValues();
            this.variableAtoms = termStore.getVariableAtoms();
            this.rvAtomGradients = fArr;
            this.deepAtomGradients = fArr2;
            this.blockSize = i;
        }

        public Object clone() {
            return new GradientWorker(this.termStore, this.rvAtomGradients, this.deepAtomGradients, this.blockSize);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(long j, Long l) {
            long size = this.termStore.size();
            Arrays.fill(this.rvAtomGradients[(int) j], 0.0f);
            Arrays.fill(this.deepAtomGradients[(int) j], 0.0f);
            for (int i = 0; i < this.blockSize; i++) {
                int i2 = (int) ((j * this.blockSize) + i);
                if (i2 >= size) {
                    return;
                }
                ReasonerTerm reasonerTerm = this.termStore.get(i2);
                if (reasonerTerm.isActive() && !reasonerTerm.isConstraint()) {
                    int[] atomIndexes = reasonerTerm.getAtomIndexes();
                    float computeInnerPotential = reasonerTerm.computeInnerPotential(this.variableValues);
                    for (int i3 = 0; i3 < reasonerTerm.size(); i3++) {
                        if (!(this.variableAtoms[atomIndexes[i3]] instanceof ObservedAtom)) {
                            if (this.variableAtoms[atomIndexes[i3]].getPredicate() instanceof DeepPredicate) {
                                float[] fArr = this.deepAtomGradients[(int) j];
                                int i4 = atomIndexes[i3];
                                fArr[i4] = fArr[i4] + reasonerTerm.computeVariablePartial(i3, computeInnerPotential);
                            } else {
                                float[] fArr2 = this.rvAtomGradients[(int) j];
                                int i5 = atomIndexes[i3];
                                fArr2[i5] = fArr2[i5] + reasonerTerm.computeVariablePartial(i3, computeInnerPotential);
                            }
                        }
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/linqs/psl/reasoner/Reasoner$ObjectiveResult.class */
    public static class ObjectiveResult {
        public float objective;
        public long violatedConstraints;

        public ObjectiveResult(float f, long j) {
            this.objective = f;
            this.violatedConstraints = j;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/reasoner/Reasoner$ObjectiveWorker.class */
    public static class ObjectiveWorker extends Parallel.Worker<Long> {
        private final TermStore<? extends ReasonerTerm> termStore;
        private final int blockSize;
        private final float[] variableValues;
        private final float[] objectives;
        private final int[] violatedConstraints;

        public ObjectiveWorker(TermStore<? extends ReasonerTerm> termStore, float[] fArr, int[] iArr, int i) {
            this.termStore = termStore;
            this.variableValues = termStore.getVariableValues();
            this.objectives = fArr;
            this.violatedConstraints = iArr;
            this.blockSize = i;
        }

        public Object clone() {
            return new ObjectiveWorker(this.termStore, this.objectives, this.violatedConstraints, this.blockSize);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(long j, Long l) {
            int i;
            int size = (int) this.termStore.size();
            float f = 0.0f;
            int i2 = 0;
            for (int i3 = 0; i3 < this.blockSize && (i = (int) ((j * this.blockSize) + i3)) < size; i3++) {
                ReasonerTerm reasonerTerm = this.termStore.get(i);
                if (reasonerTerm.isActive()) {
                    if (!reasonerTerm.isConstraint()) {
                        f += reasonerTerm.evaluate(this.variableValues);
                    } else if (!MathUtils.isZero(reasonerTerm.evaluate(this.variableValues))) {
                        i2++;
                    }
                }
            }
            this.objectives[(int) j] = f;
            this.violatedConstraints[(int) j] = i2;
        }
    }

    public double optimize(TermStore<T> termStore) {
        return optimize(termStore, null, null);
    }

    public abstract double optimize(TermStore<T> termStore, List<EvaluationInstance> list, TrainingMap trainingMap);

    public void close() {
    }

    public void setBudget(double d) {
        this.budget = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initForOptimization(TermStore<T> termStore) {
        log.debug("Performing optimization with {} variables and {} terms.", termStore.getVariableCounts(), Long.valueOf(termStore.size()));
        if (log.isTraceEnabled()) {
            ObjectiveResult computeObjective = termStore instanceof StreamingTermStore ? computeObjective(termStore) : parallelComputeObjective(termStore);
            log.trace("Iteration {} -- Objective: {}, Violated Constraints: {}, Total Optimization Time: {}, Total Number of Iterations: {}.", 0, Float.valueOf(computeObjective.objective), Long.valueOf(computeObjective.violatedConstraints), 0, 0);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void optimizationComplete(TermStore<T> termStore, ObjectiveResult objectiveResult, long j, int i) {
        float sync = (float) termStore.sync();
        log.info("Final Objective: {}, Violated Constraints: {}, Total Optimization Time: {}, Total Number of Iterations: {}", Float.valueOf(objectiveResult.objective), Long.valueOf(objectiveResult.violatedConstraints), Long.valueOf(j), Integer.valueOf(i));
        log.debug("Movement of variables from initial state: {}", Float.valueOf(sync));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean breakOptimization(int i, TermStore<T> termStore, ObjectiveResult objectiveResult, ObjectiveResult objectiveResult2) {
        if (i > ((int) (this.maxIterations * this.budget))) {
            log.trace("Breaking optimization. Max iterations exceeded.");
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (objectiveResult != null && objectiveResult.violatedConstraints > 0) {
            return false;
        }
        if (this.objectiveBreak && objectiveResult != null && objectiveResult2 != null && MathUtils.equals(objectiveResult.objective, objectiveResult2.objective, this.objectiveTolerance)) {
            log.trace("Breaking optimization. Objective change: {} below tolerance: {}.", Float.valueOf(Math.abs(objectiveResult.objective - objectiveResult2.objective)), Float.valueOf(this.objectiveTolerance));
            return true;
        }
        if (!this.variableMovementBreak) {
            return false;
        }
        float[] variableValues = termStore.getVariableValues();
        if (this.prevVariableValues != null) {
            float[] copyOf = Arrays.copyOf(this.prevVariableValues, this.prevVariableValues.length);
            for (int i2 = 0; i2 < this.prevVariableValues.length; i2++) {
                copyOf[i2] = this.prevVariableValues[i2] - variableValues[i2];
            }
            float pNorm = MathUtils.pNorm(copyOf, this.variableMovementNorm);
            if (pNorm < this.variableMovementTolerance) {
                log.trace("Breaking optimization. Movement of variables: {} below tolerance: {}.", Float.valueOf(pNorm), Float.valueOf(this.variableMovementTolerance));
                return true;
            }
        }
        this.prevVariableValues = Arrays.copyOf(variableValues, variableValues.length);
        return false;
    }

    public void computeOptimalValueGradient(TermStore<T> termStore, float[] fArr, float[] fArr2) {
        parallelComputeGradient(termStore, fArr, fArr2);
    }

    /* JADX WARN: Type inference failed for: r1v7, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [float[], float[][]] */
    public void parallelComputeGradient(TermStore termStore, float[] fArr, float[] fArr2) {
        int size = (int) ((termStore.size() / (Parallel.getNumThreads() * 4)) + 1);
        int ceil = (int) Math.ceil(termStore.size() / size);
        if (this.workerRVAtomGradients == null || this.workerRVAtomGradients.length < ceil || this.workerRVAtomGradients[0].length < fArr.length || this.workerDeepGradients == null || this.workerDeepGradients.length < ceil || this.workerDeepGradients[0].length < fArr2.length) {
            this.workerRVAtomGradients = new float[ceil];
            this.workerDeepGradients = new float[ceil];
            for (int i = 0; i < ceil; i++) {
                this.workerRVAtomGradients[i] = new float[fArr.length];
                this.workerDeepGradients[i] = new float[fArr2.length];
            }
        }
        Parallel.count(ceil, new GradientWorker(termStore, this.workerRVAtomGradients, this.workerDeepGradients, size));
        Arrays.fill(fArr, 0.0f);
        Arrays.fill(fArr2, 0.0f);
        for (int i2 = 0; i2 < ceil; i2++) {
            for (int i3 = 0; i3 < termStore.getNumVariables(); i3++) {
                int i4 = i3;
                fArr[i4] = fArr[i4] + this.workerRVAtomGradients[i2][i3];
                int i5 = i3;
                fArr2[i5] = fArr2[i5] + this.workerDeepGradients[i2][i3];
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void clipGradient(float[] fArr, float[] fArr2) {
        for (int i = 0; i < fArr2.length; i++) {
            if (MathUtils.equals(fArr[i], 0.0f) && fArr2[i] > 0.0f) {
                fArr2[i] = 0.0f;
            } else if (MathUtils.equals(fArr[i], 1.0f) && fArr2[i] < 0.0f) {
                fArr2[i] = 0.0f;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ObjectiveResult computeObjective(TermStore<T> termStore) {
        float f = 0.0f;
        long j = 0;
        float[] variableValues = termStore.getVariableValues();
        Iterator<T> it = termStore.iterator();
        while (it.hasNext()) {
            T next = it.next();
            if (next.isActive()) {
                if (!next.isConstraint()) {
                    f += next.evaluate(variableValues);
                } else if (next.evaluate(variableValues) > 0.0f) {
                    j++;
                }
            }
        }
        return new ObjectiveResult(f, j);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ObjectiveResult parallelComputeObjective(TermStore<T> termStore) {
        if (!$assertionsDisabled && (termStore instanceof StreamingTermStore)) {
            throw new AssertionError();
        }
        int size = (int) ((termStore.size() / (Parallel.getNumThreads() * 4)) + 1);
        int ceil = (int) Math.ceil(termStore.size() / size);
        float[] fArr = new float[ceil];
        int[] iArr = new int[ceil];
        Parallel.count(ceil, new ObjectiveWorker(termStore, fArr, iArr, size));
        float f = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < ceil; i2++) {
            f += fArr[i2];
            i += iArr[i2];
        }
        return new ObjectiveResult(f, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void evaluate(TermStore<T> termStore, int i, List<EvaluationInstance> list, TrainingMap trainingMap) {
        if (!this.evaluate || trainingMap == null || list == null || list.size() == 0) {
            return;
        }
        termStore.sync();
        for (EvaluationInstance evaluationInstance : list) {
            evaluationInstance.compute(trainingMap);
            log.info("Iteration {} -- {}.", Integer.valueOf(i), evaluationInstance.getOutput());
        }
    }

    static {
        $assertionsDisabled = !Reasoner.class.desiredAssertionStatus();
        log = Logger.getLogger(Reasoner.class);
    }
}
