package org.linqs.psl.reasoner.sgd.term;

import java.nio.ByteBuffer;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.rule.AbstractRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.reasoner.term.streaming.StreamingTerm;

/* loaded from: input_file:org/linqs/psl/reasoner/sgd/term/SGDObjectiveTerm.class */
public class SGDObjectiveTerm implements StreamingTerm {
    private boolean squared;
    private boolean hinge;
    private WeightedRule rule;
    private float constant;
    private short size;
    private float[] coefficients;
    private int[] variableIndexes;

    public SGDObjectiveTerm(VariableTermStore<SGDObjectiveTerm, GroundAtom> variableTermStore, WeightedRule weightedRule, boolean z, boolean z2, Hyperplane<GroundAtom> hyperplane) {
        this.squared = z;
        this.hinge = z2;
        this.rule = weightedRule;
        this.size = (short) hyperplane.size();
        this.coefficients = hyperplane.getCoefficients();
        this.constant = hyperplane.getConstant();
        this.variableIndexes = new int[this.size];
        GroundAtom[] variables = hyperplane.getVariables();
        for (int i = 0; i < this.size; i++) {
            this.variableIndexes[i] = variableTermStore.getVariableIndex(variables[i]);
        }
    }

    public int getVariableIndex(int i) {
        return this.variableIndexes[i];
    }

    @Override // org.linqs.psl.reasoner.term.ReasonerTerm
    public int size() {
        return this.size;
    }

    @Override // org.linqs.psl.reasoner.term.ReasonerTerm
    public void adjustConstant(float f, float f2) {
        this.constant = (this.constant - f) + f2;
    }

    @Override // org.linqs.psl.reasoner.term.ReasonerTerm
    public boolean isConvex() {
        return true;
    }

    public float evaluate(float[] fArr) {
        float dot = dot(fArr);
        float weight = getWeight();
        return (this.squared && this.hinge) ? weight * ((float) Math.pow(Math.max(0.0f, dot), 2.0d)) : (!this.squared || this.hinge) ? (this.squared || !this.hinge) ? weight * dot : weight * Math.max(0.0f, dot) : weight * ((float) Math.pow(dot, 2.0d));
    }

    public float computePartial(int i, float f, float f2) {
        if (!this.hinge || f > 0.0f) {
            return this.squared ? f2 * 2.0f * f * this.coefficients[i] : f2 * this.coefficients[i];
        }
        return 0.0f;
    }

    public float dot(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < this.size; i++) {
            f += this.coefficients[i] * fArr[this.variableIndexes[i]];
        }
        return f - this.constant;
    }

    public WeightedRule getRule() {
        return this.rule;
    }

    public int[] getVariableIndexes() {
        return this.variableIndexes;
    }

    @Override // org.linqs.psl.reasoner.term.streaming.StreamingTerm
    public int fixedByteSize() {
        return (96 + (this.size * 64)) / 8;
    }

    @Override // org.linqs.psl.reasoner.term.streaming.StreamingTerm
    public void writeFixedValues(ByteBuffer byteBuffer) {
        byteBuffer.put((byte) (this.squared ? 1 : 0));
        byteBuffer.put((byte) (this.hinge ? 1 : 0));
        byteBuffer.putInt(this.rule.hashCode());
        byteBuffer.putFloat(this.constant);
        byteBuffer.putShort(this.size);
        for (int i = 0; i < this.size; i++) {
            byteBuffer.putFloat(this.coefficients[i]);
            byteBuffer.putInt(this.variableIndexes[i]);
        }
    }

    @Override // org.linqs.psl.reasoner.term.streaming.StreamingTerm
    public void read(ByteBuffer byteBuffer, ByteBuffer byteBuffer2) {
        this.squared = byteBuffer.get() == 1;
        this.hinge = byteBuffer.get() == 1;
        this.rule = (WeightedRule) AbstractRule.getRule(byteBuffer.getInt());
        this.constant = byteBuffer.getFloat();
        this.size = byteBuffer.getShort();
        if (this.coefficients.length < this.size) {
            this.coefficients = new float[this.size];
            this.variableIndexes = new int[this.size];
        }
        for (int i = 0; i < this.size; i++) {
            this.coefficients[i] = byteBuffer.getFloat();
            this.variableIndexes[i] = byteBuffer.getInt();
        }
    }

    public String toString() {
        return toString(null);
    }

    public String toString(VariableTermStore<SGDObjectiveTerm, GroundAtom> variableTermStore) {
        StringBuilder sb = new StringBuilder();
        sb.append(getWeight());
        sb.append(" * ");
        if (this.hinge) {
            sb.append("max(0.0, ");
        } else {
            sb.append("(");
        }
        for (int i = 0; i < this.size; i++) {
            sb.append("(");
            sb.append(this.coefficients[i]);
            if (variableTermStore == null) {
                sb.append(" * <index:");
                sb.append(this.variableIndexes[i]);
                sb.append(">)");
            } else {
                sb.append(" * ");
                sb.append(variableTermStore.getVariableValue(this.variableIndexes[i]));
                sb.append(")");
            }
            if (i != this.size - 1) {
                sb.append(" + ");
            }
        }
        sb.append(" - ");
        sb.append(this.constant);
        sb.append(")");
        if (this.squared) {
            sb.append(" ^2");
        }
        return sb.toString();
    }

    private float getWeight() {
        if (this.rule == null || !this.rule.isWeighted()) {
            return Float.POSITIVE_INFINITY;
        }
        return this.rule.getWeight();
    }
}
