package org.linqs.psl.reasoner.term;

import java.util.Collection;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.UnweightedGroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.function.ConstraintTerm;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.function.FunctionTerm;
import org.linqs.psl.reasoner.function.GeneralFunction;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/reasoner/term/TermGenerator.class */
public abstract class TermGenerator<T extends ReasonerTerm> {
    private static final Logger log = Logger.getLogger(TermGenerator.class);
    protected boolean mergeConstants;

    public TermGenerator(boolean z) {
        this.mergeConstants = z;
    }

    public void setMergeConstants(boolean z) {
        this.mergeConstants = z;
    }

    public boolean getMergeConstants() {
        return this.mergeConstants;
    }

    public int createTerm(GroundRule groundRule, Collection<T> collection, Collection<Hyperplane> collection2) {
        Hyperplane processHyperplane;
        int createLinearConstraintTerm;
        if (groundRule instanceof WeightedGroundRule) {
            GeneralFunction functionDefinition = ((WeightedGroundRule) groundRule).getFunctionDefinition(this.mergeConstants);
            processHyperplane = processHyperplane(functionDefinition);
            if (processHyperplane == null) {
                return 0;
            }
            createLinearConstraintTerm = createLossTerm(collection, functionDefinition.isNonNegative(), functionDefinition.isSquared(), groundRule, processHyperplane);
        } else {
            if (!(groundRule instanceof UnweightedGroundRule)) {
                throw new IllegalArgumentException("Unsupported ground rule: " + groundRule);
            }
            ConstraintTerm constraintDefinition = ((UnweightedGroundRule) groundRule).getConstraintDefinition(this.mergeConstants);
            processHyperplane = processHyperplane(constraintDefinition.getFunction());
            if (processHyperplane == null) {
                return 0;
            }
            processHyperplane.setConstant(constraintDefinition.getValue() + processHyperplane.getConstant());
            createLinearConstraintTerm = createLinearConstraintTerm(collection, groundRule, processHyperplane, constraintDefinition.getComparator());
        }
        if (collection2 != null) {
            for (int i = 0; i < createLinearConstraintTerm; i++) {
                collection2.add(processHyperplane);
            }
        }
        return createLinearConstraintTerm;
    }

    private Hyperplane processHyperplane(GeneralFunction generalFunction) {
        Hyperplane hyperplane = new Hyperplane(generalFunction.size(), (-1.0f) * generalFunction.getConstant());
        for (int i = 0; i < generalFunction.size(); i++) {
            float coefficient = generalFunction.getCoefficient(i);
            FunctionTerm term = generalFunction.getTerm(i);
            if ((term instanceof RandomVariableAtom) || (!this.mergeConstants && (term instanceof ObservedAtom))) {
                GroundAtom groundAtom = (GroundAtom) term;
                int indexOfVariable = hyperplane.indexOfVariable(groundAtom);
                if (indexOfVariable == -1) {
                    hyperplane.addTerm(groundAtom, coefficient);
                } else {
                    if (generalFunction.isNonNegative() && !MathUtils.signsMatch(hyperplane.getCoefficient(indexOfVariable), coefficient)) {
                        return null;
                    }
                    hyperplane.appendCoefficient(indexOfVariable, coefficient);
                }
            } else {
                if (!term.isConstant()) {
                    throw new IllegalArgumentException("Unexpected summand: " + generalFunction + "[" + i + "] (" + term + ").");
                }
                hyperplane.setConstant(hyperplane.getConstant() - (coefficient * term.getValue()));
            }
        }
        if (hyperplane.size() == 0) {
            return null;
        }
        return hyperplane;
    }

    public abstract int createLossTerm(Collection<T> collection, boolean z, boolean z2, GroundRule groundRule, Hyperplane hyperplane);

    public abstract int createLinearConstraintTerm(Collection<T> collection, GroundRule groundRule, Hyperplane hyperplane, FunctionComparator functionComparator);
}
