package ai.libs.jaicore.ml.weka.classification.timeseries.learner.shapelets;

import ai.libs.jaicore.basic.IOwnerBasedRandomizedAlgorithmConfig;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.quality.IQualityMeasure;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.shapelets.Shapelet;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.shapelets.search.AMinimumDistanceSearchStrategy;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.shapelets.search.EarlyAbandonMinimumDistanceSearchStrategy;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.util.WekaTimeseriesUtil;
import ai.libs.jaicore.ml.weka.classification.timeseries.learner.ensemble.EnsembleProvider;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import org.aeonbits.owner.Config;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/timeseries/learner/shapelets/ShapeletTransformLearningAlgorithm.class */
public class ShapeletTransformLearningAlgorithm extends ASimplifiedTSCLearningAlgorithm<Integer, ShapeletTransformTSClassifier> {
    private static final Logger logger = LoggerFactory.getLogger(ShapeletTransformLearningAlgorithm.class);
    private final IQualityMeasure qualityMeasure;
    private static final int MIN_MAX_ESTIMATION_SAMPLES = 10;
    private static final boolean USE_BIAS_CORRECTION = true;
    private AMinimumDistanceSearchStrategy minDistanceSearchStrategy;
    private static final String INTERRUPTION_MESSAGE = "Interrupted training due to timeout.";

    /* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/timeseries/learner/shapelets/ShapeletTransformLearningAlgorithm$IShapeletTransformLearningAlgorithmConfig.class */
    public interface IShapeletTransformLearningAlgorithmConfig extends IOwnerBasedRandomizedAlgorithmConfig {
        public static final String K_NUMSHAPELETS = "numshapelets";
        public static final String K_NUMCLUSTERS = "numclusters";
        public static final String K_CLUSTERSHAPELETS = "clustershapelets";
        public static final String K_SHAPELETLENGTH_MIN = "minshapeletlength";
        public static final String K_SHAPELETLENGTH_MAX = "maxshapeletlength";
        public static final String K_USEHIVECOTEENSEMBLE = "usehivecoteensemble";
        public static final String K_ESTIMATESHAPELETLENGTHBORDERS = "estimateshapeletlengthborders";
        public static final String K_NUMFOLDS = "numfolds";

        @Config.DefaultValue("10")
        @Config.Key("numshapelets")
        int numShapelets();

        @Config.DefaultValue("10")
        @Config.Key(K_NUMCLUSTERS)
        int numClusters();

        @Config.DefaultValue("false")
        @Config.Key(K_CLUSTERSHAPELETS)
        boolean clusterShapelets();

        @Config.DefaultValue("3")
        @Config.Key("minshapeletlength")
        int minShapeletLength();

        @Config.Key(K_SHAPELETLENGTH_MAX)
        int maxShapeletLength();

        @Config.Key(K_USEHIVECOTEENSEMBLE)
        boolean useHIVECOTEEnsemble();

        @Config.Key(K_ESTIMATESHAPELETLENGTHBORDERS)
        boolean estimateShapeletLengthBorders();

        @Config.DefaultValue("5")
        @Config.Key("numfolds")
        int numFolds();
    }

    public ShapeletTransformLearningAlgorithm(IShapeletTransformLearningAlgorithmConfig iShapeletTransformLearningAlgorithmConfig, ShapeletTransformTSClassifier shapeletTransformTSClassifier, TimeSeriesDataset2 timeSeriesDataset2, IQualityMeasure iQualityMeasure) {
        super(iShapeletTransformLearningAlgorithmConfig, shapeletTransformTSClassifier, timeSeriesDataset2);
        this.minDistanceSearchStrategy = new EarlyAbandonMinimumDistanceSearchStrategy(true);
        this.qualityMeasure = iQualityMeasure;
    }

    /* renamed from: call, reason: merged with bridge method [inline-methods] */
    public ShapeletTransformTSClassifier m52call() throws AlgorithmException, InterruptedException {
        if (getNumCPUs() > 1) {
            logger.warn("Multithreading is not supported for LearnShapelets yet. Therefore, the number of CPUs is not considered.");
        }
        long currentTimeMillis = System.currentTimeMillis();
        TimeSeriesDataset2 timeSeriesDataset2 = (TimeSeriesDataset2) getInput();
        if (timeSeriesDataset2 == null || timeSeriesDataset2.isEmpty()) {
            throw new IllegalStateException("The time series input data must not be null or empty!");
        }
        if (timeSeriesDataset2.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate datasets are not supported.");
        }
        double[][] valuesOrNull = timeSeriesDataset2.getValuesOrNull(0);
        if (valuesOrNull == null) {
            throw new IllegalArgumentException("Value matrix must be a valid 2D matrix containing the time series values for all instances!");
        }
        int[] targets = timeSeriesDataset2.getTargets();
        int minShapeletLength = m53getConfig().minShapeletLength();
        int maxShapeletLength = m53getConfig().maxShapeletLength();
        long seed = m53getConfig().seed();
        ShapeletTransformTSClassifier shapeletTransformTSClassifier = (ShapeletTransformTSClassifier) getClassifier();
        int length = valuesOrNull[0].length;
        if (m53getConfig().estimateShapeletLengthBorders()) {
            logger.debug("Starting min max estimation.");
            int[] estimateMinMax = estimateMinMax(valuesOrNull, targets, currentTimeMillis);
            minShapeletLength = estimateMinMax[0];
            maxShapeletLength = estimateMinMax[1];
            logger.debug("Finished min max estimation. min={}, max={}", Integer.valueOf(minShapeletLength), Integer.valueOf(maxShapeletLength));
        } else if (maxShapeletLength == -1) {
            maxShapeletLength = length - 1;
        }
        if (maxShapeletLength >= length) {
            logger.debug("The maximum shapelet length was larger than the total time series length. Therefore, it will be set to time series length - 1.");
            maxShapeletLength = length - 1;
        }
        logger.debug("Starting cached shapelet selection with min={}, max={} and k={}...", new Object[]{Integer.valueOf(minShapeletLength), Integer.valueOf(maxShapeletLength), Integer.valueOf(m53getConfig().numShapelets())});
        List<Shapelet> shapeletCachedSelection = shapeletCachedSelection(valuesOrNull, minShapeletLength, maxShapeletLength, m53getConfig().numShapelets(), targets, currentTimeMillis);
        logger.debug("Finished cached shapelet selection. Extracted {} shapelets.", Integer.valueOf(shapeletCachedSelection.size()));
        if (m53getConfig().clusterShapelets()) {
            logger.debug("Starting shapelet clustering...");
            shapeletCachedSelection = clusterShapelets(shapeletCachedSelection, m53getConfig().numClusters(), currentTimeMillis);
            logger.debug("Finished shapelet clustering. Staying with {} shapelets.", Integer.valueOf(shapeletCachedSelection.size()));
        }
        shapeletTransformTSClassifier.setShapelets(shapeletCachedSelection);
        logger.debug("Transforming the training data using the extracted shapelets.");
        TimeSeriesDataset2 shapeletTransform = shapeletTransform(timeSeriesDataset2, shapeletTransformTSClassifier.getShapelets(), getTimeout(), currentTimeMillis, this.minDistanceSearchStrategy);
        logger.debug("Finished transforming the training data.");
        logger.debug("Initializing ensemble classifier...");
        try {
            Classifier provideHIVECOTEEnsembleModel = m53getConfig().useHIVECOTEEnsemble() ? EnsembleProvider.provideHIVECOTEEnsembleModel(seed) : EnsembleProvider.provideCAWPEEnsembleModel((int) seed, m53getConfig().numFolds());
            logger.debug("Initialized ensemble classifier.");
            logger.debug("Starting ensemble training...");
            try {
                WekaTimeseriesUtil.buildWekaClassifierFromSimplifiedTS(provideHIVECOTEEnsembleModel, shapeletTransform);
                logger.debug("Finished ensemble training.");
                shapeletTransformTSClassifier.setClassifier(provideHIVECOTEEnsembleModel);
                return shapeletTransformTSClassifier;
            } catch (TrainingException e) {
                throw new AlgorithmException("Could not train classifier due to a training exception.", e);
            }
        } catch (Exception e2) {
            throw new AlgorithmException("Could not train model due to ensemble exception.", e2);
        }
    }

    private int[] estimateMinMax(double[][] dArr, int[] iArr, long j) throws InterruptedException {
        int[] iArr2 = new int[2];
        long length = dArr.length;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < MIN_MAX_ESTIMATION_SAMPLES; i++) {
            double[][] dArr2 = new double[MIN_MAX_ESTIMATION_SAMPLES][dArr[0].length];
            Random random = new Random(m53getConfig().seed());
            int[] iArr3 = new int[MIN_MAX_ESTIMATION_SAMPLES];
            for (int i2 = 0; i2 < MIN_MAX_ESTIMATION_SAMPLES; i2++) {
                int nextInt = (int) (random.nextInt() % length);
                if (nextInt < 0) {
                    nextInt = (int) (nextInt + length);
                }
                for (int i3 = 0; i3 < dArr[0].length; i3++) {
                    dArr2[i2] = Arrays.copyOf(dArr[nextInt], dArr2[i2].length);
                }
                iArr3[i2] = iArr[nextInt];
            }
            arrayList.addAll(shapeletCachedSelection(dArr2, 3, dArr[0].length, MIN_MAX_ESTIMATION_SAMPLES, iArr3, j));
        }
        Shapelet.sortByLengthAsc(arrayList);
        logger.debug("Number of shapelets found in min/max estimation: {}", Integer.valueOf(arrayList.size()));
        iArr2[0] = ((Shapelet) arrayList.get(25)).getLength();
        iArr2[1] = ((Shapelet) arrayList.get(75)).getLength();
        return iArr2;
    }

    public List<Shapelet> clusterShapelets(List<Shapelet> list, int i, long j) throws InterruptedException {
        double d;
        double findMinimumDistance;
        ArrayList arrayList = new ArrayList();
        for (Shapelet shapelet : list) {
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(shapelet);
            arrayList.add(arrayList2);
        }
        while (arrayList.size() > i) {
            if (System.currentTimeMillis() - j > getTimeout().milliseconds()) {
                throw new InterruptedException(INTERRUPTION_MESSAGE);
            }
            INDArray create = Nd4j.create(arrayList.size(), arrayList.size());
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                for (int i3 = 0; i3 < arrayList.size(); i3++) {
                    double d2 = 0.0d;
                    int size = ((List) arrayList.get(i2)).size() * ((List) arrayList.get(i3)).size();
                    for (int i4 = 0; i4 < ((List) arrayList.get(i2)).size(); i4++) {
                        for (int i5 = 0; i5 < ((List) arrayList.get(i3)).size(); i5++) {
                            Shapelet shapelet2 = (Shapelet) ((List) arrayList.get(i2)).get(i4);
                            Shapelet shapelet3 = (Shapelet) ((List) arrayList.get(i3)).get(i5);
                            if (shapelet2.getLength() > shapelet3.getLength()) {
                                d = d2;
                                findMinimumDistance = this.minDistanceSearchStrategy.findMinimumDistance(shapelet3, shapelet2.getData());
                            } else {
                                d = d2;
                                findMinimumDistance = this.minDistanceSearchStrategy.findMinimumDistance(shapelet2, shapelet3.getData());
                            }
                            d2 = d + findMinimumDistance;
                        }
                    }
                    create.putScalar(new int[]{i2, i3}, d2 / size);
                }
            }
            double d3 = Double.MAX_VALUE;
            int i6 = 0;
            int i7 = 0;
            for (int i8 = 0; i8 < create.shape()[0]; i8++) {
                for (int i9 = 0; i9 < create.shape()[1]; i9++) {
                    if (create.getDouble(i8, i9) < d3 && i8 != i9) {
                        i6 = i8;
                        i7 = i9;
                        d3 = create.getDouble(i8, i9);
                    }
                }
            }
            List list2 = (List) arrayList.get(i6);
            list2.addAll((Collection) arrayList.get(i7));
            Shapelet highestQualityShapeletInList = Shapelet.getHighestQualityShapeletInList(list2);
            if (i6 > i7) {
                arrayList.remove(i6);
                arrayList.remove(i7);
            } else {
                arrayList.remove(i7);
                arrayList.remove(i6);
            }
            arrayList.add(Arrays.asList(highestQualityShapeletInList));
        }
        return (List) arrayList.stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toList());
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<Shapelet> shapeletCachedSelection(double[][] dArr, int i, int i2, int i3, int[] iArr, long j) throws InterruptedException {
        List arrayList = new ArrayList();
        int length = dArr.length;
        for (int i4 = 0; i4 < length; i4++) {
            if (System.currentTimeMillis() - j > getTimeout().milliseconds()) {
                throw new InterruptedException(INTERRUPTION_MESSAGE);
            }
            ArrayList arrayList2 = new ArrayList();
            for (int i5 = i; i5 < i2; i5++) {
                for (Shapelet shapelet : generateCandidates(dArr[i4], i5, i4)) {
                    double assessQuality = this.qualityMeasure.assessQuality(findDistances(shapelet, dArr), iArr);
                    shapelet.setDeterminedQuality(assessQuality);
                    arrayList2.add(new AbstractMap.SimpleEntry(shapelet, Double.valueOf(assessQuality)));
                }
            }
            sortByQualityDesc(arrayList2);
            arrayList = merge(i3, arrayList, removeSelfSimilar(arrayList2));
        }
        return (List) arrayList.stream().map((v0) -> {
            return v0.getKey();
        }).collect(Collectors.toList());
    }

    public static List<Map.Entry<Shapelet, Double>> merge(int i, List<Map.Entry<Shapelet, Double>> list, List<Map.Entry<Shapelet, Double>> list2) {
        list.addAll(list2);
        sortByQualityDesc(list);
        int size = list.size() - i;
        for (int i2 = 0; i2 < size; i2++) {
            list.remove(list.size() - 1);
        }
        return list;
    }

    private static void sortByQualityDesc(List<Map.Entry<Shapelet, Double>> list) {
        list.sort((entry, entry2) -> {
            return (-1) * ((Double) entry.getValue()).compareTo((Double) entry2.getValue());
        });
    }

    public static List<Map.Entry<Shapelet, Double>> removeSelfSimilar(List<Map.Entry<Shapelet, Double>> list) {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<Shapelet, Double> entry : list) {
            boolean z = false;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                if (isSelfSimilar(entry.getKey(), (Shapelet) ((Map.Entry) it.next()).getKey())) {
                    z = true;
                }
            }
            if (!z) {
                arrayList.add(entry);
            }
        }
        return arrayList;
    }

    private static boolean isSelfSimilar(Shapelet shapelet, Shapelet shapelet2) {
        return shapelet.getInstanceIndex() == shapelet2.getInstanceIndex() && shapelet.getStartIndex() < shapelet2.getStartIndex() + shapelet2.getLength() && shapelet2.getStartIndex() < shapelet.getStartIndex() + shapelet.getLength();
    }

    public List<Double> findDistances(Shapelet shapelet, double[][] dArr) {
        ArrayList arrayList = new ArrayList();
        for (double[] dArr2 : dArr) {
            arrayList.add(Double.valueOf(this.minDistanceSearchStrategy.findMinimumDistance(shapelet, dArr2)));
        }
        return arrayList;
    }

    public static Set<Shapelet> generateCandidates(double[] dArr, int i, int i2) {
        HashSet hashSet = new HashSet();
        for (int i3 = 0; i3 < (dArr.length - i) + 1; i3++) {
            hashSet.add(new Shapelet(TimeSeriesUtil.zNormalize(TimeSeriesUtil.getInterval(dArr, i3, i3 + i), true), i3, i, i2));
        }
        return hashSet;
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    public static TimeSeriesDataset2 shapeletTransform(TimeSeriesDataset2 timeSeriesDataset2, List<Shapelet> list, Timeout timeout, long j, AMinimumDistanceSearchStrategy aMinimumDistanceSearchStrategy) throws InterruptedException {
        if (timeSeriesDataset2.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate datasets are not supported yet!");
        }
        double[][] valuesOrNull = timeSeriesDataset2.getValuesOrNull(0);
        if (valuesOrNull == null) {
            throw new IllegalArgumentException("Time series matrix must be a valid 2d matrix!");
        }
        ?? r0 = new double[valuesOrNull.length];
        for (int i = 0; i < valuesOrNull.length; i++) {
            if (timeout != null && System.currentTimeMillis() - j > timeout.milliseconds()) {
                throw new InterruptedException(INTERRUPTION_MESSAGE);
            }
            r0[i] = shapeletTransform(valuesOrNull[i], list, aMinimumDistanceSearchStrategy);
        }
        timeSeriesDataset2.replace(0, (double[][]) r0, timeSeriesDataset2.getTimestampsOrNull(0));
        return timeSeriesDataset2;
    }

    public static double[] shapeletTransform(double[] dArr, List<Shapelet> list, AMinimumDistanceSearchStrategy aMinimumDistanceSearchStrategy) {
        double[] dArr2 = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr2[i] = aMinimumDistanceSearchStrategy.findMinimumDistance(list.get(i), dArr);
        }
        return dArr2;
    }

    public AMinimumDistanceSearchStrategy getMinDistanceSearchStrategy() {
        return this.minDistanceSearchStrategy;
    }

    public void setMinDistanceSearchStrategy(AMinimumDistanceSearchStrategy aMinimumDistanceSearchStrategy) {
        this.minDistanceSearchStrategy = aMinimumDistanceSearchStrategy;
    }

    public void registerListener(Object obj) {
        throw new UnsupportedOperationException("The operation to be performed is not supported.");
    }

    public IAlgorithmEvent nextWithException() {
        throw new UnsupportedOperationException("The operation to be performed is not supported.");
    }

    /* renamed from: getConfig, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public IShapeletTransformLearningAlgorithmConfig m53getConfig() {
        return super.getConfig();
    }
}
