package ai.libs.jaicore.ml.core.filter;

import ai.libs.jaicore.basic.reconstruction.ReconstructionInstruction;
import ai.libs.jaicore.basic.reconstruction.ReconstructionUtil;
import ai.libs.jaicore.ml.core.dataset.splitter.ReproducibleSplit;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.splitter.IDatasetSplitter;
import org.api4.java.ai.ml.core.dataset.splitter.IFoldSizeConfigurableRandomDatasetSplitter;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.reconstruction.IReconstructible;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/filter/FilterBasedDatasetSplitter.class */
public class FilterBasedDatasetSplitter<D extends IDataset<?>> implements IDatasetSplitter<D>, IFoldSizeConfigurableRandomDatasetSplitter<D>, ILoggingCustomizable {
    private final ISamplingAlgorithmFactory<D, ?> samplerFactory;
    private double relSampleSize;
    private Random random;
    private Logger logger = LoggerFactory.getLogger(FilterBasedDatasetSplitter.class);

    public FilterBasedDatasetSplitter(ISamplingAlgorithmFactory<D, ?> iSamplingAlgorithmFactory, double d, Random random) {
        this.samplerFactory = iSamplingAlgorithmFactory;
        this.relSampleSize = d;
        this.random = random;
    }

    public List<D> split(D d) throws SplitFailedException, InterruptedException {
        return split(d, this.random, this.relSampleSize);
    }

    public int getNumberOfFoldsPerSplit() {
        return 2;
    }

    public List<D> split(D d, Random random, double... dArr) throws SplitFailedException, InterruptedException {
        return getSplit(d, this.samplerFactory, random.nextLong(), this.logger, dArr);
    }

    public static <D extends IDataset<?>> List<D> getSplit(D d, ISamplingAlgorithmFactory<D, ?> iSamplingAlgorithmFactory, long j, List<Double> list) throws InterruptedException, SplitFailedException {
        return list.size() > 1 ? getSplit(d, iSamplingAlgorithmFactory, j, list.get(0).doubleValue(), list.get(1).doubleValue()) : getSplit(d, iSamplingAlgorithmFactory, j, list.get(0).doubleValue());
    }

    public static <D extends IDataset<?>> List<D> getSplit(D d, ISamplingAlgorithmFactory<D, ?> iSamplingAlgorithmFactory, long j, double... dArr) throws InterruptedException, SplitFailedException {
        return getSplit(d, iSamplingAlgorithmFactory, j, LoggerFactory.getLogger(FilterBasedDatasetSplitter.class), dArr);
    }

    /* JADX WARN: Type inference failed for: r0v19, types: [ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm, org.api4.java.ai.ml.core.filter.unsupervised.sampling.ISamplingAlgorithm] */
    /* JADX WARN: Type inference failed for: r8v19, types: [double[], java.lang.Object[]] */
    public static <D extends IDataset<?>> List<D> getSplit(D d, ISamplingAlgorithmFactory<D, ?> iSamplingAlgorithmFactory, long j, Logger logger, double... dArr) throws InterruptedException, SplitFailedException {
        Objects.requireNonNull(d);
        if (d.isEmpty()) {
            throw new IllegalArgumentException("Cannot split empty dataset.");
        }
        if (dArr.length > 2 || (dArr.length == 2 && dArr[0] + dArr[1] != 1.0d)) {
            throw new IllegalArgumentException("Invalid fold size specification " + Arrays.toString(dArr));
        }
        if ((d instanceof IReconstructible) && !(iSamplingAlgorithmFactory instanceof IReconstructible)) {
            throw new IllegalStateException("Given data is reproducible and so should the splitters, but the sampler factory used to create the sampling algorithm is not reproducible.");
        }
        int round = (int) Math.round(d.size() * dArr[0]);
        logger.info("Drawing 2-fold split with size {} for the first fold.", Integer.valueOf(round));
        ?? algorithm = iSamplingAlgorithmFactory.getAlgorithm(round, d, new Random(j));
        if (algorithm instanceof ILoggingCustomizable) {
            ((ILoggingCustomizable) algorithm).setLoggerName(logger.getName() + ".sampler");
        }
        try {
            IDataset nextSample = algorithm.nextSample();
            logger.debug("Sample for first fold completed, now computing the complement to fill the second fold.");
            IDataset complementOfLastSample = algorithm.getComplementOfLastSample();
            logger.info("Fold creation completed. Adding reconstruction information.");
            if (!(d instanceof IReconstructible)) {
                logger.info("Sampling-based split completed, returning two folds of sizes {} and {}.", Integer.valueOf(nextSample.size()), Integer.valueOf(complementOfLastSample.size()));
                return Arrays.asList(nextSample, complementOfLastSample);
            }
            if (!ReconstructionUtil.areInstructionsNonEmptyIfReconstructibilityClaimed(d)) {
                logger.info("Not making the split reproducible since the original data is not reproducible.");
                return Arrays.asList(nextSample, complementOfLastSample);
            }
            ArrayList arrayList = new ArrayList();
            for (double d2 : dArr) {
                arrayList.add(Double.valueOf(d2));
            }
            List instructions = ((IReconstructible) d).getConstructionPlan().getInstructions();
            IReconstructible iReconstructible = (IReconstructible) nextSample;
            Objects.requireNonNull(iReconstructible);
            instructions.forEach(iReconstructible::addInstruction);
            ((IReconstructible) nextSample).addInstruction(new ReconstructionInstruction(FilterBasedDatasetSplitter.class.getName(), "getFoldOfSplit", new Class[]{IDataset.class, ISamplingAlgorithmFactory.class, Long.TYPE, Integer.TYPE, List.class}, new Object[]{"this", iSamplingAlgorithmFactory, Long.valueOf(j), 0, arrayList}));
            IReconstructible iReconstructible2 = (IReconstructible) complementOfLastSample;
            Objects.requireNonNull(iReconstructible2);
            instructions.forEach(iReconstructible2::addInstruction);
            ((IReconstructible) complementOfLastSample).addInstruction(new ReconstructionInstruction(FilterBasedDatasetSplitter.class.getName(), "getFoldOfSplit", new Class[]{IDataset.class, ISamplingAlgorithmFactory.class, Long.TYPE, Integer.TYPE, List.class}, new Object[]{"this", iSamplingAlgorithmFactory, Long.valueOf(j), 1, arrayList}));
            ReconstructionUtil.requireNonEmptyInstructionsIfReconstructibilityClaimed(nextSample);
            ReconstructionUtil.requireNonEmptyInstructionsIfReconstructibilityClaimed(complementOfLastSample);
            ReconstructionInstruction reconstructionInstruction = new ReconstructionInstruction(FilterBasedDatasetSplitter.class.getName(), "getSplit", new Class[]{IDataset.class, ISamplingAlgorithmFactory.class, Long.TYPE, List.class}, new Object[]{"this", iSamplingAlgorithmFactory, Long.valueOf(j), Arrays.asList(new double[]{dArr})});
            logger.info("Sampling-based split completed, returning two folds of sizes {} and {}.", Integer.valueOf(nextSample.size()), Integer.valueOf(complementOfLastSample.size()));
            return new ReproducibleSplit(reconstructionInstruction, d, nextSample, complementOfLastSample);
        } catch (DatasetCreationException e) {
            throw new SplitFailedException(e);
        }
    }

    public static <D extends IDataset<?>> D getFoldOfSplit(D d, ISamplingAlgorithmFactory<D, ?> iSamplingAlgorithmFactory, long j, int i, List<Double> list) throws InterruptedException, SplitFailedException {
        return (D) getSplit(d, iSamplingAlgorithmFactory, j, list).get(i);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }
}
