package com.wcohen.ss.abbvGapsHmm;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/wcohen/ss/abbvGapsHmm/AbbvGapsHMM.class */
public class AbbvGapsHMM {
    Matrix3D _alpha;
    Matrix3D _beta;
    private static final int MAX_ITERATIONS = 300;
    private String _modelParamsFile;
    private static Logger LOG = LoggerFactory.getLogger(AbbvGapsHMM.class);
    private static final Double CHANGE_THRESHOLD = Double.valueOf(0.01d);
    private String _tfIdfDataFile = null;
    private Double _dfWordThreshold = Double.valueOf(0.2d);
    private Map<String, Double> _commonWordDF = null;
    List<Double> _transitionCounters = new ArrayList();
    List<Double> _emissionCounters = new ArrayList();
    List<Double> _transitionParams = new ArrayList();
    List<Double> _emissionParams = new ArrayList();
    boolean _externalySet = false;
    List<Double> _stateStartProb = null;

    /* loaded from: input_file:com/wcohen/ss/abbvGapsHmm/AbbvGapsHMM$Emissions.class */
    public enum Emissions {
        e_DL_alphaNumeric_to_none,
        e_DL_nonAlphaNumeric_to_none,
        e_DL_word_to_none,
        e_D_alphaNumeric_to_none,
        e_D_word_to_none,
        e_D_none_to_nonAlphaNumeric,
        e_M_partialWord_to_letter,
        e_M_word_to_firstLetter,
        e_M_letter_to_letter,
        e_M_nonAlphaNumeric_to_none,
        e_M_commonWordDeletion,
        e_M_AND_to_symbol,
        e_M_one_to_1,
        e_M_two_to_2,
        e_M_three_to_3,
        e_M_four_to_4,
        e_M_five_to_5,
        e_M_six_to_6,
        e_M_seven_to_7,
        e_M_eight_to_8,
        e_M_nine_to_9,
        e_M_Silver_Ag,
        e_M_Gold_Au,
        e_M_Copper_Cu,
        e_M_Iron_Fe,
        e_M_Mercury_Hg,
        e_M_Potassium_K,
        e_M_Sodium_Na,
        e_M_Lead_Pb,
        e_M_Antimony_Sb,
        e_M_Tin_Sn,
        e_M_Tungsten_W,
        e_END_end
    }

    /* loaded from: input_file:com/wcohen/ss/abbvGapsHmm/AbbvGapsHMM$States.class */
    public enum States {
        S,
        DL,
        M,
        D,
        END
    }

    /* loaded from: input_file:com/wcohen/ss/abbvGapsHmm/AbbvGapsHMM$Transitions.class */
    public enum Transitions {
        t_DL_in,
        t_DL_to_M,
        t_M_in,
        t_M_to_D,
        t_M_to_END,
        t_D_in,
        t_D_to_M,
        t_D_to_END,
        t_S_to_M,
        t_S_to_DL
    }

    public AbbvGapsHMM() {
        this._modelParamsFile = null;
        this._modelParamsFile = null;
    }

    public AbbvGapsHMM(String str) {
        this._modelParamsFile = null;
        this._modelParamsFile = str;
    }

    public AbbvGapsHMM(String str, boolean z) {
        this._modelParamsFile = null;
        this._modelParamsFile = str;
    }

    public List<Double> getEmmisionParams() {
        return this._emissionParams;
    }

    public List<Double> getTransitionParams() {
        return this._transitionParams;
    }

    public boolean useTDIDF() {
        return this._tfIdfDataFile != null;
    }

    public Double getDF(String str) {
        if (this._commonWordDF.containsKey(str)) {
            return this._commonWordDF.get(str);
        }
        return null;
    }

    public void setTfIdfData(String str) throws IOException {
        this._tfIdfDataFile = str;
        this._commonWordDF = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                return;
            }
            String[] split = readLine.split(" ");
            String str2 = split[0];
            Double valueOf = Double.valueOf(Double.parseDouble(split[1]));
            if (valueOf.compareTo(this._dfWordThreshold) >= 0) {
                this._commonWordDF.put(str2, valueOf);
            }
        }
    }

    protected void initStartProbs() {
        if (this._stateStartProb != null) {
            return;
        }
        States[] values = States.values();
        this._stateStartProb = new ArrayList();
        for (States states : values) {
            if (states.name().equals("S")) {
                this._stateStartProb.add(Double.valueOf(1.0d));
            } else {
                this._stateStartProb.add(Double.valueOf(0.0d));
            }
        }
    }

    public void setParamFile(String str) {
        this._modelParamsFile = str;
    }

    public boolean train(List<List<Acronym>> list, List<Map<String, String>> list2) {
        if (loadModelParams()) {
            return true;
        }
        return trainCorpus(list, list2);
    }

    public boolean train(List<List<Acronym>> list, List<Map<String, String>> list2, boolean z) {
        return z ? trainCorpus(list, list2) : loadModelParams();
    }

    public void setStartingParams(List<Double> list, List<Double> list2) {
        this._emissionParams.clear();
        this._emissionParams.addAll(list);
        this._transitionParams.clear();
        this._transitionParams.addAll(list2);
        this._externalySet = true;
    }

    public void initModelParamsAndCounters() {
        Emissions[] values = Emissions.values();
        this._emissionCounters.clear();
        if (!this._externalySet) {
            this._emissionParams.clear();
        }
        for (int i = 0; i < values.length; i++) {
            this._emissionCounters.add(Double.valueOf(0.0d));
            this._emissionParams.add(Double.valueOf(0.5d));
        }
        Transitions[] values2 = Transitions.values();
        this._transitionCounters.clear();
        if (!this._externalySet) {
            this._transitionParams.clear();
        }
        for (int i2 = 0; i2 < values2.length; i2++) {
            this._transitionCounters.add(Double.valueOf(0.0d));
            this._transitionParams.add(Double.valueOf(0.5d));
        }
        this._emissionParams.set(Emissions.e_END_end.ordinal(), Double.valueOf(1.0d));
    }

    protected boolean trainCorpus(List<List<Acronym>> list, List<Map<String, String>> list2) {
        boolean z = false;
        List list3 = null;
        initModelParamsAndCounters();
        int size = list.size();
        int i = 1;
        LOG.debug("Training abbreviations...");
        while (!z) {
            for (int i2 = 0; i2 < size; i2++) {
                List<Acronym> list4 = list.get(i2);
                Map map = 0 != 0 ? (Map) list3.get(i2) : null;
                int size2 = list4.size();
                for (int i3 = 0; i3 < size2; i3++) {
                    Acronym acronym = list4.get(i3);
                    if (0 != 0) {
                        expectationStep(acronym, (String) map.get(acronym._shortForm));
                    } else {
                        expectationStep(acronym, null);
                    }
                }
            }
            Double maximizationStep = maximizationStep();
            LOG.debug("step {}...", Integer.valueOf(i));
            i++;
            if (i > MAX_ITERATIONS) {
                System.out.println("\n\tTraining stopped after " + (i - 1) + " iterations with final change: " + maximizationStep);
                z = true;
            }
            if (maximizationStep.compareTo(CHANGE_THRESHOLD) < 0) {
                LOG.debug("Training converged in " + (i - 1) + " iterations.");
                z = true;
            }
        }
        saveModelParams();
        return true;
    }

    protected void expectationStep(Acronym acronym, String str) {
        AbbvGapsHmmBackwardsEvaluator abbvGapsHmmBackwardsEvaluator = new AbbvGapsHmmBackwardsEvaluator(this);
        abbvGapsHmmBackwardsEvaluator.backwardEvaluate(acronym, this._transitionParams, this._emissionParams);
        this._beta = abbvGapsHmmBackwardsEvaluator.getEvalMatrix();
        if (this._beta.at(0, 0, States.S.ordinal()) == 0.0d) {
            return;
        }
        AbbvGapsHmmForwardEvaluator abbvGapsHmmForwardEvaluator = new AbbvGapsHmmForwardEvaluator(this);
        abbvGapsHmmForwardEvaluator.forwardEvaluate(acronym, this._transitionParams, this._emissionParams);
        this._alpha = abbvGapsHmmForwardEvaluator.getEvalMatrix();
        AbbvGapsHmmExpectationEvaluator abbvGapsHmmExpectationEvaluator = new AbbvGapsHmmExpectationEvaluator(this);
        abbvGapsHmmExpectationEvaluator.expectationEvaluate(acronym, this._transitionCounters, this._emissionCounters, this._transitionParams, this._emissionParams, this._alpha, this._beta);
        this._transitionCounters = abbvGapsHmmExpectationEvaluator.getTransitionCounters();
        this._emissionCounters = abbvGapsHmmExpectationEvaluator.getEmissionCounters();
    }

    public AbbreviationAlignmentContainer<Emissions, States> viterbi(Acronym acronym) {
        return new AbbvGapsHmmBackwardsViterbiEvaluator(this).backwardViterbiEvaluate(acronym, this._transitionParams, this._emissionParams);
    }

    public void saveModelParams() {
        if (this._modelParamsFile == null) {
            return;
        }
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this._modelParamsFile));
            bufferedWriter.write("# Emmisions\n");
            Emissions[] values = Emissions.values();
            for (int i = 0; i < values.length; i++) {
                bufferedWriter.write(values[i].toString() + "\t" + this._emissionParams.get(i) + "\n");
            }
            bufferedWriter.write("# Transitions\n");
            Transitions[] values2 = Transitions.values();
            for (int i2 = 0; i2 < values2.length; i2++) {
                bufferedWriter.write(values2[i2].toString() + "\t" + this._transitionParams.get(i2) + "\n");
            }
            bufferedWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public boolean loadModelParams() {
        try {
            if (this._modelParamsFile == null || !new File(this._modelParamsFile).exists()) {
                return false;
            }
            BufferedReader bufferedReader = new BufferedReader(new FileReader(this._modelParamsFile));
            this._emissionParams.clear();
            this._transitionParams.clear();
            Emissions[] values = Emissions.values();
            int i = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                if (!readLine.startsWith("#")) {
                    String[] split = readLine.split("\t");
                    if (i < values.length) {
                        this._emissionParams.add(Double.valueOf(Double.parseDouble(split[1])));
                        i++;
                    } else {
                        this._transitionParams.add(Double.valueOf(Double.parseDouble(split[1])));
                    }
                }
            }
            bufferedReader.close();
            return this._transitionParams.size() == Transitions.values().length;
        } catch (IOException e) {
            this._emissionParams.clear();
            this._transitionParams.clear();
            e.printStackTrace();
            return false;
        }
    }

    protected Double maximizationStep() {
        return Double.valueOf(Double.valueOf(Double.valueOf(0.0d).doubleValue() + maximizationStepForEmissions().doubleValue()).doubleValue() + maximizationStepForTransitions().doubleValue());
    }

    protected Double maximizationStepForTransitions() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        Transitions[] values = Transitions.values();
        for (int i = 0; i < values.length; i++) {
            String name = values[i].name();
            if (name.startsWith("t_DL_")) {
                d3 += smoothCounter(i, this._transitionCounters, this._transitionParams);
            } else if (name.startsWith("t_M_")) {
                d += smoothCounter(i, this._transitionCounters, this._transitionParams);
            } else if (name.startsWith("t_D_")) {
                d2 += smoothCounter(i, this._transitionCounters, this._transitionParams);
            } else if (name.startsWith("t_S_")) {
                d4 += smoothCounter(i, this._transitionCounters, this._transitionParams);
            } else if (name.startsWith("t_I_")) {
                d5 += smoothCounter(i, this._transitionCounters, this._transitionParams);
            }
        }
        Double valueOf = Double.valueOf(0.0d);
        for (int i2 = 0; i2 < values.length; i2++) {
            String name2 = values[i2].name();
            Double d6 = name2.startsWith("t_DL_") ? new Double(getNewStateVal(smoothCounter(i2, this._transitionCounters, this._transitionParams), d3)) : name2.startsWith("t_M_") ? new Double(getNewStateVal(smoothCounter(i2, this._transitionCounters, this._transitionParams), d)) : name2.startsWith("t_D_") ? new Double(getNewStateVal(smoothCounter(i2, this._transitionCounters, this._transitionParams), d2)) : name2.startsWith("t_S_") ? new Double(getNewStateVal(smoothCounter(i2, this._transitionCounters, this._transitionParams), d4)) : name2.startsWith("t_I_") ? new Double(getNewStateVal(smoothCounter(i2, this._transitionCounters, this._transitionParams), d5)) : new Double(1.0d);
            valueOf = Double.valueOf(valueOf.doubleValue() + Math.abs(this._transitionParams.get(i2).doubleValue() - d6.doubleValue()));
            this._transitionParams.set(i2, d6);
        }
        return valueOf;
    }

    protected double smoothCounter(int i, List<Double> list, List<Double> list2) {
        return list.get(i).doubleValue() + Math.pow(list2.get(i).doubleValue(), 1.0d);
    }

    protected double getNewStateVal(double d, double d2) {
        if (d2 == 0.0d) {
            return 0.0d;
        }
        return new Double(d / d2).doubleValue();
    }

    protected Double maximizationStepForEmissions() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        Emissions[] values = Emissions.values();
        for (int i = 0; i < values.length; i++) {
            String name = values[i].name();
            if (name.startsWith("e_DL_")) {
                d3 += smoothCounter(i, this._emissionCounters, this._emissionParams);
            } else if (name.startsWith("e_M_")) {
                d += smoothCounter(i, this._emissionCounters, this._emissionParams);
            } else if (name.startsWith("e_D_")) {
                d2 += smoothCounter(i, this._emissionCounters, this._emissionParams);
            } else if (name.startsWith("e_I_")) {
                d4 += smoothCounter(i, this._emissionCounters, this._emissionParams);
            }
        }
        Double valueOf = Double.valueOf(0.0d);
        for (int i2 = 0; i2 < values.length; i2++) {
            String name2 = values[i2].name();
            Double d5 = name2.startsWith("e_DL_") ? new Double(getNewStateVal(smoothCounter(i2, this._emissionCounters, this._emissionParams), d3)) : name2.startsWith("e_M_") ? new Double(getNewStateVal(smoothCounter(i2, this._emissionCounters, this._emissionParams), d)) : name2.startsWith("e_D_") ? new Double(getNewStateVal(smoothCounter(i2, this._emissionCounters, this._emissionParams), d2)) : name2.startsWith("e_I_") ? new Double(getNewStateVal(smoothCounter(i2, this._emissionCounters, this._emissionParams), d4)) : new Double(1.0d);
            valueOf = Double.valueOf(valueOf.doubleValue() + Math.abs(this._emissionParams.get(i2).doubleValue() - d5.doubleValue()));
            this._emissionParams.set(i2, d5);
        }
        return valueOf;
    }
}
