package be.ac.ulg.montefiore.run.jahmm.apps.sample;

import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.Observation;
import be.ac.ulg.montefiore.run.jahmm.ObservationDiscrete;
import be.ac.ulg.montefiore.run.jahmm.OpdfDiscrete;
import be.ac.ulg.montefiore.run.jahmm.OpdfDiscreteFactory;
import be.ac.ulg.montefiore.run.jahmm.draw.GenericHmmDrawerDot;
import be.ac.ulg.montefiore.run.jahmm.learn.BaumWelchLearner;
import be.ac.ulg.montefiore.run.jahmm.toolbox.KullbackLeiblerDistanceCalculator;
import be.ac.ulg.montefiore.run.jahmm.toolbox.MarkovGenerator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:be/ac/ulg/montefiore/run/jahmm/apps/sample/SimpleExample.class */
public class SimpleExample {

    /* loaded from: input_file:be/ac/ulg/montefiore/run/jahmm/apps/sample/SimpleExample$Packet.class */
    public enum Packet {
        OK,
        LOSS;

        public ObservationDiscrete<Packet> observation() {
            return new ObservationDiscrete<>(this);
        }
    }

    public static void main(String[] strArr) throws IOException {
        Hmm<ObservationDiscrete<Packet>> buildHmm = buildHmm();
        List generateSequences = generateSequences(buildHmm);
        BaumWelchLearner baumWelchLearner = new BaumWelchLearner();
        Hmm<ObservationDiscrete<Packet>> buildInitHmm = buildInitHmm();
        KullbackLeiblerDistanceCalculator kullbackLeiblerDistanceCalculator = new KullbackLeiblerDistanceCalculator();
        for (int i = 0; i < 10; i++) {
            System.out.println("Distance at iteration " + i + ": " + kullbackLeiblerDistanceCalculator.distance(buildInitHmm, buildHmm));
            buildInitHmm = baumWelchLearner.iterate(buildInitHmm, generateSequences);
        }
        System.out.println("Resulting HMM:\n" + buildInitHmm);
        ObservationDiscrete<Packet> observation = Packet.OK.observation();
        ObservationDiscrete<Packet> observation2 = Packet.LOSS.observation();
        ArrayList arrayList = new ArrayList();
        arrayList.add(observation);
        arrayList.add(observation);
        arrayList.add(observation2);
        System.out.println("Sequence probability: " + buildInitHmm.probability(arrayList));
        new GenericHmmDrawerDot().write(buildInitHmm, "learntHmm.dot");
    }

    static Hmm<ObservationDiscrete<Packet>> buildHmm() {
        Hmm<ObservationDiscrete<Packet>> hmm = new Hmm<>(2, new OpdfDiscreteFactory(Packet.class));
        hmm.setPi(0, 0.95d);
        hmm.setPi(1, 0.05d);
        hmm.setOpdf(0, new OpdfDiscrete(Packet.class, new double[]{0.95d, 0.05d}));
        hmm.setOpdf(1, new OpdfDiscrete(Packet.class, new double[]{0.2d, 0.8d}));
        hmm.setAij(0, 1, 0.05d);
        hmm.setAij(0, 0, 0.95d);
        hmm.setAij(1, 0, 0.1d);
        hmm.setAij(1, 1, 0.9d);
        return hmm;
    }

    static Hmm<ObservationDiscrete<Packet>> buildInitHmm() {
        Hmm<ObservationDiscrete<Packet>> hmm = new Hmm<>(2, new OpdfDiscreteFactory(Packet.class));
        hmm.setPi(0, 0.5d);
        hmm.setPi(1, 0.5d);
        hmm.setOpdf(0, new OpdfDiscrete(Packet.class, new double[]{0.8d, 0.2d}));
        hmm.setOpdf(1, new OpdfDiscrete(Packet.class, new double[]{0.1d, 0.9d}));
        hmm.setAij(0, 1, 0.2d);
        hmm.setAij(0, 0, 0.8d);
        hmm.setAij(1, 0, 0.2d);
        hmm.setAij(1, 1, 0.8d);
        return hmm;
    }

    static <O extends Observation> List<List<O>> generateSequences(Hmm<O> hmm) {
        MarkovGenerator markovGenerator = new MarkovGenerator(hmm);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 200; i++) {
            arrayList.add(markovGenerator.observationSequence(100));
        }
        return arrayList;
    }
}
