package net.digital_alexandria.lvm4j.hmm;

import java.util.Collections;
import java.util.List;
import net.digital_alexandria.lvm4j.edges.ArcFactory;
import net.digital_alexandria.lvm4j.edges.WeightedArc;
import net.digital_alexandria.lvm4j.math.linalg.Combinatorial;
import net.digital_alexandria.lvm4j.nodes.HMMNode;
import net.digital_alexandria.lvm4j.nodes.LatentHMMNode;
import net.digital_alexandria.lvm4j.nodes.NodeFactory;
import net.digital_alexandria.lvm4j.structs.Pair;
import net.digital_alexandria.lvm4j.structs.Triple;
import net.digital_alexandria.lvm4j.util.File;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:net/digital_alexandria/lvm4j/hmm/HMMFactory.class */
public final class HMMFactory {
    private static final Logger _LOGGER = LoggerFactory.getLogger(HMMFactory.class);
    private static final ArcFactory _ARC_FACTORY = ArcFactory.instance();
    private static final NodeFactory _NODE_FACTORY = NodeFactory.instance();
    private static HMMFactory _factory;

    private HMMFactory() {
    }

    public static HMMFactory instance() {
        if (_factory == null) {
            _LOGGER.info("Instantiating HMMFactory");
            _factory = new HMMFactory();
        }
        return _factory;
    }

    public HMM hmm(String str) {
        return HMMbuilder(str);
    }

    public HMM hmm(char[] cArr, char[] cArr2, int i) {
        return HMMbuilder(cArr, cArr2, i);
    }

    private HMM HMMbuilder(char[] cArr, char[] cArr2, int i) {
        HMM hmm = new HMM();
        init(hmm, cArr, cArr2, i);
        return hmm;
    }

    private HMM HMMbuilder(String str) {
        HMM hmm = new HMM();
        init(hmm, str);
        return hmm;
    }

    private void init(HMM hmm, String str) {
        init(hmm, File.parseXML(str));
    }

    private void init(HMM hmm, HMMParams hMMParams) {
        init(hmm, hMMParams.states(), hMMParams.observations(), hMMParams.order());
        if (hMMParams.hasTrainingParams()) {
            initTrainingParams(hmm, hMMParams.emissionProbabilities(), hMMParams.transitionProbabilities(), hMMParams.startProbabilities());
        }
    }

    private void init(HMM hmm, char[] cArr, char[] cArr2, int i) {
        hmm.order = i;
        init(hmm, Combinatorial.combinatorial(cArr, hmm.order), cArr2);
    }

    private void init(HMM hmm, List<String> list, char[] cArr) {
        addStates(hmm, list);
        addObservations(hmm, cArr);
        addTransitions(hmm);
        addEmissions(hmm);
    }

    private void initTrainingParams(HMM hmm, List<Triple<String, String, Double>> list, List<Triple<String, String, Double>> list2, List<Pair<String, Double>> list3) {
        for (Pair<String, Double> pair : list3) {
            String first = pair.getFirst();
            double doubleValue = pair.getSecond().doubleValue();
            hmm.STATES.stream().filter(latentHMMNode -> {
                return ((String) latentHMMNode.state()).equals(first);
            }).forEach(latentHMMNode2 -> {
                latentHMMNode2.startingProbability(doubleValue);
            });
        }
        list2.stream().forEach(triple -> {
            setUpWeights(triple, hmm.TRANSITIONS);
        });
        list.stream().forEach(triple2 -> {
            setUpWeights(triple2, hmm.EMISSIONS);
        });
    }

    private void setUpWeights(Triple<String, String, Double> triple, List<WeightedArc> list) {
        String first = triple.getFirst();
        String second = triple.getSecond();
        double doubleValue = triple.getThird().doubleValue();
        for (WeightedArc weightedArc : list) {
            HMMNode hMMNode = (HMMNode) weightedArc.source();
            HMMNode hMMNode2 = (HMMNode) weightedArc.sink();
            if (((String) hMMNode.state()).equals(first) && ((String) hMMNode2.state()).equals(second)) {
                weightedArc.weight(doubleValue);
            }
        }
    }

    private void addStates(HMM hmm, List<String> list) {
        Collections.sort(list, (str, str2) -> {
            return str.length() != str2.length() ? str.length() < str2.length() ? -1 : 1 : str.compareTo(str2);
        });
        for (int i = 0; i < list.size(); i++) {
            String str3 = list.get(i);
            hmm.STATES.add(_NODE_FACTORY.newLatentHMMNode(Character.valueOf(str3.charAt(str3.length() - 1)), str3, i));
        }
    }

    private void addObservations(HMM hmm, char[] cArr) {
        for (int i = 0; i < cArr.length; i++) {
            hmm.OBSERVATIONS.add(_NODE_FACTORY.newHMMNode(Character.valueOf(cArr[i]), String.valueOf(cArr[i]), i));
        }
    }

    private void addTransitions(HMM hmm) {
        for (int i = 0; i < hmm.STATES.size(); i++) {
            LatentHMMNode<Character, String> latentHMMNode = hmm.STATES.get(i);
            hmm.STATES.stream().forEach(latentHMMNode2 -> {
                addTransition(hmm, latentHMMNode, latentHMMNode2);
            });
        }
    }

    private void addTransition(HMM hmm, LatentHMMNode<Character, String> latentHMMNode, HMMNode<Character, String> hMMNode) {
        String state = latentHMMNode.state();
        String state2 = hMMNode.state();
        int length = state.length();
        int length2 = state2.length();
        if (length > length2) {
            return;
        }
        if (length != length2 || length >= hmm.order) {
            if (length == length2 || length + 1 == length2) {
                if ((length < hmm.order ? state : state.substring(1, length)).equals(length2 < hmm.order ? state2.substring(0, length) : state2.substring(0, length2 - 1))) {
                    WeightedArc weightedArc = _ARC_FACTORY.weightedArc(latentHMMNode, hMMNode, 0.0d);
                    hmm.TRANSITIONS.add(weightedArc);
                    latentHMMNode.addTransition(weightedArc);
                }
            }
        }
    }

    private void addEmissions(HMM hmm) {
        for (LatentHMMNode<Character, String> latentHMMNode : hmm.STATES) {
            hmm.OBSERVATIONS.stream().forEach(hMMNode -> {
                addEmission(hmm, latentHMMNode, hMMNode);
            });
        }
    }

    private void addEmission(HMM hmm, LatentHMMNode latentHMMNode, HMMNode hMMNode) {
        WeightedArc weightedArc = _ARC_FACTORY.weightedArc(latentHMMNode, hMMNode, 0.0d);
        hmm.EMISSIONS.add(weightedArc);
        latentHMMNode.addEmission(weightedArc);
    }
}
