/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.task;

import java.beans.ConstructorProperties;
import java.util.concurrent.Callable;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.listener.BaseUIStatusReportingListener;
import org.deeplearning4j.arbiter.listener.UIStatusReportingListener;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.TaskCreator;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.runner.Status;
import org.deeplearning4j.arbiter.optimize.runner.listener.candidate.UICandidateStatusListener;
import org.deeplearning4j.arbiter.optimize.ui.ArbiterUIServer;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.components.text.ComponentText;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class MultiLayerNetworkTaskCreator<A>
implements TaskCreator<DL4JConfiguration, MultiLayerNetwork, DataSetIterator, A> {
    private ModelEvaluator<MultiLayerNetwork, DataSetIterator, A> modelEvaluator;

    public Callable<OptimizationResult<DL4JConfiguration, MultiLayerNetwork, A>> create(Candidate<DL4JConfiguration> candidate, DataProvider<DataSetIterator> dataProvider, ScoreFunction<MultiLayerNetwork, DataSetIterator> scoreFunction, UICandidateStatusListener statusListener) {
        return new DL4JLearningTask<A>(candidate, dataProvider, scoreFunction, this.modelEvaluator, statusListener);
    }

    @ConstructorProperties(value={"modelEvaluator"})
    public MultiLayerNetworkTaskCreator(ModelEvaluator<MultiLayerNetwork, DataSetIterator, A> modelEvaluator) {
        this.modelEvaluator = modelEvaluator;
    }

    public MultiLayerNetworkTaskCreator() {
    }

    private static class DL4JLearningTask<A>
    implements Callable<OptimizationResult<DL4JConfiguration, MultiLayerNetwork, A>> {
        private Candidate<DL4JConfiguration> candidate;
        private DataProvider<DataSetIterator> dataProvider;
        private ScoreFunction<MultiLayerNetwork, DataSetIterator> scoreFunction;
        private ModelEvaluator<MultiLayerNetwork, DataSetIterator, A> modelEvaluator;
        private BaseUIStatusReportingListener<MultiLayerNetwork> dl4jListener;

        public DL4JLearningTask(Candidate<DL4JConfiguration> candidate, DataProvider<DataSetIterator> dataProvider, ScoreFunction<MultiLayerNetwork, DataSetIterator> scoreFunction, ModelEvaluator<MultiLayerNetwork, DataSetIterator, A> modelEvaluator, UICandidateStatusListener listener) {
            this.candidate = candidate;
            this.dataProvider = dataProvider;
            this.scoreFunction = scoreFunction;
            this.modelEvaluator = modelEvaluator;
            this.dl4jListener = ArbiterUIServer.isRunning() ? new UIStatusReportingListener(listener) : null;
        }

        @Override
        public OptimizationResult<DL4JConfiguration, MultiLayerNetwork, A> call() throws Exception {
            Double score;
            Object additionalEvaluation;
            MultiLayerNetwork net;
            block20: {
                EarlyStoppingResult esResult;
                block19: {
                    net = new MultiLayerNetwork(((DL4JConfiguration)this.candidate.getValue()).getMultiLayerConfiguration());
                    net.init();
                    net.setListeners(new IterationListener[]{this.dl4jListener});
                    DataSetIterator dataSetIterator = (DataSetIterator)this.dataProvider.trainData(this.candidate.getDataParameters());
                    EarlyStoppingConfiguration esConfig = ((DL4JConfiguration)this.candidate.getValue()).getEarlyStoppingConfiguration();
                    esResult = null;
                    if (esConfig != null) {
                        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConfig, net, dataSetIterator, this.dl4jListener);
                        try {
                            esResult = trainer.fit();
                            net = (MultiLayerNetwork)esResult.getBestModel();
                        }
                        catch (Exception e) {
                            if (this.dl4jListener != null) {
                                this.dl4jListener.postReport(Status.Failed, null, new Component[]{new ComponentText("Unexpected exception during model training\n", null), new ComponentText(ExceptionUtils.getStackTrace((Throwable)e), null)});
                            }
                            throw e;
                        }
                        switch (esResult.getTerminationReason()) {
                            case Error: {
                                if (this.dl4jListener == null) break;
                                this.dl4jListener.postReport(Status.Failed, (EarlyStoppingResult<MultiLayerNetwork>)esResult, new Component[0]);
                                break;
                            }
                            case IterationTerminationCondition: 
                            case EpochTerminationCondition: {
                                if (this.dl4jListener == null) break;
                                this.dl4jListener.postReport(Status.Complete, (EarlyStoppingResult<MultiLayerNetwork>)esResult, new Component[0]);
                            }
                        }
                    } else {
                        int nEpochs = ((DL4JConfiguration)this.candidate.getValue()).getNumEpochs();
                        for (int i = 0; i < nEpochs; ++i) {
                            net.fit(dataSetIterator);
                            dataSetIterator.reset();
                        }
                        if (this.dl4jListener != null) {
                            this.dl4jListener.postReport(Status.Complete, null, new Component[0]);
                        }
                    }
                    additionalEvaluation = null;
                    if (esConfig != null && esResult.getTerminationReason() != EarlyStoppingResult.TerminationReason.Error) {
                        try {
                            additionalEvaluation = this.modelEvaluator != null ? this.modelEvaluator.evaluateModel((Object)net, this.dataProvider) : null;
                        }
                        catch (Exception e) {
                            if (this.dl4jListener == null) break block19;
                            this.dl4jListener.postReport(Status.Failed, (EarlyStoppingResult<MultiLayerNetwork>)esResult, new Component[]{new ComponentText("Failed during additional evaluation stage\n", null), new ComponentText(ExceptionUtils.getStackTrace((Throwable)e), null)});
                        }
                    }
                }
                score = null;
                if (net == null) {
                    if (this.dl4jListener != null) {
                        this.dl4jListener.postReport(Status.Complete, (EarlyStoppingResult<MultiLayerNetwork>)esResult, new Component[]{new ComponentText("No best model available; cannot calculate model score", null)});
                    }
                } else {
                    try {
                        score = this.scoreFunction.score((Object)net, this.dataProvider, this.candidate.getDataParameters());
                    }
                    catch (Exception e) {
                        if (this.dl4jListener == null) break block20;
                        this.dl4jListener.postReport(Status.Failed, (EarlyStoppingResult<MultiLayerNetwork>)esResult, new Component[]{new ComponentText("Failed during score calculation stage\n", null), new ComponentText(ExceptionUtils.getStackTrace((Throwable)e), null)});
                    }
                }
            }
            return new OptimizationResult(this.candidate, (Object)net, score, this.candidate.getIndex(), additionalEvaluation);
        }
    }
}

