/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.task;

import java.beans.ConstructorProperties;
import java.util.concurrent.Callable;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.deeplearning4j.arbiter.GraphConfiguration;
import org.deeplearning4j.arbiter.listener.UIGraphStatusReportingListener;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.TaskCreator;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.runner.Status;
import org.deeplearning4j.arbiter.optimize.runner.listener.candidate.UICandidateStatusListener;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.components.text.ComponentText;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class ComputationGraphTaskCreator<A>
implements TaskCreator<GraphConfiguration, ComputationGraph, DataSetIterator, A> {
    private ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator;

    public Callable<OptimizationResult<GraphConfiguration, ComputationGraph, A>> create(Candidate<GraphConfiguration> candidate, DataProvider<DataSetIterator> dataProvider, ScoreFunction<ComputationGraph, DataSetIterator> scoreFunction, UICandidateStatusListener statusListener) {
        return new GraphLearningTask<A>(candidate, dataProvider, scoreFunction, this.modelEvaluator, statusListener);
    }

    @ConstructorProperties(value={"modelEvaluator"})
    public ComputationGraphTaskCreator(ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator) {
        this.modelEvaluator = modelEvaluator;
    }

    private static class GraphLearningTask<A>
    implements Callable<OptimizationResult<GraphConfiguration, ComputationGraph, A>> {
        private Candidate<GraphConfiguration> candidate;
        private DataProvider<DataSetIterator> dataProvider;
        private ScoreFunction<ComputationGraph, DataSetIterator> scoreFunction;
        private ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator;
        private UIGraphStatusReportingListener dl4jListener;

        public GraphLearningTask(Candidate<GraphConfiguration> candidate, DataProvider<DataSetIterator> dataProvider, ScoreFunction<ComputationGraph, DataSetIterator> scoreFunction, ModelEvaluator<ComputationGraph, DataSetIterator, A> modelEvaluator, UICandidateStatusListener listener) {
            this.candidate = candidate;
            this.dataProvider = dataProvider;
            this.scoreFunction = scoreFunction;
            this.modelEvaluator = modelEvaluator;
            this.dl4jListener = new UIGraphStatusReportingListener(listener);
        }

        @Override
        public OptimizationResult<GraphConfiguration, ComputationGraph, A> call() throws Exception {
            ComputationGraph net = new ComputationGraph(((GraphConfiguration)this.candidate.getValue()).getConfiguration());
            net.init();
            net.setListeners(new IterationListener[]{this.dl4jListener});
            DataSetIterator dataSetIterator = (DataSetIterator)this.dataProvider.trainData(this.candidate.getDataParameters());
            EarlyStoppingConfiguration<ComputationGraph> esConfig = ((GraphConfiguration)this.candidate.getValue()).getEarlyStoppingConfiguration();
            EarlyStoppingResult esResult = null;
            if (esConfig != null) {
                EarlyStoppingGraphTrainer trainer = new EarlyStoppingGraphTrainer(esConfig, net, dataSetIterator, (EarlyStoppingListener)this.dl4jListener);
                try {
                    esResult = trainer.fit();
                    net = (ComputationGraph)esResult.getBestModel();
                }
                catch (Exception e) {
                    this.dl4jListener.postReport(Status.Failed, null, new Component[]{new ComponentText("Unexpected exception during model training\n", null), new ComponentText(ExceptionUtils.getStackTrace((Throwable)e), null)});
                    throw e;
                }
                switch (esResult.getTerminationReason()) {
                    case Error: {
                        this.dl4jListener.postReport(Status.Failed, esResult, new Component[0]);
                        break;
                    }
                    case IterationTerminationCondition: 
                    case EpochTerminationCondition: {
                        this.dl4jListener.postReport(Status.Complete, esResult, new Component[0]);
                    }
                }
            } else {
                int nEpochs = ((GraphConfiguration)this.candidate.getValue()).getNumEpochs();
                for (int i = 0; i < nEpochs; ++i) {
                    net.fit(dataSetIterator);
                    dataSetIterator.reset();
                }
                this.dl4jListener.postReport(Status.Complete, null, new Component[0]);
            }
            Object additionalEvaluation = null;
            if (esConfig != null && esResult.getTerminationReason() != EarlyStoppingResult.TerminationReason.Error) {
                try {
                    additionalEvaluation = this.modelEvaluator != null ? this.modelEvaluator.evaluateModel((Object)net, this.dataProvider) : null;
                }
                catch (Exception e) {
                    this.dl4jListener.postReport(Status.Failed, esResult, new Component[]{new ComponentText("Failed during additional evaluation stage\n", null), new ComponentText(ExceptionUtils.getStackTrace((Throwable)e), null)});
                }
            }
            Double score = null;
            if (net == null) {
                this.dl4jListener.postReport(Status.Complete, esResult, new Component[]{new ComponentText("No best model available; cannot calculate model score", null)});
            } else {
                try {
                    score = this.scoreFunction.score((Object)net, this.dataProvider, this.candidate.getDataParameters());
                }
                catch (Exception e) {
                    this.dl4jListener.postReport(Status.Failed, esResult, new Component[]{new ComponentText("Failed during score calculation stage\n", null), new ComponentText(ExceptionUtils.getStackTrace((Throwable)e), null)});
                }
            }
            return new OptimizationResult(this.candidate, (Object)net, score, this.candidate.getIndex(), additionalEvaluation);
        }
    }
}

