package org.fnlp.ml.nmf;

import gnu.trove.iterator.TLongFloatIterator;
import java.util.Vector;
import org.fnlp.ml.types.sv.SparseMatrix;

/* loaded from: input_file:org/fnlp/ml/nmf/Nmf.class */
public class Nmf {
    int max_iter;
    float lambda;
    int m;
    int n;
    int r;
    float eps = 1.0E-10f;
    SparseMatrix v;
    SparseMatrix w;
    SparseMatrix h;

    public Nmf(int i, float f, int i2, SparseMatrix sparseMatrix) {
        this.max_iter = i;
        this.lambda = f;
        this.r = i2;
        this.m = sparseMatrix.size()[0];
        this.n = sparseMatrix.size()[1];
        this.v = sparseMatrix;
        int[] iArr = {this.m, i2};
        int[] iArr2 = {i2, this.n};
        this.w = SparseMatrix.random(iArr);
        this.h = SparseMatrix.random(iArr2);
    }

    float computeObjective(SparseMatrix sparseMatrix, SparseMatrix sparseMatrix2, SparseMatrix sparseMatrix3) {
        SparseMatrix mutiplyMatrix = sparseMatrix2.mutiplyMatrix(sparseMatrix3);
        SparseMatrix m10clone = sparseMatrix.m10clone();
        m10clone.minus(mutiplyMatrix);
        return m10clone.l2Norm();
    }

    SparseMatrix updateH() {
        int[] iArr = {this.m, this.n};
        int[] iArr2 = {this.m, this.n};
        int[] iArr3 = {this.r, this.n};
        new SparseMatrix(iArr);
        SparseMatrix sparseMatrix = new SparseMatrix(iArr2);
        SparseMatrix sparseMatrix2 = new SparseMatrix(iArr3);
        SparseMatrix mutiplyMatrix = this.w.mutiplyMatrix(this.h);
        TLongFloatIterator it = this.v.vector.iterator();
        TLongFloatIterator it2 = this.h.vector.iterator();
        int size = this.v.vector.size();
        while (true) {
            int i = size;
            size--;
            if (i <= 0) {
                break;
            }
            it.advance();
            sparseMatrix.set(it.key(), it.value() / (mutiplyMatrix.elementAt(it.key()) + this.eps));
        }
        SparseMatrix mutiplyMatrix2 = this.w.trans().mutiplyMatrix(sparseMatrix);
        int size2 = this.h.vector.size();
        while (true) {
            int i2 = size2;
            size2--;
            if (i2 <= 0) {
                return sparseMatrix2;
            }
            it2.advance();
            sparseMatrix2.set(it2.key(), it2.value() * mutiplyMatrix2.elementAt(it2.key()));
        }
    }

    SparseMatrix updateW() {
        int[] iArr = {this.m, this.n};
        int[] iArr2 = {this.m, this.r};
        SparseMatrix sparseMatrix = new SparseMatrix(iArr);
        SparseMatrix sparseMatrix2 = new SparseMatrix(iArr2);
        SparseMatrix mutiplyMatrix = this.w.mutiplyMatrix(this.h);
        TLongFloatIterator it = this.v.vector.iterator();
        TLongFloatIterator it2 = this.w.vector.iterator();
        int size = this.v.vector.size();
        while (true) {
            int i = size;
            size--;
            if (i <= 0) {
                break;
            }
            it.advance();
            sparseMatrix.set(it.key(), it.value() / (mutiplyMatrix.elementAt(it.key()) + this.eps));
        }
        SparseMatrix mutiplyMatrix2 = sparseMatrix.mutiplyMatrix(this.h.trans());
        int size2 = this.w.vector.size();
        while (true) {
            int i2 = size2;
            size2--;
            if (i2 <= 0) {
                return sparseMatrix2;
            }
            it2.advance();
            sparseMatrix2.set(it2.key(), it2.value() * mutiplyMatrix2.elementAt(it2.key()));
        }
    }

    SparseMatrix normalized(SparseMatrix sparseMatrix) {
        float[] fArr = new float[sparseMatrix.size()[1]];
        TLongFloatIterator it = sparseMatrix.vector.iterator();
        int size = sparseMatrix.vector.size();
        while (true) {
            int i = size;
            size--;
            if (i <= 0) {
                break;
            }
            it.advance();
            int i2 = sparseMatrix.getIndices(it.key())[1];
            fArr[i2] = fArr[i2] + it.value();
        }
        TLongFloatIterator it2 = sparseMatrix.vector.iterator();
        int size2 = sparseMatrix.vector.size();
        while (true) {
            int i3 = size2;
            size2--;
            if (i3 <= 0) {
                return sparseMatrix;
            }
            it2.advance();
            sparseMatrix.set(it2.key(), it2.value() / (fArr[sparseMatrix.getIndices(it2.key())[1]] + this.eps));
        }
    }

    void calc() {
        int[] iArr = {this.m, this.r};
        int[] iArr2 = {this.r, this.n};
        this.w = SparseMatrix.random(iArr);
        this.w = normalized(this.w);
        this.h = SparseMatrix.random(iArr2);
        float computeObjective = computeObjective(this.v, this.w, this.h);
        for (int i = 1; i <= this.max_iter; i++) {
            this.h = updateH();
            this.w = updateW();
            this.w = normalized(this.w);
            float computeObjective2 = computeObjective(this.v, this.w, this.h);
            float f = computeObjective2 - computeObjective;
            System.out.printf("k = %d; obj=%f\t改变：%f\n", Integer.valueOf(i), Float.valueOf(computeObjective), Float.valueOf(f));
            if (Math.abs(f) <= this.lambda) {
                return;
            }
            computeObjective = computeObjective2;
        }
    }

    public static void main(String[] strArr) {
        int[] iArr = {10, 10};
        SparseMatrix sparseMatrix = new SparseMatrix(iArr);
        Vector vector = new Vector();
        for (int i = 0; i < iArr[0]; i++) {
            for (int i2 = 0; i2 < iArr[1]; i2++) {
                vector.add(new int[]{i2, i});
            }
        }
        for (int i3 = 0; i3 < vector.size(); i3++) {
            sparseMatrix.set((int[]) vector.get(i3), i3);
        }
        System.out.print("矩阵初始化结束\n");
        Long valueOf = Long.valueOf(System.currentTimeMillis());
        new Nmf(1000, 1.0E-4f, 5, sparseMatrix).calc();
        System.out.println("程序共计运行 " + (Long.valueOf(System.currentTimeMillis()).longValue() - valueOf.longValue()) + " 毫秒");
    }
}
