package ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.filter.unsupervised.sampling.ISamplingAlgorithm;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/filter/sampling/inmemory/casecontrol/APilotEstimateSampling.class */
public abstract class APilotEstimateSampling<D extends ILabeledDataset<? extends ILabeledInstance>> extends CaseControlLikeSampling<D> {
    private Logger logger;
    private final ISamplingAlgorithm<D> subSampler;
    protected int preSampleSize;
    private final IClassifier pilotEstimator;

    /* JADX INFO: Access modifiers changed from: protected */
    public APilotEstimateSampling(D d, IClassifier iClassifier) {
        this(d, null, 1, iClassifier);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public APilotEstimateSampling(D d, ISamplingAlgorithmFactory<D, ?> iSamplingAlgorithmFactory, int i, IClassifier iClassifier) {
        super(d);
        this.logger = LoggerFactory.getLogger(APilotEstimateSampling.class);
        Objects.requireNonNull(iClassifier);
        this.pilotEstimator = iClassifier;
        this.preSampleSize = i;
        if (iSamplingAlgorithmFactory != null) {
            this.subSampler = iSamplingAlgorithmFactory.getAlgorithm(i, d, new Random(0L));
        } else {
            this.subSampler = null;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.CaseControlLikeSampling
    public List<Pair<ILabeledInstance, Double>> computeAcceptanceThresholds() throws InterruptedException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException, AlgorithmException {
        if (this.subSampler != null) {
            ILabeledDataset iLabeledDataset = (ILabeledDataset) this.subSampler.call();
            this.logger.info("Fitting pilot with reduced dataset of {}/{} instances.", Integer.valueOf(iLabeledDataset.size()), Integer.valueOf(((ILabeledDataset) getInput()).size()));
            this.pilotEstimator.fit(iLabeledDataset);
        } else {
            this.logger.info("Fitting pilot with full dataset.");
            this.pilotEstimator.fit((ILabeledDataset) getInput());
        }
        return calculateAcceptanceThresholdsWithTrainedPilot((ILabeledDataset) getInput(), this.pilotEstimator);
    }

    public abstract List<Pair<ILabeledInstance, Double>> calculateAcceptanceThresholdsWithTrainedPilot(D d, IClassifier iClassifier) throws InterruptedException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException, AlgorithmException;

    public IClassifier getPilotEstimator() {
        return this.pilotEstimator;
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.CaseControlLikeSampling, ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm
    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        super.setLoggerName(str + ".ccsampling");
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.CaseControlLikeSampling, ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm
    public String getLoggerName() {
        return this.logger.getName();
    }
}
