package ai.libs.mlplan.metamining.similaritymeasures;

import ai.libs.jaicore.basic.sets.SetUtil;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.GradientDescent;
import java.util.ArrayList;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/mlplan/metamining/similaritymeasures/F3Optimizer.class */
public class F3Optimizer implements IHeterogenousSimilarityMeasureComputer {
    private static final Logger logger;
    private static final double ALPHA_START = 1.0E-9d;
    private static final double ALPHA_MAX = 1.0E-7d;
    private static final int ITERATIONS_PER_PROBE = 100;
    private static final int LIMIT = 1;
    private static final double MAX_DESIRED_ERROR = 0.0d;
    private final double mu;
    private INDArray R;
    private INDArray X;
    private INDArray W;
    private INDArray U;
    private INDArray V;
    static final /* synthetic */ boolean $assertionsDisabled;

    public F3Optimizer(double d) {
        this.mu = d;
    }

    @Override // ai.libs.mlplan.metamining.similaritymeasures.IHeterogenousSimilarityMeasureComputer
    public void build(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        this.R = iNDArray3;
        this.W = iNDArray2;
        this.X = iNDArray;
        int rows = iNDArray.rows();
        final int columns = iNDArray.columns();
        int rows2 = iNDArray2.rows();
        final int columns2 = iNDArray2.columns();
        System.out.println("X = " + iNDArray + " (" + rows + " x " + iNDArray.columns() + ")");
        System.out.println("W = " + iNDArray2 + " (" + rows2 + " x " + iNDArray2.columns() + ")");
        boolean z = false;
        DoubleVector randomInitSolution = getRandomInitSolution(columns, columns2, LIMIT);
        SetUtil.Pair<INDArray, INDArray> vector2matrices = vector2matrices(randomInitSolution, columns, LIMIT, columns2, LIMIT);
        System.out.println("randomly initialized U = " + vector2matrices.getX() + " (" + columns + " x " + LIMIT + ")");
        System.out.println("randomly initialized V = " + vector2matrices.getY() + " (" + columns2 + " x " + LIMIT + ")");
        double cost = getCost((INDArray) vector2matrices.getX(), (INDArray) vector2matrices.getY());
        System.out.println("loss of randomly initialized U and V: " + cost);
        CostFunction costFunction = new CostFunction() { // from class: ai.libs.mlplan.metamining.similaritymeasures.F3Optimizer.1
            static final /* synthetic */ boolean $assertionsDisabled;

            public CostGradientTuple evaluateCost(DoubleVector doubleVector) {
                SetUtil.Pair<INDArray, INDArray> vector2matrices2 = F3Optimizer.this.vector2matrices(doubleVector, columns, F3Optimizer.LIMIT, columns2, F3Optimizer.LIMIT);
                INDArray iNDArray4 = (INDArray) vector2matrices2.getX();
                INDArray iNDArray5 = (INDArray) vector2matrices2.getY();
                if (!$assertionsDisabled && (iNDArray4.rows() != columns || iNDArray4.columns() != F3Optimizer.LIMIT)) {
                    throw new AssertionError("Incorrect shape of U: (" + iNDArray4.rows() + " x " + iNDArray4.columns() + ") instead of (" + columns + " x " + F3Optimizer.LIMIT + ")");
                }
                if (!$assertionsDisabled && (iNDArray5.rows() != columns2 || iNDArray5.columns() != F3Optimizer.LIMIT)) {
                    throw new AssertionError("Incorrect shape of V: (" + iNDArray5.rows() + " x " + iNDArray5.columns() + ") instead of (" + columns2 + " x " + F3Optimizer.LIMIT + ")");
                }
                return new CostGradientTuple(F3Optimizer.this.getCost(iNDArray4, iNDArray5), F3Optimizer.this.matrices2vector(F3Optimizer.this.getGradientAsMatrix(iNDArray4, iNDArray5, true), F3Optimizer.this.getGradientAsMatrix(iNDArray4, iNDArray5, false)));
            }

            static {
                $assertionsDisabled = !F3Optimizer.class.desiredAssertionStatus();
            }
        };
        double d = 1.0E-9d;
        int i = 0;
        while (true) {
            if (cost <= MAX_DESIRED_ERROR) {
                break;
            }
            double d2 = cost;
            DoubleVector doubleVector = randomInitSolution;
            randomInitSolution = new GradientDescent(d, 1.0d).minimize(costFunction, randomInitSolution, ITERATIONS_PER_PROBE, false);
            logger.debug("Produced gd solution vector {}", randomInitSolution);
            boolean z2 = false;
            int i2 = 0;
            while (true) {
                if (i2 >= randomInitSolution.getLength()) {
                    break;
                }
                if (Double.valueOf(randomInitSolution.get(i2)).equals(Double.valueOf(Double.NaN))) {
                    z2 = LIMIT;
                    break;
                }
                i2 += LIMIT;
            }
            if (z2) {
                randomInitSolution = doubleVector;
                cost = d2;
                if (d > 1.0E-20d) {
                    d /= 2.0d;
                }
            } else {
                vector2matrices = vector2matrices(randomInitSolution, columns, LIMIT, columns2, LIMIT);
                cost = getCost((INDArray) vector2matrices.getX(), (INDArray) vector2matrices.getY());
                if (d2 <= cost) {
                    randomInitSolution = doubleVector;
                    cost = d2;
                    if (d2 == cost) {
                        i += LIMIT;
                        d *= 2.0d;
                    } else if (d > 1.0E-20d) {
                        d /= 2.0d;
                    }
                    if (i > 10) {
                        System.out.println("No further improvement, canceling");
                        break;
                    }
                } else {
                    if (!z) {
                        z = LIMIT;
                    }
                    i = 0;
                    d *= 2.0d;
                }
                d = Math.min(d, ALPHA_MAX);
                System.out.println(cost + " (alpha = " + d + ")");
            }
            if (!z) {
                randomInitSolution = getRandomInitSolution(columns, columns2, LIMIT);
                vector2matrices = vector2matrices(randomInitSolution, columns, LIMIT, columns2, LIMIT);
                cost = getCost((INDArray) vector2matrices.getX(), (INDArray) vector2matrices.getY());
                d = 1.0E-9d;
                logger.info("Rebooting approach with solution vector {} that has cost {}", randomInitSolution, Double.valueOf(cost));
            }
        }
        this.U = (INDArray) vector2matrices.getX();
        this.V = (INDArray) vector2matrices.getY();
        System.out.println("Finished learning");
        System.out.println("U = " + this.U);
        System.out.println("V = " + this.V);
    }

    private DoubleVector getRandomInitSolution(int i, int i2, int i3) {
        double[] dArr = new double[(i + i2) * i3];
        int i4 = 0;
        for (int i5 = 0; i5 < i; i5 += LIMIT) {
            for (int i6 = 0; i6 < i3; i6 += LIMIT) {
                int i7 = i4;
                i4 += LIMIT;
                dArr[i7] = (Math.random() - 0.5d) * 100.0d;
            }
        }
        for (int i8 = 0; i8 < i2; i8 += LIMIT) {
            for (int i9 = 0; i9 < i3; i9 += LIMIT) {
                int i10 = i4;
                i4 += LIMIT;
                dArr[i10] = (Math.random() - 0.5d) * 100.0d;
            }
        }
        return new DenseDoubleVector(dArr);
    }

    public INDArray vector2matrix(DoubleVector doubleVector, int i, int i2) {
        double[] dArr = new double[doubleVector.getLength()];
        for (int i3 = 0; i3 < doubleVector.getLength(); i3 += LIMIT) {
            dArr[i3] = doubleVector.get(i3);
        }
        return Nd4j.create(dArr, new int[]{i, i2});
    }

    public SetUtil.Pair<INDArray, INDArray> vector2matrices(DoubleVector doubleVector, int i, int i2, int i3, int i4) {
        DoubleVector sliceByLength = doubleVector.sliceByLength(0, i * i2);
        return new SetUtil.Pair<>(vector2matrix(sliceByLength, i, i2), vector2matrix(doubleVector.sliceByLength(i * i2, doubleVector.getLength() - sliceByLength.getLength()), i3, i4));
    }

    public DoubleVector matrix2vector(INDArray iNDArray) {
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        double[] dArr = new double[rows * columns];
        int i = 0;
        for (int i2 = 0; i2 < rows; i2 += LIMIT) {
            for (int i3 = 0; i3 < columns; i3 += LIMIT) {
                int i4 = i;
                i += LIMIT;
                dArr[i4] = iNDArray.getDouble(i2, i3);
            }
        }
        return new DenseDoubleVector(dArr);
    }

    public DoubleVector matrices2vector(INDArray... iNDArrayArr) {
        ArrayList<DoubleVector> arrayList = new ArrayList();
        int i = 0;
        int length = iNDArrayArr.length;
        for (int i2 = 0; i2 < length; i2 += LIMIT) {
            DoubleVector matrix2vector = matrix2vector(iNDArrayArr[i2]);
            arrayList.add(matrix2vector);
            i += matrix2vector.getLength();
        }
        double[] dArr = new double[i];
        int i3 = 0;
        for (DoubleVector doubleVector : arrayList) {
            for (int i4 = 0; i4 < doubleVector.getLength(); i4 += LIMIT) {
                int i5 = i3;
                i3 += LIMIT;
                dArr[i5] = doubleVector.get(i4);
            }
        }
        return new DenseDoubleVector(dArr);
    }

    public double getCost(INDArray iNDArray, INDArray iNDArray2) {
        return getSquaredFrobeniusNorm(this.R.sub(this.X.mmul(iNDArray).mmul(this.W.mmul(iNDArray2).transpose()))) + (this.mu * (getSquaredFrobeniusNorm(iNDArray) + getSquaredFrobeniusNorm(iNDArray2)));
    }

    public double getSquaredFrobeniusNorm(INDArray iNDArray) {
        double d = 0.0d;
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        for (int i = 0; i < rows; i += LIMIT) {
            for (int i2 = 0; i2 < columns; i2 += LIMIT) {
                d += Math.pow(iNDArray.getDouble(i, i2), 2.0d);
            }
        }
        return d;
    }

    public INDArray getGradientAsMatrix(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        if (z) {
            int rows = iNDArray.rows();
            int columns = iNDArray.columns();
            float[][] fArr = new float[rows][columns];
            for (int i = 0; i < rows; i += LIMIT) {
                for (int i2 = 0; i2 < columns; i2 += LIMIT) {
                    fArr[i][i2] = getFirstDerivative(iNDArray, iNDArray2, i, i2, true);
                }
            }
            return Nd4j.create(fArr);
        }
        int rows2 = iNDArray2.rows();
        int columns2 = iNDArray2.columns();
        float[][] fArr2 = new float[rows2][columns2];
        for (int i3 = 0; i3 < rows2; i3 += LIMIT) {
            for (int i4 = 0; i4 < columns2; i4 += LIMIT) {
                fArr2[i3][i4] = getFirstDerivative(iNDArray, iNDArray2, i3, i4, false);
            }
        }
        return Nd4j.create(fArr2);
    }

    public float getFirstDerivative(INDArray iNDArray, INDArray iNDArray2, int i, int i2, boolean z) {
        float f;
        INDArray sub = this.R.sub(this.X.mmul(iNDArray).mmul(this.W.mmul(iNDArray2).transpose()));
        float f2 = 0.0f;
        int rows = sub.rows();
        int columns = sub.columns();
        if (!$assertionsDisabled && columns != this.W.rows()) {
            throw new AssertionError("W has " + this.W.rows() + " but is expected to have m = " + columns + " rows");
        }
        if (!$assertionsDisabled && i2 >= iNDArray2.columns()) {
            throw new AssertionError("V has only " + iNDArray2.columns() + " but would have to have " + (i2 + LIMIT) + " columns to proceed! I.e. deriving a derivative for t = " + i2 + " is not possible.");
        }
        if (z) {
            for (int i3 = 0; i3 < rows; i3 += LIMIT) {
                float f3 = this.X.getFloat(i3, i);
                for (int i4 = 0; i4 < columns; i4 += LIMIT) {
                    double d = sub.getFloat(i3, i4);
                    f2 = (float) (f2 - (((2.0d * d) * f3) * this.W.getRow(i4).mmul(iNDArray2.getColumn(i2)).getDouble(0, 0)));
                }
            }
            f = (float) (f2 + (2.0d * this.mu * iNDArray.getDouble(i, i2)));
        } else {
            for (int i5 = 0; i5 < rows; i5 += LIMIT) {
                double d2 = this.X.getRow(i5).mmul(iNDArray2.getColumn(i2)).getDouble(0, 0);
                for (int i6 = 0; i6 < columns; i6 += LIMIT) {
                    f2 = (float) (f2 - (((2.0d * sub.getFloat(i5, i6)) * this.W.getFloat(i6, i)) * d2));
                }
            }
            f = (float) (f2 + (2.0d * this.mu * iNDArray2.getDouble(i, i2)));
        }
        return f;
    }

    @Override // ai.libs.mlplan.metamining.similaritymeasures.IHeterogenousSimilarityMeasureComputer
    public double computeSimilarity(INDArray iNDArray, INDArray iNDArray2) {
        return MAX_DESIRED_ERROR;
    }

    static {
        $assertionsDisabled = !F3Optimizer.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(F3Optimizer.class);
    }
}
