package org.fbk.cit.hlt.core.lsa;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeSet;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.fbk.cit.hlt.core.math.Node;
import org.fbk.cit.hlt.thewikimachine.util.StringTable;
import org.fbk.cit.hlt.thewikimachine.xmldump.util.ParsedPageLink;

/* loaded from: input_file:org/fbk/cit/hlt/core/lsa/LSI.class */
public class LSI extends AbstractLSI {
    static Logger logger = Logger.getLogger(LSI.class.getName());

    public LSI(String str, int i, boolean z) throws IOException {
        super(str, i, z);
    }

    public LSI(String str, int i, boolean z, boolean z2) throws IOException {
        super(str, i, z, z2);
    }

    public LSI(File file, File file2, File file3, File file4, File file5, int i, boolean z) throws IOException {
        super(file, file2, file3, file4, file5, i, z);
    }

    public LSI(File file, File file2, File file3, File file4, File file5, int i, boolean z, boolean z2) throws IOException {
        super(file, file2, file3, file4, file5, i, z, z2);
    }

    public Node[] mapTerm(String str) throws TermNotFoundException {
        int i = this.termIndex.get(str);
        if (i == -1) {
            throw new TermNotFoundException(str);
        }
        Node[] nodeArr = new Node[this.Uk[i].length];
        for (int i2 = 0; i2 < this.Uk[i].length; i2++) {
            nodeArr[i2] = new Node(i2, this.Uk[i][i2]);
        }
        return nodeArr;
    }

    public Node[] mapDocumentOld(BOW bow, boolean z) {
        logger.info("mapDocument " + z);
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (String str : bow.termSet()) {
            int i2 = this.termIndex.get(str);
            if (i2 != -1) {
                int frequency = bow.getFrequency(str);
                double log2 = log2(frequency);
                if (z) {
                    log2 *= this.Iidf[i2];
                }
                logger.info(StringTable.HORIZONTAL_TABULATION + str + StringTable.HORIZONTAL_TABULATION + i2 + "\ttf= " + frequency + ParsedPageLink.START_SUFFIX_PATTERN + log2(frequency) + ")\tidf=" + this.Iidf[i2] + ParsedPageLink.START_SUFFIX_PATTERN + log2 + ")");
                arrayList.add(new Node(i2, log2));
            } else {
                logger.debug(i + StringTable.HORIZONTAL_TABULATION + str + StringTable.HORIZONTAL_TABULATION + i2);
            }
            i++;
        }
        Node[] nodeArr = (Node[]) arrayList.toArray(new Node[arrayList.size()]);
        Arrays.sort(nodeArr, new Comparator<Node>() { // from class: org.fbk.cit.hlt.core.lsa.LSI.1
            @Override // java.util.Comparator
            public int compare(Node node, Node node2) {
                double d = node2.index - node.index;
                if (d > 0.0d) {
                    return -1;
                }
                return d < 0.0d ? 1 : 0;
            }
        });
        return nodeArr;
    }

    public Node[] mapDocument(BOW bow) {
        TreeSet treeSet = new TreeSet();
        int i = 0;
        for (String str : bow.termSet()) {
            int i2 = this.termIndex.get(str);
            if (i2 != -1) {
                treeSet.add(new Node(i2, bow.tf(str) * this.Iidf[i2]));
            }
            i++;
        }
        return (Node[]) treeSet.toArray(new Node[treeSet.size()]);
    }

    public Node[] mapDocument(Map<String, Double> map) {
        TreeSet treeSet = new TreeSet();
        int i = 0;
        for (String str : map.keySet()) {
            int i2 = this.termIndex.get(str);
            if (i2 != -1) {
                treeSet.add(new Node(i2, map.get(str).doubleValue()));
            }
            i++;
        }
        return (Node[]) treeSet.toArray(new Node[treeSet.size()]);
    }

    public Node[] mapDocumentOld(BOW bow) {
        return mapDocument(bow);
    }

    public Node[] mapPseudoDocument(Node[] nodeArr) {
        Node[] nodeArr2 = new Node[this.Uk[0].length];
        for (int i = 0; i < this.Uk[0].length; i++) {
            nodeArr2[i] = new Node(i, 0.0d);
            for (int i2 = 0; i2 < nodeArr.length; i2++) {
                int i3 = nodeArr[i2].index;
                nodeArr2[i].value += this.Uk[i3][i] * nodeArr[i2].value;
            }
        }
        return nodeArr2;
    }

    public double compare(String str, String str2) throws TermNotFoundException {
        Node[] mapTerm = mapTerm(str);
        Node[] mapTerm2 = mapTerm(str2);
        return Node.dot(mapTerm, mapTerm2) / Math.sqrt(Node.dot(mapTerm, mapTerm) * Node.dot(mapTerm2, mapTerm2));
    }

    public double compare(BOW bow, BOW bow2) {
        Node[] mapDocument = mapDocument(bow);
        Node[] mapDocument2 = mapDocument(bow2);
        Node[] mapPseudoDocument = mapPseudoDocument(mapDocument);
        Node[] mapPseudoDocument2 = mapPseudoDocument(mapDocument2);
        return Node.dot(mapPseudoDocument, mapPseudoDocument2) / Math.sqrt(Node.dot(mapPseudoDocument, mapPseudoDocument) * Node.dot(mapPseudoDocument2, mapPseudoDocument2));
    }

    public void interactive() throws IOException {
        while (true) {
            logger.info("\nPlease write a query and type <return> to continue (CTRL C to exit):");
            String str = new BufferedReader(new InputStreamReader(System.in)).readLine().toString();
            if (str.contains(StringTable.HORIZONTAL_TABULATION)) {
                String[] split = str.split(StringTable.HORIZONTAL_TABULATION);
                long nanoTime = System.nanoTime();
                BOW bow = new BOW(split[0].toLowerCase());
                BOW bow2 = new BOW(split[1].toLowerCase());
                logger.info("bow1:" + bow);
                logger.info("bow2:" + bow2);
                logger.info("parsing time " + df.format(System.nanoTime() - nanoTime) + " ns");
                long nanoTime2 = System.nanoTime();
                Node[] mapDocument = mapDocument(bow);
                logger.info("d1:" + Arrays.toString(mapDocument));
                Node[] mapDocument2 = mapDocument(bow2);
                logger.info("d2:" + Arrays.toString(mapDocument2));
                Node[] mapPseudoDocument = mapPseudoDocument(mapDocument);
                logger.info("pd1:" + Arrays.toString(mapPseudoDocument));
                Node[] mapPseudoDocument2 = mapPseudoDocument(mapDocument2);
                logger.info("pd2:" + Arrays.toString(mapPseudoDocument2));
                double dot = Node.dot(mapDocument, mapDocument2) / Math.sqrt(Node.dot(mapDocument, mapDocument) * Node.dot(mapDocument2, mapDocument2));
                double dot2 = Node.dot(mapPseudoDocument, mapPseudoDocument2) / Math.sqrt(Node.dot(mapPseudoDocument, mapPseudoDocument) * Node.dot(mapPseudoDocument2, mapPseudoDocument2));
                logger.info("mapping time " + df.format(System.nanoTime() - nanoTime2) + " ns");
                logger.info("<\"" + split[0] + "\",\"" + split[1] + "\"> = " + dot2 + ParsedPageLink.START_SUFFIX_PATTERN + dot + ")");
            } else {
                try {
                    String lowerCase = str.toLowerCase();
                    logger.debug("query " + lowerCase);
                    long nanoTime3 = System.nanoTime();
                    ScoreTermMap scoreTermMap = new ScoreTermMap(lowerCase, 20);
                    Node[] mapTerm = mapTerm(lowerCase);
                    Iterator<String> terms = terms();
                    while (terms.hasNext()) {
                        String next = terms.next();
                        Node[] mapTerm2 = mapTerm(next);
                        scoreTermMap.put(Node.dot(mapTerm, mapTerm2) / Math.sqrt(Node.dot(mapTerm, mapTerm) * Node.dot(mapTerm2, mapTerm2)), next);
                    }
                    long nanoTime4 = System.nanoTime();
                    logger.info(scoreTermMap.toString());
                    logger.info("time required " + df.format(nanoTime4 - nanoTime3) + " ns");
                } catch (TermNotFoundException e) {
                    logger.error(e);
                }
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        String property = System.getProperty("log-config");
        if (property == null) {
            property = "log-config.txt";
        }
        long currentTimeMillis = System.currentTimeMillis();
        PropertyConfigurator.configure(property);
        if (strArr.length != 5) {
            logger.info(getHelp());
            System.exit(1);
        }
        File file = new File(strArr[0] + "-Ut");
        File file2 = new File(strArr[0] + "-S");
        File file3 = new File(strArr[0] + "-row");
        File file4 = new File(strArr[0] + "-col");
        File file5 = new File(strArr[0] + "-df");
        Double.parseDouble(strArr[1]);
        Integer.parseInt(strArr[2]);
        new LSI(file, file2, file3, file4, file5, Integer.parseInt(strArr[3]), Boolean.parseBoolean(strArr[4])).interactive();
        logger.info("term similarity calculated in " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    private static String getHelp() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Usage: java -cp dist/jcore.jar -mx2G org.fbk.cit.hlt.core.lsa.LSI input threshold size dim idf\n\n");
        stringBuffer.append("Arguments:\n");
        stringBuffer.append("\tinput\t\t-> root of files from which to read the model\n");
        stringBuffer.append("\tthreshold\t-> similarity threshold\n");
        stringBuffer.append("\tsize\t\t-> number of similar terms to return\n");
        stringBuffer.append("\tdim\t\t-> number of dimensions\n");
        stringBuffer.append("\tidf\t\t-> if true rescale using the idf\n");
        return stringBuffer.toString();
    }
}
