package ai.libs.mlplan.core;

import ai.libs.jaicore.components.model.ComponentInstance;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import java.util.ArrayList;
import java.util.List;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/mlplan/core/TimeTrackingLearnerWrapper.class */
public class TimeTrackingLearnerWrapper extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, IPrediction, IPredictionBatch> implements ITimeTrackingLearner {
    private static final Logger LOGGER = LoggerFactory.getLogger(TimeTrackingLearnerWrapper.class);
    private final ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> wrappedSLClassifier;
    private ComponentInstance ci;
    private Double score;
    private Double predictedInductionTime = null;
    private Double predictedInferenceTime = null;
    private List<Long> fitTimes = new ArrayList();
    private List<Long> batchPredictTimes = new ArrayList();
    private List<Long> perInstancePredictionTimes = new ArrayList();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/libs/mlplan/core/TimeTrackingLearnerWrapper$TimeTracker.class */
    public class TimeTracker {
        private final long startTime;

        private TimeTracker() {
            this.startTime = System.currentTimeMillis();
        }

        public long stop() {
            return System.currentTimeMillis() - this.startTime;
        }
    }

    public TimeTrackingLearnerWrapper(ComponentInstance componentInstance, ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> iSupervisedLearner) {
        this.ci = componentInstance;
        this.wrappedSLClassifier = iSupervisedLearner;
    }

    public void fit(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws TrainingException, InterruptedException {
        TimeTracker timeTracker = new TimeTracker();
        this.wrappedSLClassifier.fit(iLabeledDataset);
        this.fitTimes.add(Long.valueOf(timeTracker.stop()));
    }

    public IPrediction predict(ILabeledInstance iLabeledInstance) throws PredictionException, InterruptedException {
        TimeTracker timeTracker = new TimeTracker();
        IPrediction predict = this.wrappedSLClassifier.predict(iLabeledInstance);
        this.perInstancePredictionTimes.add(Long.valueOf(timeTracker.stop()));
        return predict;
    }

    public IPredictionBatch predict(ILabeledInstance[] iLabeledInstanceArr) throws PredictionException, InterruptedException {
        TimeTracker timeTracker = new TimeTracker();
        IPredictionBatch predict = this.wrappedSLClassifier.predict(iLabeledInstanceArr);
        long stop = timeTracker.stop();
        this.batchPredictTimes.add(Long.valueOf(stop));
        this.perInstancePredictionTimes.add(Long.valueOf(Math.round(stop / iLabeledInstanceArr.length)));
        return predict;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public List<Long> getFitTimes() {
        return this.fitTimes;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public List<Long> getBatchPredictionTimesInMS() {
        return this.batchPredictTimes;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public List<Long> getInstancePredictionTimesInMS() {
        return this.perInstancePredictionTimes;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public ComponentInstance getComponentInstance() {
        return this.ci;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public void setPredictedInductionTime(String str) {
        try {
            this.predictedInductionTime = Double.valueOf(Double.parseDouble(str));
        } catch (Exception e) {
            LOGGER.warn("Could not parse double from provided induction time {}.", str, e);
        }
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public void setPredictedInferenceTime(String str) {
        try {
            this.predictedInferenceTime = Double.valueOf(Double.parseDouble(str));
        } catch (Exception e) {
            LOGGER.warn("Could not parse double from provided inference time {}.", str, e);
        }
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public Double getPredictedInductionTime() {
        return this.predictedInductionTime;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public Double getPredictedInferenceTime() {
        return this.predictedInferenceTime;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public void setScore(Double d) {
        if (d == null) {
            return;
        }
        this.score = d;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public Double getScore() {
        return this.score;
    }

    @Override // ai.libs.mlplan.core.ITimeTrackingLearner
    public ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> getLearner() {
        return this.wrappedSLClassifier;
    }

    public String toString() {
        return getClass().getName() + " -> " + this.wrappedSLClassifier.toString();
    }
}
