package ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.shapelets;

import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.MathUtil;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.shapelets.LearnShapeletsLearningAlgorithm;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.aeonbits.owner.ConfigCache;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/singlelabel/timeseries/learner/shapelets/LearnShapeletsClassifier.class */
public class LearnShapeletsClassifier extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LearnShapeletsClassifier.class);
    private double[][][] s;
    private double[][][] w;
    private double[] w0;
    private int c;
    private final LearnShapeletsLearningAlgorithm.ILearnShapeletsLearningAlgorithmConfig config;

    public LearnShapeletsClassifier(int i, double d, double d2, int i2, double d3, int i3, int i4) {
        this(i, d, d2, i2, d3, i3, 0.5d, i4);
    }

    public LearnShapeletsClassifier(int i, double d, double d2, int i2, double d3, int i3, double d4, int i4) {
        this.config = ConfigCache.getOrCreate(LearnShapeletsLearningAlgorithm.ILearnShapeletsLearningAlgorithmConfig.class, new Map[0]);
        this.config.setProperty("numshapelets", "" + i);
        this.config.setProperty("regularization", "" + d2);
        this.config.setProperty("scaler", "" + i2);
        this.config.setProperty("relativeminshapeletlength", "" + d3);
        this.config.setProperty("seed", "" + i4);
        this.config.setProperty("maxiter", "" + i3);
        this.config.setProperty("learningrate", "" + d);
        this.config.setProperty("gamma", "" + d4);
    }

    public void setEstimateK(boolean z) {
        this.config.setProperty("estimatek", "" + z);
    }

    public double[][][] getS() {
        return this.s;
    }

    public void setS(double[][][] dArr) {
        this.s = dArr;
    }

    public double[][][] getW() {
        return this.w;
    }

    public void setW(double[][][] dArr) {
        this.w = dArr;
    }

    public double[] getW0() {
        return this.w0;
    }

    public void setW0(double[] dArr) {
        this.w0 = dArr;
    }

    public void setC(int i) {
        this.c = i;
    }

    public void setMinShapeLength(int i) {
        this.config.setProperty("minshapeletlength", "" + i);
    }

    /* renamed from: predict, reason: merged with bridge method [inline-methods] */
    public Integer m16predict(double[] dArr) throws PredictionException {
        if (!isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        HashMap hashMap = new HashMap();
        double[] zNormalize = TimeSeriesUtil.zNormalize(dArr, false);
        for (int i = 0; i < this.c; i++) {
            double d = this.w0[i];
            for (int i2 = 0; i2 < this.config.scaleR(); i2++) {
                for (int i3 = 0; i3 < this.s[i2].length; i3++) {
                    d += LearnShapeletsLearningAlgorithm.calculateMHat(this.s, this.config.minShapeletLength(), i2, zNormalize, i3, zNormalize.length, -30.0d) * this.w[i][i2][i3];
                }
            }
            hashMap.put(Integer.valueOf(i), Double.valueOf(MathUtil.sigmoid(d)));
        }
        return (Integer) ((Map.Entry) Collections.max(hashMap.entrySet(), Map.Entry.comparingByValue())).getKey();
    }

    public Integer predict(List<double[]> list) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return m16predict(list.get(0));
    }

    public List<Integer> predict(TimeSeriesDataset2 timeSeriesDataset2) throws PredictionException {
        if (!isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (timeSeriesDataset2.isMultivariate()) {
            LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        }
        double[][] valuesOrNull = timeSeriesDataset2.getValuesOrNull(0);
        if (valuesOrNull == null) {
            throw new IllegalArgumentException("Dataset matrix of the instances to be predicted must not be null!");
        }
        ArrayList arrayList = new ArrayList();
        LOGGER.debug("Starting prediction...");
        for (double[] dArr : valuesOrNull) {
            arrayList.add(m16predict(dArr));
        }
        LOGGER.debug("Finished prediction.");
        return arrayList;
    }

    /* renamed from: getLearningAlgorithm, reason: merged with bridge method [inline-methods] */
    public LearnShapeletsLearningAlgorithm m14getLearningAlgorithm(TimeSeriesDataset2 timeSeriesDataset2) {
        return new LearnShapeletsLearningAlgorithm(this.config, this, timeSeriesDataset2);
    }

    /* renamed from: predict, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Object m15predict(List list) throws PredictionException {
        return predict((List<double[]>) list);
    }
}
