package org.linqs.psl.reasoner.term;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.model.ModelPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.ReasonerLocalVariable;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.util.Logger;

/* loaded from: input_file:org/linqs/psl/reasoner/term/MemoryVariableTermStore.class */
public abstract class MemoryVariableTermStore<T extends ReasonerTerm, V extends ReasonerLocalVariable> implements VariableTermStore<T, V> {
    private static final Logger log = Logger.getLogger(MemoryVariableTermStore.class);
    private Map<V, Integer> variables;
    private float[] variableValues;
    private RandomVariableAtom[] variableAtoms;
    private Set<ModelPredicate> modelPredicates;
    private Map<RandomVariableAtom, List<MemoryVariableTermStore<T, V>.MirrorTermCoefficient>> mirrorVariables;
    private boolean variablesExternallyUpdatedFlag;
    private boolean shuffle = Options.MEMORY_VTS_SHUFFLE.getBoolean();
    private int defaultSize = Options.MEMORY_VTS_DEFAULT_SIZE.getInt();
    private MemoryTermStore<T> store = new MemoryTermStore<>();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/reasoner/term/MemoryVariableTermStore$MirrorTermCoefficient.class */
    public class MirrorTermCoefficient {
        public T term;
        public float coefficient;

        public MirrorTermCoefficient(T t, float f) {
            this.term = t;
            this.coefficient = f;
        }
    }

    public MemoryVariableTermStore() {
        ensureVariableCapacity(this.defaultSize);
        this.modelPredicates = new HashSet();
        this.mirrorVariables = new HashMap();
        this.variablesExternallyUpdatedFlag = false;
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public int getVariableIndex(V v) {
        return this.variables.get(v).intValue();
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public float getVariableValue(int i) {
        return this.variableValues[i];
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public float[] getVariableValues() {
        return this.variableValues;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public double syncAtoms() {
        double d = 0.0d;
        for (int i = 0; i < this.variables.size(); i++) {
            d += Math.pow(this.variableAtoms[i].getValue() - this.variableValues[i], 2.0d);
            this.variableAtoms[i].setValue(this.variableValues[i]);
        }
        return Math.sqrt(d);
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public GroundAtom[] getVariableAtoms() {
        return this.variableAtoms;
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public int getNumVariables() {
        return this.variables.size();
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public int getNumRandomVariables() {
        return getNumVariables();
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public int getNumObservedVariables() {
        return 0;
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public boolean isLoaded() {
        return true;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public synchronized V createLocalVariable(GroundAtom groundAtom) {
        if (!(groundAtom instanceof RandomVariableAtom)) {
            throw new IllegalArgumentException("MemoryVariableTermStores do not keep track of observed atoms (" + groundAtom + ").");
        }
        RandomVariableAtom randomVariableAtom = (RandomVariableAtom) groundAtom;
        V convertAtomToVariable = convertAtomToVariable(randomVariableAtom);
        if (this.variables.containsKey(convertAtomToVariable)) {
            return convertAtomToVariable;
        }
        if (this.variables.size() >= this.variableAtoms.length) {
            ensureVariableCapacity(this.variables.size() * 2);
        }
        int size = this.variables.size();
        this.variables.put(convertAtomToVariable, Integer.valueOf(size));
        this.variableValues[size] = randomVariableAtom.getValue();
        this.variableAtoms[size] = randomVariableAtom;
        return convertAtomToVariable;
    }

    private synchronized void createMirrorVariable(RandomVariableAtom randomVariableAtom, float f, T t) {
        if (randomVariableAtom.getPredicate() instanceof ModelPredicate) {
            this.modelPredicates.add((ModelPredicate) randomVariableAtom.getPredicate());
        }
        if (!this.mirrorVariables.containsKey(randomVariableAtom)) {
            this.mirrorVariables.put(randomVariableAtom, new ArrayList());
        }
        this.mirrorVariables.get(randomVariableAtom).add(new MirrorTermCoefficient(t, f));
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void variablesExternallyUpdated() {
        this.variablesExternallyUpdatedFlag = true;
        this.store.variablesExternallyUpdated();
    }

    public boolean getVariablesExternallyUpdatedFlag() {
        return this.variablesExternallyUpdatedFlag;
    }

    public void resetVariablesExternallyUpdatedFlag() {
        this.variablesExternallyUpdatedFlag = false;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void ensureVariableCapacity(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Variable capacity must be non-negative. Got: " + i);
        }
        if (i == 0) {
            return;
        }
        if (this.variables == null || this.variables.size() == 0) {
            this.variables = new HashMap((int) Math.ceil(i / 0.75d));
            this.variableValues = new float[i];
            this.variableAtoms = new RandomVariableAtom[i];
        } else if (this.variables.size() < i) {
            if (i < this.variables.size() * 2) {
                i = this.variables.size() * 2;
            }
            HashMap hashMap = new HashMap((int) Math.ceil(i / 0.75d));
            hashMap.putAll(this.variables);
            this.variables = hashMap;
            this.variableValues = Arrays.copyOf(this.variableValues, i);
            this.variableAtoms = (RandomVariableAtom[]) Arrays.copyOf(this.variableAtoms, i);
        }
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public Iterable<V> getVariables() {
        return this.variables.keySet();
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void add(GroundRule groundRule, T t, Hyperplane hyperplane) {
        this.store.add(groundRule, t, hyperplane);
        if (hyperplane.getIntegratedRVAs() != null) {
            for (Hyperplane.IntegratedRVA integratedRVA : hyperplane.getIntegratedRVAs()) {
                createMirrorVariable(integratedRVA.atom, integratedRVA.coefficient, t);
            }
        }
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void clear() {
        if (this.store != null) {
            this.store.clear();
        }
        if (this.variables != null) {
            this.variables.clear();
        }
        if (this.modelPredicates != null) {
            this.modelPredicates.clear();
        }
        if (this.mirrorVariables != null) {
            this.mirrorVariables.clear();
        }
        this.variableValues = null;
        this.variableAtoms = null;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void reset() {
        for (int i = 0; i < this.variables.size(); i++) {
            this.variableValues[i] = this.variableAtoms[i].getValue();
        }
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void close() {
        clear();
        if (this.store != null) {
            this.store.close();
            this.store = null;
        }
        this.variables = null;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void initForOptimization() {
        initialFitModelAtoms();
        updateModelAtoms();
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void iterationComplete() {
        fitModelAtoms();
        updateModelAtoms();
    }

    public RandomVariableAtom getAtom(int i) {
        return this.variableAtoms[i];
    }

    /* JADX WARN: Type inference failed for: r0v49, types: [T extends org.linqs.psl.reasoner.term.ReasonerTerm, org.linqs.psl.reasoner.term.ReasonerTerm] */
    private void updateModelAtoms() {
        if (this.modelPredicates.size() == 0) {
            return;
        }
        Iterator<ModelPredicate> it = this.modelPredicates.iterator();
        while (it.hasNext()) {
            it.next().runModel();
        }
        double d = 0.0d;
        int i = 0;
        for (RandomVariableAtom randomVariableAtom : this.mirrorVariables.keySet()) {
            if (randomVariableAtom.getPredicate() instanceof ModelPredicate) {
                ModelPredicate modelPredicate = (ModelPredicate) randomVariableAtom.getPredicate();
                float value = randomVariableAtom.getValue();
                float value2 = modelPredicate.getValue(randomVariableAtom);
                randomVariableAtom.setValue(value2);
                for (MemoryVariableTermStore<T, V>.MirrorTermCoefficient mirrorTermCoefficient : this.mirrorVariables.get(randomVariableAtom)) {
                    mirrorTermCoefficient.term.adjustConstant(mirrorTermCoefficient.coefficient * value, mirrorTermCoefficient.coefficient * value2);
                }
                d += Math.pow(value2 - modelPredicate.getLabel(randomVariableAtom), 2.0d);
                i++;
            }
        }
        if (i != 0) {
            d = Math.pow(d / i, 0.5d);
        }
        log.trace("Batch update of {} model atoms. RMSE: {}", Integer.valueOf(i), Double.valueOf(d));
        variablesExternallyUpdated();
    }

    private void initialFitModelAtoms() {
        Iterator<ModelPredicate> it = this.modelPredicates.iterator();
        while (it.hasNext()) {
            it.next().initialFit();
        }
    }

    private void fitModelAtoms() {
        if (this.modelPredicates.size() == 0) {
            return;
        }
        Iterator<ModelPredicate> it = this.modelPredicates.iterator();
        while (it.hasNext()) {
            it.next().resetLabels();
        }
        int i = 0;
        for (RandomVariableAtom randomVariableAtom : this.mirrorVariables.keySet()) {
            if (randomVariableAtom.getPredicate() instanceof ModelPredicate) {
                ((ModelPredicate) randomVariableAtom.getPredicate()).setLabel(randomVariableAtom, this.variableValues[this.variables.get(convertAtomToVariable(randomVariableAtom.getMirror())).intValue()]);
                i++;
            }
        }
        Iterator<ModelPredicate> it2 = this.modelPredicates.iterator();
        while (it2.hasNext()) {
            it2.next().fit();
        }
        log.trace("Batch fit of {} model atoms.", Integer.valueOf(i));
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public T get(long j) {
        return this.store.get(j);
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public long size() {
        return this.store.size();
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void ensureCapacity(long j) {
        this.store.ensureCapacity(j);
    }

    @Override // java.lang.Iterable
    public Iterator<T> iterator() {
        if (this.shuffle) {
            this.store.shuffle();
        }
        return this.store.iterator();
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public Iterator<T> noWriteIterator() {
        return iterator();
    }

    protected abstract V convertAtomToVariable(RandomVariableAtom randomVariableAtom);
}
