package ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesFeature;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.TimeSeriesBagOfFeaturesLearningAlgorithm;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.util.WekaTimeseriesUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.aeonbits.owner.ConfigCache;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.trees.RandomForest;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/singlelabel/timeseries/learner/trees/TimeSeriesBagOfFeaturesClassifier.class */
public class TimeSeriesBagOfFeaturesClassifier extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(TimeSeriesBagOfFeaturesClassifier.class);
    private RandomForest subseriesClf;
    private RandomForest finalClf;
    private int numClasses;
    private int[][][] intervals;
    private int[][] subsequences;
    private final TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig config;

    public TimeSeriesBagOfFeaturesClassifier(int i) {
        this(i, 10, 10, 0.1d, 5, false);
    }

    public TimeSeriesBagOfFeaturesClassifier(int i, int i2, int i3, double d, int i4) {
        this(i, i2, i3, d, i4, false);
    }

    public TimeSeriesBagOfFeaturesClassifier(int i, int i2, int i3, double d, int i4, boolean z) {
        this.config = ConfigCache.getOrCreate(TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig.class, new Map[0]);
        this.config.setProperty("seed", "" + i);
        setNumBins(i2);
        this.config.setProperty("numfolds", "" + i3);
        this.config.setProperty(TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig.K_ZPROP, "" + d);
        this.config.setProperty(TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig.K_MIN_INTERVAL_LENGTH, "" + i4);
        this.config.setProperty(TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig.K_USE_ZNORMALIZATION, "" + z);
    }

    /* JADX WARN: Type inference failed for: r0v12, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r0v39, types: [double[][], double[][][]] */
    /* renamed from: predict, reason: merged with bridge method [inline-methods] */
    public Integer m32predict(double[] dArr) throws PredictionException {
        if (!isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (this.config.zNormalization()) {
            dArr = TimeSeriesUtil.zNormalize(dArr, true);
        }
        double[][] dArr2 = new double[this.intervals.length][((this.intervals[0].length + 1) * 3) + 2];
        for (int i = 0; i < this.intervals.length; i++) {
            for (int i2 = 0; i2 < this.intervals[i].length; i2++) {
                double[] features = TimeSeriesFeature.getFeatures(dArr, this.intervals[i][i2][0], this.intervals[i][i2][1] - 1, false);
                dArr2[i][i2 * 3] = features[0];
                dArr2[i][(i2 * 3) + 1] = features[1] * features[1];
                dArr2[i][(i2 * 3) + 2] = features[2];
            }
            double[] features2 = TimeSeriesFeature.getFeatures(dArr, this.subsequences[i][0], this.subsequences[i][1] - 1, false);
            dArr2[i][this.intervals[i].length * 3] = features2[0];
            dArr2[i][(this.intervals[i].length * 3) + 1] = features2[1] * features2[1];
            dArr2[i][(this.intervals[i].length * 3) + 2] = features2[2];
            dArr2[i][dArr2[i].length - 2] = this.subsequences[i][0];
            dArr2[i][dArr2[i].length - 1] = this.subsequences[i][1];
        }
        Instances simplifiedTimeSeriesDatasetToWekaInstances = WekaTimeseriesUtil.simplifiedTimeSeriesDatasetToWekaInstances(TimeSeriesUtil.createDatasetForMatrix((double[][][]) new double[][]{dArr2}), (List) IntStream.rangeClosed(0, this.numClasses - 1).boxed().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.toList()));
        int[] iArr = new int[simplifiedTimeSeriesDatasetToWekaInstances.numInstances()];
        try {
            double[][] distributionsForInstances = this.subseriesClf.distributionsForInstances(simplifiedTimeSeriesDatasetToWekaInstances);
            for (int i3 = 0; i3 < simplifiedTimeSeriesDatasetToWekaInstances.numInstances(); i3++) {
                iArr[i3] = (int) this.subseriesClf.classifyInstance(simplifiedTimeSeriesDatasetToWekaInstances.get(i3));
            }
            Pair<int[][][], int[][]> formHistogramsAndRelativeFreqs = TimeSeriesBagOfFeaturesLearningAlgorithm.formHistogramsAndRelativeFreqs(TimeSeriesBagOfFeaturesLearningAlgorithm.discretizeProbs(getNumBins(), distributionsForInstances), 1, this.numClasses, getNumBins());
            Instances simplifiedTimeSeriesDatasetToWekaInstances2 = WekaTimeseriesUtil.simplifiedTimeSeriesDatasetToWekaInstances(TimeSeriesUtil.createDatasetForMatrix((double[][][]) new double[][]{TimeSeriesBagOfFeaturesLearningAlgorithm.generateHistogramInstances((int[][][]) formHistogramsAndRelativeFreqs.getX(), (int[][]) formHistogramsAndRelativeFreqs.getY())}), (List) IntStream.rangeClosed(0, this.numClasses - 1).boxed().map((v0) -> {
                return String.valueOf(v0);
            }).collect(Collectors.toList()));
            if (simplifiedTimeSeriesDatasetToWekaInstances2.size() != 1) {
                throw new PredictionException("There should be only one instance given to the final Random Forest classifier.", new IllegalStateException("There should be only one instance given to the final Random Forest classifier."));
            }
            try {
                return Integer.valueOf((int) this.finalClf.classifyInstance(simplifiedTimeSeriesDatasetToWekaInstances2.firstInstance()));
            } catch (Exception e) {
                throw new PredictionException("Could not predict instance due to an internal Weka exception.", e);
            }
        } catch (Exception e2) {
            throw new PredictionException("Cannot derive the probabilities using the subseries classifier due to an internal Weka exception.", e2);
        }
    }

    public Integer predict(List<double[]> list) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return m32predict(list.get(0));
    }

    public List<Integer> predict(TimeSeriesDataset2 timeSeriesDataset2) throws PredictionException {
        if (!isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < timeSeriesDataset2.getValues(0).length; i++) {
            arrayList.add(m32predict(timeSeriesDataset2.getValues(0)[i]));
        }
        return arrayList;
    }

    public RandomForest getSubseriesClf() {
        return this.subseriesClf;
    }

    public void setSubseriesClf(RandomForest randomForest) {
        this.subseriesClf = randomForest;
    }

    public RandomForest getFinalClf() {
        return this.finalClf;
    }

    public void setFinalClf(RandomForest randomForest) {
        this.finalClf = randomForest;
    }

    public int getNumBins() {
        return this.config.numBins();
    }

    public void setNumBins(int i) {
        this.config.setProperty(TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig.K_NUMBINS, "" + i);
    }

    public int getNumClasses() {
        return this.numClasses;
    }

    public void setNumClasses(int i) {
        this.numClasses = i;
    }

    public int[][][] getIntervals() {
        return this.intervals;
    }

    public void setIntervals(int[][][] iArr) {
        this.intervals = iArr;
    }

    public int[][] getSubsequences() {
        return this.subsequences;
    }

    public void setSubsequences(int[][] iArr) {
        this.subsequences = iArr;
    }

    /* renamed from: getLearningAlgorithm, reason: merged with bridge method [inline-methods] */
    public TimeSeriesBagOfFeaturesLearningAlgorithm m30getLearningAlgorithm(TimeSeriesDataset2 timeSeriesDataset2) {
        return new TimeSeriesBagOfFeaturesLearningAlgorithm(this.config, this, timeSeriesDataset2);
    }

    public TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig getConfig() {
        return this.config;
    }

    /* renamed from: predict, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Object m31predict(List list) throws PredictionException {
        return predict((List<double[]>) list);
    }
}
