package org.bigml.binding;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.math3.random.MersenneTwister;
import org.bigml.binding.resources.AbstractResource;
import org.bigml.binding.utils.Stemmer;
import org.bigml.binding.utils.StemmerInterface;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/bigml/binding/LocalTopicModel.class */
public class LocalTopicModel extends ModelFields implements Serializable {
    private static final long serialVersionUID = 1;
    static final int MAXIMUM_TERM_LENGTH = 30;
    static final int MIN_UPDATES = 16;
    static final int MAX_UPDATES = 512;
    static final int SAMPLES_PER_TOPIC = 128;
    Logger logger = LoggerFactory.getLogger(LocalTopicModel.class);
    private StemmerInterface stemmer;
    private long seed;
    private Boolean caseSensitive;
    private Boolean bigrams;
    private Integer ntopics;
    private Double[] temp;
    private Double[][] phi;
    private HashMap<String, Integer> termToIndex;
    private JSONArray topics;
    private Double alpha;
    private Double ktimesalpha;

    public LocalTopicModel(JSONObject jSONObject) throws Exception {
        this.caseSensitive = false;
        this.bigrams = false;
        if (jSONObject.get("resource") == null) {
            throw new Exception("Cannot create the topicModel instance. Could not find the 'resource' key in the resource");
        }
        if (jSONObject.containsKey("object") && (jSONObject.get("object") instanceof Map)) {
            jSONObject = (JSONObject) jSONObject.get("object");
        }
        if (!jSONObject.containsKey("topic_model") || !(jSONObject.get("topic_model") instanceof Map)) {
            throw new Exception(String.format("Cannot create the topic model instance. Could not find the 'topic_model' key in the resource:\n\n%s", ((JSONObject) jSONObject.get(AbstractResource.MODEL_PATH)).keySet()));
        }
        JSONObject jSONObject2 = (JSONObject) jSONObject.get("status");
        if (jSONObject2 == null || !jSONObject2.containsKey("code") || AbstractResource.FINISHED != ((Number) jSONObject2.get("code")).intValue()) {
            throw new Exception("The topicModel isn't finished yet");
        }
        JSONObject jSONObject3 = (JSONObject) Utils.getJSONObject(jSONObject, "topic_model");
        this.topics = (JSONArray) jSONObject3.get("topics");
        this.stemmer = Stemmer.getStemmer((String) jSONObject3.get("language"));
        JSONArray jSONArray = (JSONArray) jSONObject3.get("termset");
        this.termToIndex = new HashMap<>();
        for (int i = 0; i < jSONArray.size(); i++) {
            this.termToIndex.put(stem((String) jSONArray.get(i)), Integer.valueOf(i));
        }
        JSONArray jSONArray2 = (JSONArray) jSONObject3.get("term_topic_assignments");
        this.seed = Math.abs(((Long) jSONObject3.get("hashed_seed")).longValue());
        this.caseSensitive = (Boolean) jSONObject3.get("case_sensitive");
        this.bigrams = (Boolean) jSONObject3.get("bigrams");
        this.ntopics = Integer.valueOf(((JSONArray) jSONArray2.get(0)).size());
        this.alpha = (Double) jSONObject3.get("alpha");
        this.ktimesalpha = Double.valueOf(this.ntopics.intValue() * this.alpha.doubleValue());
        Double d = (Double) jSONObject3.get("beta");
        this.temp = new Double[this.ntopics.intValue()];
        for (int i2 = 0; i2 < this.ntopics.intValue(); i2++) {
            this.temp[i2] = Double.valueOf(0.0d);
        }
        int size = this.termToIndex.size();
        Long[] lArr = new Long[this.ntopics.intValue()];
        this.phi = new Double[this.ntopics.intValue()][size];
        for (int i3 = 0; i3 < this.ntopics.intValue(); i3++) {
            long j = 0;
            for (int i4 = 0; i4 < jSONArray2.size(); i4++) {
                j += ((Long) ((JSONArray) jSONArray2.get(i4)).get(i3)).longValue();
            }
            lArr[i3] = Long.valueOf(j);
        }
        for (int i5 = 0; i5 < this.ntopics.intValue(); i5++) {
            for (int i6 = 0; i6 < size; i6++) {
                this.phi[i5][i6] = Double.valueOf(0.0d);
            }
        }
        for (int i7 = 0; i7 < this.ntopics.intValue(); i7++) {
            Double valueOf = Double.valueOf(lArr[i7].longValue() + (size * d.doubleValue()));
            for (int i8 = 0; i8 < size; i8++) {
                this.phi[i7][i8] = Double.valueOf((((Long) ((JSONArray) jSONArray2.get(i8)).get(i7)).longValue() + d.doubleValue()) / valueOf.doubleValue());
            }
        }
        super.initialize((JSONObject) jSONObject3.get("fields"), null, null, null);
    }

    public ArrayList<HashMap<String, Object>> distribution(JSONObject jSONObject) throws Exception {
        JSONObject filterInputData = filterInputData(jSONObject);
        StringBuilder sb = new StringBuilder();
        Iterator it = filterInputData.keySet().iterator();
        while (it.hasNext()) {
            sb.append(filterInputData.get(it.next()).toString());
            sb.append(" ");
        }
        return distributionForText(sb.toString());
    }

    public ArrayList<HashMap<String, Object>> distributionForText(String str) throws Exception {
        Double[] infer = infer(tokenize(str));
        ArrayList<HashMap<String, Object>> arrayList = new ArrayList<>();
        for (int i = 0; i < infer.length; i++) {
            HashMap<String, Object> hashMap = new HashMap<>();
            hashMap.put("name", (String) ((JSONObject) this.topics.get(i)).get("name"));
            hashMap.put("probability", infer[i]);
            arrayList.add(hashMap);
        }
        return arrayList;
    }

    private String stem(String str) {
        return this.stemmer.getStem(str);
    }

    private void appendBigram(ArrayList<Integer> arrayList, String str, String str2) {
        if (this.bigrams == null || str == null || str2 == null) {
            return;
        }
        String stem = stem(str + " " + str2);
        if (this.termToIndex.containsKey(stem)) {
            arrayList.add(this.termToIndex.get(stem));
        }
    }

    private ArrayList<Integer> tokenize(String str) {
        ArrayList<Integer> arrayList = new ArrayList<>();
        String str2 = null;
        String str3 = null;
        boolean z = false;
        int i = 0;
        int length = str.length();
        while (i < length) {
            appendBigram(arrayList, str3, str2);
            char charAt = str.charAt(i);
            StringBuilder sb = new StringBuilder();
            boolean z2 = false;
            if (!Character.isLetterOrDigit(charAt)) {
                z2 = true;
            }
            while (!Character.isLetterOrDigit(charAt) && i < length) {
                i++;
                charAt = i < length ? str.charAt(i) : (char) 0;
            }
            while (i < length && ((Character.isLetterOrDigit(charAt) || charAt == '\'') && sb.toString().length() < MAXIMUM_TERM_LENGTH)) {
                sb.append(charAt);
                i++;
                charAt = i < length ? str.charAt(i) : (char) 0;
            }
            if (sb.toString().length() > 0) {
                String sb2 = sb.toString();
                if (!this.caseSensitive.booleanValue()) {
                    sb2 = sb2.toLowerCase();
                }
                str3 = (!z || z2) ? null : str2;
                str2 = sb2;
                if (charAt == ' ' || charAt == '\n') {
                    z = true;
                }
                String stem = stem(sb2);
                if (this.termToIndex.containsKey(stem)) {
                    arrayList.add(this.termToIndex.get(stem));
                }
                i++;
            }
        }
        appendBigram(arrayList, str3, str2);
        return arrayList;
    }

    private Integer[] sampleTopics(ArrayList<Integer> arrayList, Integer[] numArr, double d, int i, MersenneTwister mersenneTwister) {
        Integer[] numArr2 = new Integer[this.ntopics.intValue()];
        for (int i2 = 0; i2 < this.ntopics.intValue(); i2++) {
            numArr2[i2] = 0;
        }
        for (int i3 = 0; i3 < i; i3++) {
            Iterator<Integer> it = arrayList.iterator();
            while (it.hasNext()) {
                Integer next = it.next();
                for (int i4 = 0; i4 < this.ntopics.intValue(); i4++) {
                    this.temp[i4] = Double.valueOf(this.phi[i4][next.intValue()].doubleValue() * ((numArr[i4].intValue() + this.alpha.doubleValue()) / d));
                }
                for (int i5 = 1; i5 < this.ntopics.intValue(); i5++) {
                    Double[] dArr = this.temp;
                    int i6 = i5;
                    dArr[i6] = Double.valueOf(dArr[i6].doubleValue() + this.temp[i5 - 1].doubleValue());
                }
                double nextDouble = mersenneTwister.nextDouble() * this.temp[this.temp.length - 1].doubleValue();
                int i7 = 0;
                while (this.temp[i7].doubleValue() < nextDouble && i7 < this.ntopics.intValue()) {
                    i7++;
                }
                int i8 = i7;
                numArr2[i8] = Integer.valueOf(numArr2[i8].intValue() + 1);
            }
        }
        return numArr2;
    }

    private Integer[] sampleUniform(ArrayList<Integer> arrayList, int i, MersenneTwister mersenneTwister) {
        Integer[] numArr = new Integer[this.ntopics.intValue()];
        for (int i2 = 0; i2 < this.ntopics.intValue(); i2++) {
            numArr[i2] = 0;
        }
        for (int i3 = 0; i3 < i; i3++) {
            Iterator<Integer> it = arrayList.iterator();
            while (it.hasNext()) {
                Integer next = it.next();
                for (int i4 = 0; i4 < this.ntopics.intValue(); i4++) {
                    this.temp[i4] = this.phi[i4][next.intValue()];
                }
                for (int i5 = 1; i5 < this.ntopics.intValue(); i5++) {
                    Double[] dArr = this.temp;
                    int i6 = i5;
                    dArr[i6] = Double.valueOf(dArr[i6].doubleValue() + this.temp[i5 - 1].doubleValue());
                }
                double nextDouble = mersenneTwister.nextDouble() * this.temp[this.temp.length - 1].doubleValue();
                int i7 = 0;
                while (this.temp[i7].doubleValue() < nextDouble && nextDouble < this.ntopics.intValue()) {
                    i7++;
                }
                int i8 = i7;
                numArr[i8] = Integer.valueOf(numArr[i8].intValue() + 1);
            }
        }
        return numArr;
    }

    private Double[] infer(ArrayList<Integer> arrayList) {
        Collections.sort(arrayList);
        int min = arrayList.size() > 0 ? Math.min(MAX_UPDATES, Math.max(MIN_UPDATES, (SAMPLES_PER_TOPIC * this.ntopics.intValue()) / arrayList.size())) : 0;
        MersenneTwister mersenneTwister = new MersenneTwister(new int[]{(int) this.seed});
        double size = (arrayList.size() * min) + this.ktimesalpha.doubleValue();
        Integer[] sampleTopics = sampleTopics(arrayList, sampleTopics(arrayList, sampleUniform(arrayList, min, mersenneTwister), size, min, mersenneTwister), size, min, mersenneTwister);
        Double[] dArr = new Double[this.ntopics.intValue()];
        for (int i = 0; i < this.ntopics.intValue(); i++) {
            dArr[i] = Double.valueOf((sampleTopics[i].intValue() + this.alpha.doubleValue()) / size);
        }
        return dArr;
    }
}
