package org.linqs.psl.reasoner.dcd.term;

import java.nio.ByteBuffer;
import org.apache.commons.configuration2.tree.DefaultExpressionEngineSymbols;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/reasoner/dcd/term/DCDObjectiveTerm.class */
public class DCDObjectiveTerm implements ReasonerTerm {
    private boolean squared;
    private float adjustedWeight;
    private float constant;
    private float lagrange;
    private float qii;
    private short size;
    private float[] coefficients;
    private int[] variableIndexes;

    public DCDObjectiveTerm(VariableTermStore<DCDObjectiveTerm, RandomVariableAtom> variableTermStore, boolean z, Hyperplane<RandomVariableAtom> hyperplane, float f, float f2) {
        this.squared = z;
        this.size = (short) hyperplane.size();
        this.coefficients = hyperplane.getCoefficients();
        this.constant = hyperplane.getConstant();
        this.variableIndexes = new int[this.size];
        RandomVariableAtom[] variables = hyperplane.getVariables();
        for (int i = 0; i < this.size; i++) {
            this.variableIndexes[i] = variableTermStore.getVariableIndex(variables[i]);
        }
        this.adjustedWeight = f * f2;
        float f3 = 0.0f;
        for (int i2 = 0; i2 < this.size; i2++) {
            f3 += this.coefficients[i2] * this.coefficients[i2];
        }
        this.qii = f3;
        this.lagrange = 0.0f;
    }

    public float getLagrange() {
        return this.lagrange;
    }

    public float evaluate(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < this.size; i++) {
            f += this.coefficients[i] * fArr[this.variableIndexes[i]];
        }
        return this.squared ? this.adjustedWeight * ((float) Math.pow(Math.max(0.0f, r0), 2.0d)) : this.adjustedWeight * Math.max(f - this.constant, 0.0f);
    }

    public void minimize(boolean z, float[] fArr) {
        if (this.squared) {
            minimize(z, computeGradient(fArr) + (this.lagrange / (2.0f * this.adjustedWeight)), Float.POSITIVE_INFINITY, fArr);
        } else {
            minimize(z, computeGradient(fArr), this.adjustedWeight, fArr);
        }
    }

    @Override // org.linqs.psl.reasoner.term.ReasonerTerm
    public int size() {
        return this.size;
    }

    private float computeGradient(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < this.size; i++) {
            f += fArr[this.variableIndexes[i]] * this.coefficients[i];
        }
        return this.constant - f;
    }

    private void minimize(boolean z, float f, float f2, float[] fArr) {
        float f3 = f;
        if (MathUtils.isZero(this.lagrange)) {
            f3 = Math.min(0.0f, f);
        }
        if (MathUtils.equals(f2, this.adjustedWeight) && MathUtils.equals(this.lagrange, this.adjustedWeight)) {
            f3 = Math.max(0.0f, f);
        }
        if (MathUtils.isZero(f3)) {
            return;
        }
        float f4 = this.lagrange;
        this.lagrange = Math.min(f2, Math.max(0.0f, this.lagrange - (f / this.qii)));
        for (int i = 0; i < this.size; i++) {
            float f5 = fArr[this.variableIndexes[i]] - ((this.lagrange - f4) * this.coefficients[i]);
            if (z) {
                f5 = Math.max(0.0f, Math.min(1.0f, f5));
            }
            fArr[this.variableIndexes[i]] = f5;
        }
    }

    public int fixedByteSize() {
        return (120 + (this.size * 64)) / 8;
    }

    public void writeFixedValues(ByteBuffer byteBuffer) {
        byteBuffer.put((byte) (this.squared ? 1 : 0));
        byteBuffer.putFloat(this.adjustedWeight);
        byteBuffer.putFloat(this.constant);
        byteBuffer.putFloat(this.qii);
        byteBuffer.putShort(this.size);
        for (int i = 0; i < this.size; i++) {
            byteBuffer.putFloat(this.coefficients[i]);
            byteBuffer.putInt(this.variableIndexes[i]);
        }
    }

    public void read(ByteBuffer byteBuffer, ByteBuffer byteBuffer2) {
        this.squared = byteBuffer.get() == 1;
        this.adjustedWeight = byteBuffer.getFloat();
        this.constant = byteBuffer.getFloat();
        this.qii = byteBuffer.getFloat();
        this.size = byteBuffer.getShort();
        if (this.coefficients.length < this.size) {
            this.coefficients = new float[this.size];
            this.variableIndexes = new int[this.size];
        }
        for (int i = 0; i < this.size; i++) {
            this.coefficients[i] = byteBuffer.getFloat();
            this.variableIndexes[i] = byteBuffer.getInt();
        }
        this.lagrange = byteBuffer2.getFloat();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.adjustedWeight);
        sb.append(" * max(0.0, ");
        for (int i = 0; i < this.size; i++) {
            sb.append(DefaultExpressionEngineSymbols.DEFAULT_INDEX_START);
            sb.append(this.coefficients[i]);
            sb.append(" * ");
            sb.append(this.variableIndexes[i]);
            sb.append(DefaultExpressionEngineSymbols.DEFAULT_INDEX_END);
            if (i != this.size - 1) {
                sb.append(" + ");
            }
        }
        sb.append(" - ");
        sb.append(this.constant);
        sb.append(DefaultExpressionEngineSymbols.DEFAULT_INDEX_END);
        if (this.squared) {
            sb.append("^2");
        }
        return sb.toString();
    }
}
