package org.linqs.psl.reasoner.admm.term;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.FloatMatrix;
import org.linqs.psl.util.HashCode;

/* loaded from: input_file:org/linqs/psl/reasoner/admm/term/ADMMObjectiveTerm.class */
public class ADMMObjectiveTerm extends ReasonerTerm {
    private final float[] variableValues;
    private final float[] variableLagranges;
    private float[] consensusOptimizer;
    private float[] unitNormal;
    private static final Map<Integer, FloatMatrix> lowerTriangleCache;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/linqs/psl/reasoner/admm/term/ADMMObjectiveTerm$ADMMObjectiveTermState.class */
    public static final class ADMMObjectiveTermState extends TermState {
        public float[] variableValues;
        public float[] variableLagranges;

        public ADMMObjectiveTermState(float[] fArr, float[] fArr2) {
            this.variableValues = Arrays.copyOf(fArr, fArr.length);
            this.variableLagranges = Arrays.copyOf(fArr2, fArr2.length);
        }
    }

    public ADMMObjectiveTerm(Hyperplane hyperplane, Rule rule, boolean z, boolean z2, FunctionComparator functionComparator) {
        super(hyperplane, rule, z, z2, functionComparator);
        this.variableValues = new float[this.size];
        this.variableLagranges = new float[this.size];
        GroundAtom[] variables = hyperplane.getVariables();
        for (int i = 0; i < this.size; i++) {
            this.variableValues[i] = variables[i].getValue();
            this.variableLagranges[i] = 0.0f;
        }
        if (this.termType == ReasonerTerm.TermType.HingeLossTerm || this.termType == ReasonerTerm.TermType.LinearConstraintTerm) {
            initUnitNormal();
        }
    }

    public static ADMMObjectiveTerm createLinearConstraintTerm(Hyperplane hyperplane, Rule rule, FunctionComparator functionComparator) {
        return new ADMMObjectiveTerm(hyperplane, rule, false, false, functionComparator);
    }

    public static ADMMObjectiveTerm createLinearLossTerm(Hyperplane hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, false, false, null);
    }

    public static ADMMObjectiveTerm createHingeLossTerm(Hyperplane hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, false, true, null);
    }

    public static ADMMObjectiveTerm createSquaredLinearLossTerm(Hyperplane hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, true, false, null);
    }

    public static ADMMObjectiveTerm createSquaredHingeLossTerm(Hyperplane hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, true, true, null);
    }

    public void updateLagrange(float f, float[] fArr) {
        for (int i = 0; i < this.size; i++) {
            float[] fArr2 = this.variableLagranges;
            int i2 = i;
            fArr2[i2] = fArr2[i2] + (f * (this.variableValues[i] - fArr[this.atomIndexes[i]]));
        }
    }

    public void setLocalValue(short s, float f, float f2) {
        this.variableValues[s] = f;
        this.variableLagranges[s] = f2;
    }

    public float getVariableValue(short s) {
        return this.variableValues[s];
    }

    public float getVariableLagrange(short s) {
        return this.variableLagranges[s];
    }

    public void minimize(float f, float[] fArr) {
        float weight = getWeight();
        switch (this.termType) {
            case LinearConstraintTerm:
                minimizeConstraint(f, fArr);
                return;
            case LinearLossTerm:
                minimizeLinearLoss(f, weight, fArr);
                return;
            case HingeLossTerm:
                minimizeHingeLoss(f, weight, fArr);
                return;
            case SquaredLinearLossTerm:
                minimizeSquaredLinearLoss(f, weight, fArr);
                return;
            case SquaredHingeLossTerm:
                minimizeSquaredHingeLoss(f, weight, fArr);
                return;
            default:
                throw new IllegalStateException("Unknown term type.");
        }
    }

    private void minimizeConstraint(float f, float[] fArr) {
        if (!this.comparator.equals(FunctionComparator.EQ)) {
            float f2 = 0.0f;
            for (int i = 0; i < this.size; i++) {
                float f3 = fArr[this.atomIndexes[i]] - (this.variableLagranges[i] / f);
                this.variableValues[i] = f3;
                f2 += this.coefficients[i] * f3;
            }
            if (this.comparator.equals(FunctionComparator.LTE) && f2 <= this.constant) {
                return;
            }
            if (this.comparator.equals(FunctionComparator.GTE) && f2 >= this.constant) {
                return;
            }
        }
        project(f, fArr);
    }

    private void minimizeLinearLoss(float f, float f2, float[] fArr) {
        for (int i = 0; i < this.size; i++) {
            this.variableValues[i] = (fArr[this.atomIndexes[i]] - (this.variableLagranges[i] / f)) - ((f2 * this.coefficients[i]) / f);
        }
    }

    private void minimizeHingeLoss(float f, float f2, float[] fArr) {
        float f3 = 0.0f;
        for (int i = 0; i < this.size; i++) {
            float f4 = fArr[this.atomIndexes[i]] - (this.variableLagranges[i] / f);
            this.variableValues[i] = f4;
            f3 += this.coefficients[i] * f4;
        }
        if (f3 <= this.constant) {
            return;
        }
        float f5 = 0.0f;
        for (int i2 = 0; i2 < this.size; i2++) {
            float f6 = (fArr[this.atomIndexes[i2]] - (this.variableLagranges[i2] / f)) - ((f2 * this.coefficients[i2]) / f);
            this.variableValues[i2] = f6;
            f5 += this.coefficients[i2] * f6;
        }
        if (f5 >= this.constant) {
            return;
        }
        project(f, fArr);
    }

    private void minimizeSquaredLinearLoss(float f, float f2, float[] fArr) {
        minWeightedSquaredHyperplane(f, f2, fArr);
    }

    private void minimizeSquaredHingeLoss(float f, float f2, float[] fArr) {
        float f3 = 0.0f;
        for (int i = 0; i < this.size; i++) {
            float f4 = fArr[this.atomIndexes[i]] - (this.variableLagranges[i] / f);
            this.variableValues[i] = f4;
            f3 += this.coefficients[i] * f4;
        }
        if (f3 <= this.constant) {
            return;
        }
        minWeightedSquaredHyperplane(f, f2, fArr);
    }

    private void initUnitNormal() {
        if (this.size == 1) {
            this.consensusOptimizer = null;
            this.unitNormal = null;
            return;
        }
        this.consensusOptimizer = new float[this.size];
        this.unitNormal = new float[this.size];
        float f = 0.0f;
        for (int i = 0; i < this.size; i++) {
            f += this.coefficients[i] * this.coefficients[i];
        }
        float sqrt = (float) Math.sqrt(f);
        for (int i2 = 0; i2 < this.size; i2++) {
            this.unitNormal[i2] = this.coefficients[i2] / sqrt;
        }
    }

    private void project(float f, float[] fArr) {
        if (this.size == 1) {
            this.variableValues[0] = this.constant / this.coefficients[0];
            return;
        }
        for (int i = 0; i < this.size; i++) {
            this.consensusOptimizer[i] = fArr[this.atomIndexes[i]] - (this.variableLagranges[i] / f);
        }
        float f2 = ((-1.0f) * this.constant) / (this.coefficients[0] / this.unitNormal[0]);
        for (int i2 = 0; i2 < this.size; i2++) {
            f2 += this.consensusOptimizer[i2] * this.unitNormal[i2];
        }
        for (int i3 = 0; i3 < this.size; i3++) {
            this.variableValues[i3] = this.consensusOptimizer[i3] - (f2 * this.unitNormal[i3]);
        }
    }

    private void minWeightedSquaredHyperplane(float f, float f2, float[] fArr) {
        for (int i = 0; i < this.size; i++) {
            this.variableValues[i] = ((f * fArr[this.atomIndexes[i]]) - this.variableLagranges[i]) + (2.0f * f2 * this.coefficients[i] * this.constant);
        }
        if (this.size == 1) {
            float[] fArr2 = this.variableValues;
            fArr2[0] = fArr2[0] / ((((2.0f * f2) * this.coefficients[0]) * this.coefficients[0]) + f);
            return;
        }
        if (this.size == 2) {
            float f3 = this.variableValues[0];
            float f4 = this.variableValues[1];
            float f5 = this.coefficients[0];
            float f6 = this.coefficients[1];
            float f7 = (2.0f * f2 * f5 * f5) + f;
            float f8 = (2.0f * f2 * f6 * f6) + f;
            float f9 = 2.0f * f2 * f5 * f6;
            float f10 = (f4 - ((f9 * f3) / f7)) / (f8 - ((f9 * f9) / f7));
            this.variableValues[0] = (f3 - (f9 * f10)) / f7;
            this.variableValues[1] = f10;
            return;
        }
        FloatMatrix fetchLowerTriangle = fetchLowerTriangle(f, f2);
        for (int i2 = 0; i2 < this.size; i2++) {
            float f11 = this.variableValues[i2];
            for (int i3 = 0; i3 < i2; i3++) {
                f11 -= fetchLowerTriangle.get(i2, i3) * this.variableValues[i3];
            }
            this.variableValues[i2] = f11 / fetchLowerTriangle.get(i2, i2);
        }
        for (int i4 = this.size - 1; i4 >= 0; i4--) {
            float f12 = this.variableValues[i4];
            for (int i5 = this.size - 1; i5 > i4; i5--) {
                f12 -= fetchLowerTriangle.get(i5, i4) * this.variableValues[i5];
            }
            this.variableValues[i4] = f12 / fetchLowerTriangle.get(i4, i4);
        }
    }

    private FloatMatrix fetchLowerTriangle(float f, float f2) {
        int build = HashCode.build(HashCode.build(Float.valueOf(f2)), Float.valueOf(f));
        for (int i = 0; i < this.size; i++) {
            build = HashCode.build(build, Float.valueOf(this.coefficients[i]));
        }
        FloatMatrix floatMatrix = lowerTriangleCache.get(Integer.valueOf(build));
        return floatMatrix != null ? floatMatrix : computeLowerTriangle(f, f2, build);
    }

    private synchronized FloatMatrix computeLowerTriangle(float f, float f2, int i) {
        if (lowerTriangleCache.containsKey(Integer.valueOf(i))) {
            return lowerTriangleCache.get(Integer.valueOf(i));
        }
        FloatMatrix zeroes = FloatMatrix.zeroes(this.size, this.size);
        for (int i2 = 0; i2 < this.size; i2++) {
            for (int i3 = i2; i3 < this.size; i3++) {
                if (i2 == i3) {
                    zeroes.set(i2, i2, (2.0f * f2 * this.coefficients[i2] * this.coefficients[i2]) + f);
                } else {
                    float f3 = 2.0f * f2 * this.coefficients[i2] * this.coefficients[i3];
                    zeroes.set(i2, i3, f3);
                    zeroes.set(i3, i2, f3);
                }
            }
        }
        zeroes.choleskyDecomposition(true);
        lowerTriangleCache.put(Integer.valueOf(i), zeroes);
        return zeroes;
    }

    @Override // org.linqs.psl.reasoner.term.ReasonerTerm
    public void loadState(TermState termState) {
        if (!$assertionsDisabled && !(termState instanceof ADMMObjectiveTermState)) {
            throw new AssertionError();
        }
        ADMMObjectiveTermState aDMMObjectiveTermState = (ADMMObjectiveTermState) termState;
        System.arraycopy(aDMMObjectiveTermState.variableValues, 0, this.variableValues, 0, this.variableValues.length);
        System.arraycopy(aDMMObjectiveTermState.variableLagranges, 0, this.variableLagranges, 0, this.variableLagranges.length);
    }

    @Override // org.linqs.psl.reasoner.term.ReasonerTerm
    public TermState saveState() {
        return new ADMMObjectiveTermState(this.variableValues, this.variableLagranges);
    }

    @Override // org.linqs.psl.reasoner.term.ReasonerTerm
    public void saveState(TermState termState) {
        if (!$assertionsDisabled && !(termState instanceof ADMMObjectiveTermState)) {
            throw new AssertionError();
        }
        ADMMObjectiveTermState aDMMObjectiveTermState = (ADMMObjectiveTermState) termState;
        System.arraycopy(this.variableValues, 0, aDMMObjectiveTermState.variableValues, 0, this.variableValues.length);
        System.arraycopy(this.variableLagranges, 0, aDMMObjectiveTermState.variableLagranges, 0, this.variableLagranges.length);
    }

    static {
        $assertionsDisabled = !ADMMObjectiveTerm.class.desiredAssertionStatus();
        lowerTriangleCache = new HashMap();
    }
}
