package org.fbk.cit.hlt.core.lsa;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.DecimalFormat;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.fbk.cit.hlt.core.lsa.Vocabulary;
import org.fbk.cit.hlt.core.lsa.io.DenseBinaryMatrixFileReader;
import org.fbk.cit.hlt.core.lsa.io.DenseTextVectorFileReader;
import org.fbk.cit.hlt.core.math.DoubleVector;
import org.fbk.cit.hlt.thewikimachine.util.StringTable;
import org.fbk.cit.hlt.thewikimachine.xmldump.util.ParsedPageLink;

/* loaded from: input_file:org/fbk/cit/hlt/core/lsa/DoubleLSA.class */
public class DoubleLSA {
    protected double[][] Uk;
    protected double[] Sk;
    protected double[] Iidf;
    protected Index termIndex;
    protected Index documentIndex;
    protected int documentNumber;
    private int dim;
    static Logger logger = Logger.getLogger(DoubleLSA.class.getName());
    protected static DecimalFormat df = new DecimalFormat("000,000,000.#");
    public static final double LOG2 = Math.log(2.0d);

    public DoubleLSA(File file, File file2, File file3, File file4, File file5, int i, boolean z) throws IOException {
        this(file, file2, file3, file4, file5, i, z, false);
    }

    public DoubleLSA(File file, File file2, File file3, File file4, File file5, int i, boolean z, boolean z2) throws IOException {
        init(file, file2, file3, file4, file5, i, z, z2);
    }

    public DoubleLSA(String str, int i, boolean z) throws IOException {
        this(str, i, z, false);
    }

    public DoubleLSA(String str, int i, boolean z, boolean z2) throws IOException {
        this.dim = i;
        init(new File(str + "-Ut"), new File(str + "-S"), new File(str + "-row"), new File(str + "-col"), new File(str + "-df"), i, z, z2);
    }

    private void init(File file, File file2, File file3, File file4, File file5, int i, boolean z, boolean z2) throws IOException {
        this.dim = i;
        logger.info("reading term index from " + file3 + "...");
        this.termIndex = new Index();
        this.termIndex.read(new InputStreamReader(new FileInputStream(file3), "UTF-8"));
        logger.info("reading document index from " + file4 + "...");
        this.documentIndex = new Index();
        this.documentIndex.read(new InputStreamReader(new FileInputStream(file4), "UTF-8"));
        int size = this.documentIndex.itemSet().size();
        this.documentNumber = size;
        logger.info(size + " documents");
        logger.info("reading term frequency from " + file5 + "...");
        Vocabulary vocabulary = new Vocabulary();
        vocabulary.read(new InputStreamReader(new FileInputStream(file5), "UTF-8"));
        createIdf(vocabulary, size);
        logger.info("reading S matrix from " + file2 + "...");
        this.Sk = new DenseTextVectorFileReader(file2, i).readDouble();
        logger.info("Sk[" + this.Sk.length + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
        logger.info("reading Uk matrix from " + file + "...");
        this.Uk = new DenseBinaryMatrixFileReader(file, i).readDouble(true);
        rescale();
        print("Uk rescaled");
        if (z) {
            idf();
            print("Uk idf");
        }
        if (z2) {
            normalize();
            print("Uk norm");
        }
    }

    public int termCount() {
        return this.termIndex.size();
    }

    public int documentCount() {
        return this.documentNumber;
    }

    public int getDimension() {
        return this.dim;
    }

    protected void print(String str) {
        logger.info("\n" + str);
        if (this.Uk.length < 50 && this.Uk[0].length < 50) {
            for (int i = 0; i < this.Uk.length; i++) {
                for (int i2 = 0; i2 < this.Uk[i].length; i2++) {
                    if (i2 != 0) {
                        System.out.print(" ");
                    }
                    System.out.print(this.Uk[i][i2]);
                }
                System.out.print("\n");
            }
            return;
        }
        for (int i3 = 0; i3 < 3; i3++) {
            for (int i4 = 0; i4 < 3; i4++) {
                if (i4 != 0) {
                    System.out.print(StringTable.HORIZONTAL_TABULATION);
                }
                System.out.print(this.Uk[i3][i4]);
            }
            System.out.print("\t...\t");
            for (int length = this.Uk[i3].length - 3; length < this.Uk[i3].length; length++) {
                if (length != 0) {
                    System.out.print(StringTable.HORIZONTAL_TABULATION);
                }
                System.out.print(this.Uk[i3][length]);
            }
            System.out.print("\n");
        }
        System.out.print("...\n");
        for (int length2 = this.Uk.length - 3; length2 < this.Uk.length; length2++) {
            for (int i5 = 0; i5 < 3; i5++) {
                if (i5 != 0) {
                    System.out.print(StringTable.HORIZONTAL_TABULATION);
                }
                System.out.print(this.Uk[length2][i5]);
            }
            System.out.print("\t...\t");
            for (int length3 = this.Uk[length2].length - 3; length3 < this.Uk[length2].length; length3++) {
                if (length3 != 0) {
                    System.out.print(StringTable.HORIZONTAL_TABULATION);
                }
                System.out.print(this.Uk[length2][length3]);
            }
            System.out.print("\n");
        }
    }

    public double log2(double d) {
        return Math.log(d) / LOG2;
    }

    public double[] getVector(String str) {
        int i = this.termIndex.get(str);
        logger.debug(str + " " + i);
        if (i == -1) {
            return null;
        }
        return this.Uk[i];
    }

    private void createIdf(Vocabulary vocabulary, int i) {
        long currentTimeMillis = System.currentTimeMillis();
        logger.info("creating idf matrix...");
        this.Iidf = new double[vocabulary.entrySet().size()];
        for (Map.Entry entry : vocabulary.entrySet()) {
            String str = (String) entry.getKey();
            Vocabulary.TermFrequency termFrequency = (Vocabulary.TermFrequency) entry.getValue();
            this.Iidf[this.termIndex.get(str)] = log2(i / termFrequency.get());
        }
        logger.info("took " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    private void rescale() {
        long currentTimeMillis = System.currentTimeMillis();
        logger.info("rescale: Uk[" + this.Uk.length + " X " + this.Uk[0].length + "] * Sk[" + this.Uk[0].length + " X " + this.Uk[0].length + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
        for (int i = 0; i < this.Uk.length; i++) {
            for (int i2 = 0; i2 < this.Uk[i].length; i2++) {
                double[] dArr = this.Uk[i];
                int i3 = i2;
                dArr[i3] = dArr[i3] * this.Sk[i2];
            }
        }
        logger.info("took " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    private void normalize() {
        long currentTimeMillis = System.currentTimeMillis();
        logger.info("normalize: IN[" + this.Iidf.length + " X " + this.Iidf.length + "] * Uk[" + this.Uk.length + " X " + this.Uk[0].length + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
        for (int i = 0; i < this.Uk.length; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.Uk[i].length; i2++) {
                d += Math.pow(this.Uk[i][i2], 2.0d);
            }
            double sqrt = Math.sqrt(d);
            for (int i3 = 0; i3 < this.Uk[i].length; i3++) {
                double[] dArr = this.Uk[i];
                int i4 = i3;
                dArr[i4] = dArr[i4] / sqrt;
            }
        }
        logger.info("took " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    public Iterator<String> terms() {
        return this.termIndex.itemSet().iterator();
    }

    public Iterator<String> documents() {
        return this.documentIndex.itemSet().iterator();
    }

    public double getIdf(String str) {
        int i = this.termIndex.get(str);
        if (i == -1) {
            return 0.0d;
        }
        return this.Iidf[i];
    }

    private void idf() {
        long currentTimeMillis = System.currentTimeMillis();
        logger.info("idf: Iidf[" + this.Iidf.length + " X " + this.Iidf.length + "] * Uk[" + this.Uk.length + " X " + this.Uk[0].length + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
        for (int i = 0; i < this.Uk.length; i++) {
            for (int i2 = 0; i2 < this.Uk[i].length; i2++) {
                double[] dArr = this.Uk[i];
                int i3 = i2;
                dArr[i3] = dArr[i3] * this.Iidf[i];
            }
        }
        logger.info("took " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    public int termIndex(String str) throws TermNotFoundException {
        return this.termIndex.get(str);
    }

    public DoubleVector mapTerm(String str) throws TermNotFoundException {
        int i = this.termIndex.get(str);
        if (i == -1) {
            throw new TermNotFoundException(str);
        }
        return new DoubleVector(this.Uk[i]);
    }

    public DoubleVector mapDocument(BOW bow, boolean z) {
        int[] iArr = new int[bow.size()];
        double[] dArr = new double[bow.size()];
        int i = 0;
        Iterator<String> it = bow.termSet().iterator();
        int i2 = 0;
        while (it.hasNext()) {
            int i3 = this.termIndex.get(it.next());
            if (i3 != -1) {
                double log10 = 1.0d + Math.log10(bow.getFrequency(r0));
                if (z) {
                    log10 *= this.Iidf[i3];
                }
                iArr[i2] = i3;
                dArr[i2] = log10;
                i++;
            } else {
                iArr[i2] = i3;
                dArr[i2] = 0.0d;
            }
            i2++;
        }
        return new DoubleVector(iArr, dArr);
    }

    public DoubleVector mapDocument(BOW bow) {
        return mapDocument(bow, true);
    }

    public DoubleVector mapPseudoDocument(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(this.Uk[0].length);
        for (int i = 0; i < this.Uk[0].length; i++) {
            for (int i2 = 0; i2 < doubleVector.length(); i2++) {
                int i3 = doubleVector.indexes[i2];
                double[] dArr = doubleVector2.values;
                int i4 = i;
                dArr[i4] = dArr[i4] + (this.Uk[i3][i] * doubleVector.values[i2]);
            }
        }
        return doubleVector2;
    }

    public double compare(String str, String str2) throws TermNotFoundException {
        DoubleVector mapTerm = mapTerm(str);
        DoubleVector mapTerm2 = mapTerm(str2);
        return mapTerm.dot(mapTerm2) / Math.sqrt(mapTerm.dot(mapTerm) * mapTerm2.dot(mapTerm2));
    }

    public double compare(BOW bow, BOW bow2) {
        DoubleVector mapDocument = mapDocument(bow);
        DoubleVector mapDocument2 = mapDocument(bow2);
        DoubleVector mapPseudoDocument = mapPseudoDocument(mapDocument);
        DoubleVector mapPseudoDocument2 = mapPseudoDocument(mapDocument2);
        return mapPseudoDocument.dot(mapPseudoDocument2) / Math.sqrt(mapPseudoDocument.dot(mapPseudoDocument) * mapPseudoDocument2.dot(mapPseudoDocument2));
    }

    public void interactive() throws IOException {
        while (true) {
            logger.info("\nPlease write a query and type <return> to continue (CTRL C to exit):");
            String str = new BufferedReader(new InputStreamReader(System.in)).readLine().toString();
            if (str.contains(StringTable.HORIZONTAL_TABULATION)) {
                String[] split = str.split(StringTable.HORIZONTAL_TABULATION);
                long nanoTime = System.nanoTime();
                BOW bow = new BOW(split[0].toLowerCase().replaceAll("category:", StringTable.LOW_LINE).split("[_ ]"));
                BOW bow2 = new BOW(split[1].toLowerCase().replaceAll("category:", StringTable.LOW_LINE).split("[_ ]"));
                DoubleVector mapDocument = mapDocument(bow);
                DoubleVector mapDocument2 = mapDocument(bow2);
                DoubleVector mapPseudoDocument = mapPseudoDocument(mapDocument);
                DoubleVector mapPseudoDocument2 = mapPseudoDocument(mapDocument2);
                double dot = mapPseudoDocument.dot(mapDocument2) / Math.sqrt(mapDocument.dot(mapDocument) * mapDocument2.dot(mapDocument2));
                double dot2 = mapPseudoDocument.dot(mapPseudoDocument2) / Math.sqrt(mapPseudoDocument.dot(mapPseudoDocument) * mapPseudoDocument2.dot(mapPseudoDocument2));
                long nanoTime2 = System.nanoTime();
                logger.info("bow1:" + bow);
                logger.info("bow2:" + bow2);
                logger.info("time required " + df.format(nanoTime2 - nanoTime) + " ns");
                logger.info("<\"" + split[0] + "\",\"" + split[1] + "\"> = " + dot2 + ParsedPageLink.START_SUFFIX_PATTERN + dot + ")");
            } else {
                try {
                    String lowerCase = str.toLowerCase();
                    logger.debug("query " + lowerCase);
                    long nanoTime3 = System.nanoTime();
                    ScoreTermMap scoreTermMap = new ScoreTermMap(lowerCase, 20);
                    DoubleVector mapTerm = mapTerm(lowerCase);
                    Iterator<String> terms = terms();
                    while (terms.hasNext()) {
                        String next = terms.next();
                        DoubleVector mapTerm2 = mapTerm(next);
                        scoreTermMap.put(mapTerm.dot(mapTerm2) / Math.sqrt(mapTerm.dot(mapTerm) * mapTerm2.dot(mapTerm2)), next);
                    }
                    long nanoTime4 = System.nanoTime();
                    logger.info(scoreTermMap.toString());
                    logger.info("time required " + df.format(nanoTime4 - nanoTime3) + " ns");
                } catch (TermNotFoundException e) {
                    logger.error(e);
                }
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        String property = System.getProperty("log-config");
        if (property == null) {
            property = "log-config.txt";
        }
        long currentTimeMillis = System.currentTimeMillis();
        PropertyConfigurator.configure(property);
        if (strArr.length != 5) {
            logger.info(getHelp());
            System.exit(1);
        }
        File file = new File(strArr[0] + "-Ut");
        File file2 = new File(strArr[0] + "-S");
        File file3 = new File(strArr[0] + "-row");
        File file4 = new File(strArr[0] + "-col");
        File file5 = new File(strArr[0] + "-df");
        Double.parseDouble(strArr[1]);
        Integer.parseInt(strArr[2]);
        new DoubleLSA(file, file2, file3, file4, file5, Integer.parseInt(strArr[3]), Boolean.parseBoolean(strArr[4])).interactive();
        logger.info("term similarity calculated in " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    private static String getHelp() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Usage: java -cp dist/jcore.jar -mx2G org.fbk.cit.hlt.core.lsa.DoubleLSA input threshold size dim idf\n\n");
        stringBuffer.append("Arguments:\n");
        stringBuffer.append("\tinput\t\t-> root of files from which to read the model\n");
        stringBuffer.append("\tthreshold\t-> similarity threshold\n");
        stringBuffer.append("\tsize\t\t-> number of similar terms to return\n");
        stringBuffer.append("\tdim\t\t-> number of dimensions\n");
        stringBuffer.append("\tidf\t\t-> if true rescale using the idf\n");
        return stringBuffer.toString();
    }
}
