package org.linqs.psl.reasoner.term;

import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/reasoner/term/ReasonerTerm.class */
public class ReasonerTerm {
    public final TermType termType = getTermType();
    protected FunctionComparator comparator;
    protected Rule rule;
    protected int[] atomIndexes;
    protected short size;
    protected float[] coefficients;
    protected float constant;
    protected boolean squared;
    protected boolean hinge;

    /* loaded from: input_file:org/linqs/psl/reasoner/term/ReasonerTerm$TermType.class */
    public enum TermType {
        LinearConstraintTerm,
        LinearLossTerm,
        HingeLossTerm,
        SquaredLinearLossTerm,
        SquaredHingeLossTerm
    }

    public ReasonerTerm(Hyperplane hyperplane, Rule rule, boolean z, boolean z2, FunctionComparator functionComparator) {
        this.rule = rule;
        this.comparator = functionComparator;
        this.squared = z;
        this.hinge = z2;
        this.size = (short) hyperplane.size();
        this.coefficients = hyperplane.getCoefficients();
        this.constant = hyperplane.getConstant();
        this.atomIndexes = new int[this.size];
        GroundAtom[] variables = hyperplane.getVariables();
        for (int i = 0; i < this.size; i++) {
            this.atomIndexes[i] = variables[i].getIndex();
        }
    }

    private TermType getTermType() {
        if (this.comparator != null) {
            return TermType.LinearConstraintTerm;
        }
        if (!this.squared && !this.hinge) {
            return TermType.LinearLossTerm;
        }
        if (!this.squared && this.hinge) {
            return TermType.HingeLossTerm;
        }
        if (this.squared && !this.hinge) {
            return TermType.SquaredLinearLossTerm;
        }
        if (this.squared && this.hinge) {
            return TermType.SquaredHingeLossTerm;
        }
        throw new IllegalStateException("Unknown term type.");
    }

    public int size() {
        return this.size;
    }

    public Rule getRule() {
        return this.rule;
    }

    public float getWeight() {
        if (this.rule == null || !this.rule.isWeighted()) {
            return Float.POSITIVE_INFINITY;
        }
        return ((WeightedRule) this.rule).getWeight();
    }

    public boolean isActive() {
        if (this.rule != null) {
            return this.rule.isActive();
        }
        return true;
    }

    public float getConstant() {
        return this.constant;
    }

    public int[] getAtomIndexes() {
        return this.atomIndexes;
    }

    public float[] getCoefficients() {
        return this.coefficients;
    }

    public boolean isConstraint() {
        return this.termType.equals(TermType.LinearConstraintTerm);
    }

    public float evaluate(float[] fArr) {
        float evaluateIncompatibility = evaluateIncompatibility(fArr);
        return isConstraint() ? evaluateIncompatibility > 0.0f ? Float.POSITIVE_INFINITY : 0.0f : getWeight() * evaluateIncompatibility;
    }

    public float evaluateIncompatibility(float[] fArr) {
        switch (this.termType) {
            case LinearConstraintTerm:
                return evaluateConstraint(fArr);
            case LinearLossTerm:
                return evaluateLinearLoss(fArr);
            case HingeLossTerm:
                return evaluateHingeLoss(fArr);
            case SquaredLinearLossTerm:
                return evaluateSquaredLinearLoss(fArr);
            case SquaredHingeLossTerm:
                return evaluateSquaredHingeLoss(fArr);
            default:
                throw new IllegalStateException("Unknown term type.");
        }
    }

    protected float evaluateConstraint(float[] fArr) {
        float computeInnerPotential = computeInnerPotential(fArr);
        if (this.comparator.equals(FunctionComparator.EQ)) {
            return MathUtils.isZero((double) computeInnerPotential, 0.005d) ? 0.0f : Float.POSITIVE_INFINITY;
        }
        if (this.comparator.equals(FunctionComparator.LTE)) {
            return ((double) computeInnerPotential) <= 0.005d ? 0.0f : Float.POSITIVE_INFINITY;
        }
        if (this.comparator.equals(FunctionComparator.GTE)) {
            return ((double) computeInnerPotential) >= -0.005d ? 0.0f : Float.POSITIVE_INFINITY;
        }
        throw new IllegalStateException("Unknown comparison function.");
    }

    protected float evaluateLinearLoss(float[] fArr) {
        return computeInnerPotential(fArr);
    }

    protected float evaluateHingeLoss(float[] fArr) {
        return Math.max(0.0f, computeInnerPotential(fArr));
    }

    protected float evaluateSquaredLinearLoss(float[] fArr) {
        return (float) Math.pow(computeInnerPotential(fArr), 2.0d);
    }

    public float evaluateSquaredHingeLoss(float[] fArr) {
        return (float) Math.pow(Math.max(0.0f, computeInnerPotential(fArr)), 2.0d);
    }

    public float computeInnerPotential(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < this.size; i++) {
            f += this.coefficients[i] * fArr[this.atomIndexes[i]];
        }
        return f - this.constant;
    }

    public float computeVariablePartial(int i, float f) {
        switch (this.termType) {
            case LinearConstraintTerm:
                return getWeight() * computeLinearConstraintPartial(i);
            case LinearLossTerm:
                return getWeight() * computeLinearLossPartial(i);
            case HingeLossTerm:
                return getWeight() * computeHingeLossPartial(i, f);
            case SquaredLinearLossTerm:
                return getWeight() * computeSquaredLinearLossPartial(i, f);
            case SquaredHingeLossTerm:
                return getWeight() * computeSquaredHingeLossPartial(i, f);
            default:
                throw new IllegalStateException("Unknown term type.");
        }
    }

    protected float computeLinearConstraintPartial(int i) {
        return this.coefficients[i];
    }

    protected float computeLinearLossPartial(int i) {
        return this.coefficients[i];
    }

    protected float computeHingeLossPartial(int i, float f) {
        if (f <= 0.0f) {
            return 0.0f;
        }
        return this.coefficients[i];
    }

    protected float computeSquaredLinearLossPartial(int i, float f) {
        return 2.0f * f * this.coefficients[i];
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float computeSquaredHingeLossPartial(int i, float f) {
        if (f <= 0.0f) {
            return 0.0f;
        }
        return 2.0f * f * this.coefficients[i];
    }

    public void loadState(TermState termState) {
    }

    public TermState saveState() {
        return new TermState();
    }

    public void saveState(TermState termState) {
    }

    public String toString() {
        return toString(null);
    }

    public String toString(AtomStore atomStore) {
        StringBuilder sb = new StringBuilder();
        sb.append(getWeight());
        sb.append(" * ");
        if (this.hinge) {
            sb.append("max(0.0, ");
        } else {
            sb.append("(");
        }
        for (int i = 0; i < this.size; i++) {
            sb.append("(");
            sb.append(this.coefficients[i]);
            if (atomStore == null) {
                sb.append(" * <index:");
                sb.append(this.atomIndexes[i]);
                sb.append(">)");
            } else {
                sb.append(" * ");
                sb.append(atomStore.getAtomValue(this.atomIndexes[i]));
                sb.append(")");
            }
            if (i != this.size - 1) {
                sb.append(" + ");
            }
        }
        sb.append(" - ");
        sb.append(this.constant);
        sb.append(")");
        if (this.squared) {
            sb.append(" ^2");
        }
        return sb.toString();
    }
}
