package org.deeplearning4j.arbiter.task;

import java.beans.ConstructorProperties;
import java.util.concurrent.Callable;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.deeplearning4j.arbiter.GraphConfiguration;
import org.deeplearning4j.arbiter.listener.UIGraphStatusReportingListener;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.TaskCreator;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.runner.Status;
import org.deeplearning4j.arbiter.optimize.runner.listener.candidate.UICandidateStatusListener;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.components.text.ComponentText;
import org.deeplearning4j.ui.components.text.style.StyleText;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/* loaded from: input_file:org/deeplearning4j/arbiter/task/ComputationGraphTaskCreator.class */
public class ComputationGraphTaskCreator<A> implements TaskCreator<GraphConfiguration, ComputationGraph, DataSetIterator, A> {
    private ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/arbiter/task/ComputationGraphTaskCreator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$earlystopping$EarlyStoppingResult$TerminationReason = new int[EarlyStoppingResult.TerminationReason.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$earlystopping$EarlyStoppingResult$TerminationReason[EarlyStoppingResult.TerminationReason.Error.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$earlystopping$EarlyStoppingResult$TerminationReason[EarlyStoppingResult.TerminationReason.IterationTerminationCondition.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$earlystopping$EarlyStoppingResult$TerminationReason[EarlyStoppingResult.TerminationReason.EpochTerminationCondition.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/arbiter/task/ComputationGraphTaskCreator$GraphLearningTask.class */
    private static class GraphLearningTask<A> implements Callable<OptimizationResult<GraphConfiguration, ComputationGraph, A>> {
        private Candidate<GraphConfiguration> candidate;
        private DataProvider<DataSetIterator> dataProvider;
        private ScoreFunction<ComputationGraph, DataSetIterator> scoreFunction;
        private ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator;
        private UIGraphStatusReportingListener dl4jListener;

        public GraphLearningTask(Candidate<GraphConfiguration> candidate, DataProvider<DataSetIterator> dataProvider, ScoreFunction<ComputationGraph, DataSetIterator> scoreFunction, ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator, UICandidateStatusListener uICandidateStatusListener) {
            this.candidate = candidate;
            this.dataProvider = dataProvider;
            this.scoreFunction = scoreFunction;
            this.modelEvaluator = modelEvaluator;
            this.dl4jListener = new UIGraphStatusReportingListener(uICandidateStatusListener);
        }

        @Override // java.util.concurrent.Callable
        public OptimizationResult<GraphConfiguration, ComputationGraph, A> call() throws Exception {
            ComputationGraph computationGraph = new ComputationGraph(((GraphConfiguration) this.candidate.getValue()).getConfiguration());
            computationGraph.init();
            computationGraph.setListeners(new IterationListener[]{this.dl4jListener});
            DataSetIterator dataSetIterator = (DataSetIterator) this.dataProvider.trainData(this.candidate.getDataParameters());
            EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration = ((GraphConfiguration) this.candidate.getValue()).getEarlyStoppingConfiguration();
            EarlyStoppingResult earlyStoppingResult = null;
            if (earlyStoppingConfiguration != null) {
                try {
                    earlyStoppingResult = new EarlyStoppingGraphTrainer(earlyStoppingConfiguration, computationGraph, dataSetIterator, this.dl4jListener).fit();
                    computationGraph = (ComputationGraph) earlyStoppingResult.getBestModel();
                    switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$earlystopping$EarlyStoppingResult$TerminationReason[earlyStoppingResult.getTerminationReason().ordinal()]) {
                        case 1:
                            this.dl4jListener.postReport(Status.Failed, earlyStoppingResult, new Component[0]);
                            break;
                        case 2:
                        case 3:
                            this.dl4jListener.postReport(Status.Complete, earlyStoppingResult, new Component[0]);
                            break;
                    }
                } catch (Exception e) {
                    this.dl4jListener.postReport(Status.Failed, null, new ComponentText("Unexpected exception during model training\n", (StyleText) null), new ComponentText(ExceptionUtils.getStackTrace(e), (StyleText) null));
                    throw e;
                }
            } else {
                int intValue = ((GraphConfiguration) this.candidate.getValue()).getNumEpochs().intValue();
                for (int i = 0; i < intValue; i++) {
                    computationGraph.fit(dataSetIterator);
                    dataSetIterator.reset();
                }
                this.dl4jListener.postReport(Status.Complete, null, new Component[0]);
            }
            Object obj = null;
            if (earlyStoppingConfiguration != null && earlyStoppingResult.getTerminationReason() != EarlyStoppingResult.TerminationReason.Error) {
                try {
                    obj = this.modelEvaluator != null ? this.modelEvaluator.evaluateModel(computationGraph, this.dataProvider) : null;
                } catch (Exception e2) {
                    this.dl4jListener.postReport(Status.Failed, earlyStoppingResult, new ComponentText("Failed during additional evaluation stage\n", (StyleText) null), new ComponentText(ExceptionUtils.getStackTrace(e2), (StyleText) null));
                }
            }
            Double d = null;
            if (computationGraph == null) {
                this.dl4jListener.postReport(Status.Complete, earlyStoppingResult, new ComponentText("No best model available; cannot calculate model score", (StyleText) null));
            } else {
                try {
                    d = Double.valueOf(this.scoreFunction.score(computationGraph, this.dataProvider, this.candidate.getDataParameters()));
                } catch (Exception e3) {
                    this.dl4jListener.postReport(Status.Failed, earlyStoppingResult, new ComponentText("Failed during score calculation stage\n", (StyleText) null), new ComponentText(ExceptionUtils.getStackTrace(e3), (StyleText) null));
                }
            }
            return new OptimizationResult<>(this.candidate, computationGraph, d, this.candidate.getIndex(), obj);
        }
    }

    public Callable<OptimizationResult<GraphConfiguration, ComputationGraph, A>> create(Candidate<GraphConfiguration> candidate, DataProvider<DataSetIterator> dataProvider, ScoreFunction<ComputationGraph, DataSetIterator> scoreFunction, UICandidateStatusListener uICandidateStatusListener) {
        return new GraphLearningTask(candidate, dataProvider, scoreFunction, this.modelEvaluator, uICandidateStatusListener);
    }

    @ConstructorProperties({"modelEvaluator"})
    public ComputationGraphTaskCreator(ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator) {
        this.modelEvaluator = modelEvaluator;
    }
}
