package lu.uni.serval.flakime.core.instrumentation.strategies.vocabulary;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lu.uni.serval.flakime.core.utils.Logger;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.TokenizerMode;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.SparseInstance;

/* loaded from: input_file:lu/uni/serval/flakime/core/instrumentation/strategies/vocabulary/WekaModel.class */
public class WekaModel implements Model {
    private final RandomForest randomForest;
    private final Logger logger;
    private boolean trainNeededFlag;
    private KerasTokenizer tokenizer;
    private Instances trainingInstances;

    public WekaModel(Logger logger, RandomForest randomForest) {
        this.trainNeededFlag = true;
        this.logger = logger;
        this.trainNeededFlag = false;
        this.randomForest = randomForest;
    }

    public WekaModel(Logger logger, int i, int i2) {
        this.trainNeededFlag = true;
        this.logger = logger;
        this.randomForest = new RandomForest();
        this.randomForest.setNumIterations(i);
        this.randomForest.setSeed(0);
        this.randomForest.setDebug(true);
        this.randomForest.setNumExecutionSlots(i2);
    }

    @Override // lu.uni.serval.flakime.core.instrumentation.strategies.vocabulary.Model
    public void setData(TrainingData trainingData, Set<String> set) {
        Integer[] numArr = (Integer[]) trainingData.getEntries().stream().map((v0) -> {
            return v0.getLabel();
        }).toArray(i -> {
            return new Integer[i];
        });
        String[] strArr = (String[]) trainingData.getEntries().stream().map((v0) -> {
            return v0.getBody();
        }).toArray(i2 -> {
            return new String[i2];
        });
        this.tokenizer = createTokenizer(strArr, (String[]) set.toArray(new String[0]));
        this.trainingInstances = createInstances(this.tokenizer, numArr.length, strArr, numArr);
    }

    @Override // lu.uni.serval.flakime.core.instrumentation.strategies.vocabulary.Model
    public void train() throws Exception {
        this.logger.info(String.format("Training Random Forest Classifier on %d threads with %d trees...", Integer.valueOf(this.randomForest.getNumExecutionSlots()), Integer.valueOf(this.randomForest.getNumIterations())));
        long nanoTime = System.nanoTime();
        this.randomForest.buildClassifier(this.trainingInstances);
        this.trainNeededFlag = false;
        this.logger.info(String.format("Random Forest Classifier trained in %.1f seconds", Float.valueOf(((float) (System.nanoTime() - nanoTime)) / 1.0E9f)));
    }

    @Override // lu.uni.serval.flakime.core.instrumentation.strategies.vocabulary.Model
    public double computeProbability(String str) throws Exception {
        if (this.trainNeededFlag) {
            throw new IllegalStateException("The model is not fitted");
        }
        return this.randomForest.distributionForInstance(createSingleInstance(this.trainingInstances, str, 0.0d, this.tokenizer))[1];
    }

    @Override // lu.uni.serval.flakime.core.instrumentation.strategies.vocabulary.Model
    public void save(String str) throws Exception {
        SerializationHelper.write(str, this.randomForest);
    }

    public static Model load(Logger logger, String str) throws Exception {
        try {
            return new WekaModel(logger, (RandomForest) SerializationHelper.read(str));
        } catch (StackOverflowError e) {
            logger.error("Stackoverflow due to insufficient stack size, please increment with -Xss10m");
            throw new IOException(e);
        }
    }

    private Instances createInstances(KerasTokenizer kerasTokenizer, int i, String[] strArr, Integer[] numArr) {
        Instances createEmptyInstances = createEmptyInstances(kerasTokenizer, i, numArr);
        for (int i2 = 0; i2 < i; i2++) {
            createEmptyInstances.add(createSingleInstance(createEmptyInstances, strArr[i2], numArr[i2].intValue(), kerasTokenizer));
        }
        return createEmptyInstances;
    }

    private Instance createSingleInstance(Instances instances, String str, double d, KerasTokenizer kerasTokenizer) {
        SparseInstance sparseInstance = new SparseInstance(1.0d, ArrayUtils.insert(0, kerasTokenizer.textsToMatrix(new String[]{str}, TokenizerMode.COUNT).toDoubleVector(), new double[]{d}));
        sparseInstance.setDataset(instances);
        return sparseInstance;
    }

    private Instances createEmptyInstances(KerasTokenizer kerasTokenizer, int i, Integer[] numArr) {
        Map indexWord = kerasTokenizer.getIndexWord();
        Attribute attribute = new Attribute("flakime_label", (List) Arrays.stream(numArr).distinct().sorted().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.toList()));
        ArrayList arrayList = (ArrayList) indexWord.values().stream().map(Attribute::new).collect(Collectors.toList());
        arrayList.add(0, attribute);
        Instances instances = new Instances("trainData", arrayList, i);
        instances.setClass(attribute);
        instances.setClassIndex(0);
        return instances;
    }

    private KerasTokenizer createTokenizer(String[] strArr, String[] strArr2) {
        String[] strArr3 = (String[]) ArrayUtils.addAll(strArr, strArr2);
        KerasTokenizer kerasTokenizer = new KerasTokenizer();
        kerasTokenizer.fitOnTexts(strArr3);
        return kerasTokenizer;
    }
}
