package ai.libs.jaicore.ml.weka.classification.timeseries.learner.shapelets;

import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.quality.FStat;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.quality.IQualityMeasure;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.shapelets.Shapelet;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.shapelets.search.AMinimumDistanceSearchStrategy;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.util.WekaTimeseriesUtil;
import ai.libs.jaicore.ml.weka.classification.timeseries.learner.shapelets.ShapeletTransformLearningAlgorithm;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.aeonbits.owner.ConfigCache;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/timeseries/learner/shapelets/ShapeletTransformTSClassifier.class */
public class ShapeletTransformTSClassifier extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(ShapeletTransformTSClassifier.class);
    private List<Shapelet> shapelets;
    private Classifier classifier;
    private AMinimumDistanceSearchStrategy minDistanceSearchStrategy;
    private final ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig config;
    private final IQualityMeasure qualityMeasure;

    public ShapeletTransformTSClassifier(int i, int i2) {
        this(i, new FStat(), i2, true);
    }

    public ShapeletTransformTSClassifier(int i, IQualityMeasure iQualityMeasure, int i2, boolean z) {
        this(i, i / 2, iQualityMeasure, i2, z, 3, 0, false, 1);
    }

    public ShapeletTransformTSClassifier(int i, int i2, IQualityMeasure iQualityMeasure, int i3, boolean z) {
        this(i, i2, iQualityMeasure, i3, z, 3, 0, false, 1);
    }

    public ShapeletTransformTSClassifier(int i, int i2, IQualityMeasure iQualityMeasure, int i3, boolean z, int i4, int i5, boolean z2, int i6) {
        this.config = ConfigCache.getOrCreate(ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig.class, new Map[0]);
        this.config.setProperty("numshapelets", "" + i);
        this.config.setProperty("seed", "" + i3);
        this.config.setProperty(ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig.K_CLUSTERSHAPELETS, "" + z);
        this.config.setProperty("minshapeletlength", "" + i4);
        this.config.setProperty(ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig.K_SHAPELETLENGTH_MAX, "" + i5);
        this.config.setProperty(ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig.K_USEHIVECOTEENSEMBLE, "" + z2);
        this.config.setProperty("numfolds", "" + i6);
        this.config.setProperty(ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig.K_NUMCLUSTERS, "" + i2);
        this.qualityMeasure = iQualityMeasure;
    }

    public List<Shapelet> getShapelets() {
        return this.shapelets;
    }

    public void setShapelets(List<Shapelet> list) {
        this.shapelets = list;
    }

    public void setClassifier(Classifier classifier) {
        this.classifier = classifier;
    }

    /* renamed from: predict, reason: merged with bridge method [inline-methods] */
    public Integer m57predict(double[] dArr) throws PredictionException {
        if (!isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        Instance simplifiedTSInstanceToWekaInstance = WekaTimeseriesUtil.simplifiedTSInstanceToWekaInstance(ShapeletTransformLearningAlgorithm.shapeletTransform(dArr, this.shapelets, this.minDistanceSearchStrategy));
        try {
            return Integer.valueOf((int) Math.round(this.classifier.classifyInstance(simplifiedTSInstanceToWekaInstance)));
        } catch (Exception e) {
            throw new PredictionException(String.format("Could not predict Weka instance %s.", simplifiedTSInstanceToWekaInstance.toString()), e);
        }
    }

    public Integer predict(List<double[]> list) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return m57predict(list.get(0));
    }

    public List<Integer> predict(TimeSeriesDataset2 timeSeriesDataset2) throws PredictionException {
        if (!isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (timeSeriesDataset2.isMultivariate()) {
            LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        }
        LOGGER.debug("Transforming dataset...");
        try {
            TimeSeriesDataset2 shapeletTransform = ShapeletTransformLearningAlgorithm.shapeletTransform(timeSeriesDataset2, this.shapelets, null, -1L, this.minDistanceSearchStrategy);
            LOGGER.debug("Transformed dataset.");
            if (shapeletTransform.getValuesOrNull(0) == null) {
                throw new IllegalArgumentException("Dataset matrix of the instances to be predicted must not be null!");
            }
            LOGGER.debug("Converting time series dataset to Weka instances...");
            Instances simplifiedTimeSeriesDatasetToWekaInstances = WekaTimeseriesUtil.simplifiedTimeSeriesDatasetToWekaInstances(shapeletTransform);
            LOGGER.debug("Converted time series dataset to Weka instances.");
            LOGGER.debug("Starting prediction...");
            ArrayList arrayList = new ArrayList();
            Iterator it = simplifiedTimeSeriesDatasetToWekaInstances.iterator();
            while (it.hasNext()) {
                Instance instance = (Instance) it.next();
                try {
                    arrayList.add(Integer.valueOf((int) Math.round(this.classifier.classifyInstance(instance))));
                } catch (Exception e) {
                    throw new PredictionException(String.format("Could not predict Weka instance %s.", instance.toString()), e);
                }
            }
            LOGGER.debug("Finished prediction.");
            return arrayList;
        } catch (InterruptedException e2) {
            Thread.currentThread().interrupt();
            return new ArrayList();
        }
    }

    public AMinimumDistanceSearchStrategy getMinDistanceSearchStrategy() {
        return this.minDistanceSearchStrategy;
    }

    /* renamed from: getLearningAlgorithm, reason: merged with bridge method [inline-methods] */
    public ShapeletTransformLearningAlgorithm m55getLearningAlgorithm(TimeSeriesDataset2 timeSeriesDataset2) {
        return new ShapeletTransformLearningAlgorithm(this.config, this, timeSeriesDataset2, this.qualityMeasure);
    }

    /* renamed from: predict, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Object m56predict(List list) throws PredictionException {
        return predict((List<double[]>) list);
    }
}
