package eu.monnetproject.bliss.experiments;

import eu.monnetproject.bliss.CLIOpts;
import eu.monnetproject.bliss.ParallelBinarizedReader;
import eu.monnetproject.bliss.WordMap;
import eu.monnetproject.math.sparse.SparseIntArray;
import eu.monnetproject.math.sparse.SparseRealArray;
import it.unimi.dsi.fastutil.ints.IntIterator;
import java.io.File;
import java.io.PrintStream;

/* loaded from: input_file:eu/monnetproject/bliss/experiments/PerceptronNormalization.class */
public class PerceptronNormalization {
    public static void main(String[] strArr) throws Exception {
        CLIOpts cLIOpts = new CLIOpts(strArr);
        File roFile = cLIOpts.roFile("corpus", "The corpus");
        File roFile2 = cLIOpts.roFile("wordMap", "The word map");
        int intValue = cLIOpts.intValue("J", "The number of documents to handle");
        PrintStream outFileOrStdout = cLIOpts.outFileOrStdout();
        if (cLIOpts.verify(PerceptronNormalization.class)) {
            int calcW = WordMap.calcW(roFile2);
            double[][] dArr = new double[calcW][2];
            for (int i = 0; i < 2; i++) {
                for (int i2 = 1; i2 < calcW; i2++) {
                    dArr[i2][i] = 1.0d;
                }
            }
            ParallelBinarizedReader parallelBinarizedReader = new ParallelBinarizedReader(CLIOpts.openInputAsMaybeZipped(roFile));
            for (int i3 = 0; i3 < intValue; i3++) {
                ParallelBinarizedReader parallelBinarizedReader2 = new ParallelBinarizedReader(CLIOpts.openInputAsMaybeZipped(roFile));
                SparseIntArray[] nextFreqPair = parallelBinarizedReader.nextFreqPair(calcW);
                for (int i4 = 0; i4 < intValue; i4++) {
                    SparseIntArray[] nextFreqPair2 = parallelBinarizedReader2.nextFreqPair(calcW);
                    if (i3 != i4) {
                        for (int i5 = 0; i5 < 2; i5++) {
                            double d = 0.0d;
                            double d2 = 0.0d;
                            SparseRealArray sparseRealArray = new SparseRealArray(calcW);
                            SparseRealArray sparseRealArray2 = new SparseRealArray(calcW);
                            IntIterator it = nextFreqPair2[i5].keySet().iterator();
                            while (it.hasNext()) {
                                int intValue2 = ((Integer) it.next()).intValue();
                                double doubleValue = nextFreqPair2[i5].doubleValue(intValue2) * nextFreqPair[i5].doubleValue(intValue2);
                                d += doubleValue * dArr[intValue2][i5];
                                double doubleValue2 = nextFreqPair2[i5].doubleValue(intValue2) * nextFreqPair2[i5].doubleValue(intValue2);
                                d2 += doubleValue2 * dArr[intValue2][i5];
                                sparseRealArray.add(intValue2, doubleValue);
                                sparseRealArray2.add(intValue2, doubleValue2);
                            }
                            if (d2 != 0.0d) {
                                double abs = Math.abs(d / d2);
                                if (Double.isInfinite(abs)) {
                                    continue;
                                } else {
                                    IntIterator it2 = nextFreqPair2[i5].keySet().iterator();
                                    while (it2.hasNext()) {
                                        int intValue3 = ((Integer) it2.next()).intValue();
                                        double deltaValue = deltaValue(intValue3, i5, dArr, sparseRealArray.doubleValue(intValue3), sparseRealArray2.doubleValue(intValue3), d, d2) * abs * dArr[intValue3][i5];
                                        if (deltaValue > 0.1d || deltaValue < -0.1d) {
                                            deltaValue = Math.signum(deltaValue) * 0.1d;
                                        }
                                        double[] dArr2 = dArr[intValue3];
                                        int i6 = i5;
                                        dArr2[i6] = dArr2[i6] + deltaValue;
                                    }
                                    double d3 = 0.0d;
                                    for (int i7 = 1; i7 < calcW; i7++) {
                                        d3 += dArr[i7][i5] * dArr[i7][i5];
                                    }
                                    double sqrt = Math.sqrt(d3);
                                    if (sqrt == 0.0d || Double.isNaN(sqrt)) {
                                        throw new RuntimeException("Zero'ed the vector! " + sqrt);
                                    }
                                    for (int i8 = 1; i8 < calcW; i8++) {
                                        double[] dArr3 = dArr[i8];
                                        int i9 = i5;
                                        dArr3[i9] = dArr3[i9] / sqrt;
                                    }
                                }
                            }
                        }
                    }
                }
                parallelBinarizedReader2.close();
                System.err.print(".");
            }
            System.err.println();
            parallelBinarizedReader.close();
            for (int i10 = 1; i10 < calcW; i10++) {
                outFileOrStdout.println(dArr[i10][0] + "," + dArr[i10][1]);
            }
            outFileOrStdout.flush();
            outFileOrStdout.close();
        }
    }

    private static double deltaValue(int i, int i2, double[][] dArr, double d, double d2, double d3, double d4) {
        double d5 = d3 - (d * dArr[i][i2]);
        double d6 = d4 - (d2 * dArr[i][i2]);
        if (d4 == 0.0d) {
            throw new RuntimeException("This shouldn't happen");
        }
        if (Math.abs((d2 * dArr[i][i2]) + d6) > 1.0E-30d) {
            return ((Math.signum(d3 / d4) / ((d2 * dArr[i][i2]) + d6)) / ((d2 * dArr[i][i2]) + d6)) * ((d2 * d5) - (d * d6));
        }
        return 0.0d;
    }
}
