/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.arbiter.BaseNetworkSpace;
import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

public class MultiLayerSpace
extends BaseNetworkSpace<DL4JConfiguration> {
    @Deprecated
    private ParameterSpace<int[]> cnnInputSize;
    private List<LayerConf> layerSpaces = new ArrayList<LayerConf>();
    private ParameterSpace<InputType> inputType;
    private EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration;
    private int numParameters;

    private MultiLayerSpace(Builder builder) {
        super(builder);
        this.cnnInputSize = builder.cnnInputSize;
        this.inputType = builder.inputType;
        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
        this.layerSpaces = builder.layerSpaces;
        List list = CollectionUtils.getUnique(this.collectLeaves());
        for (ParameterSpace ps : list) {
            this.numParameters += ps.numParameters();
        }
    }

    public DL4JConfiguration getValue(double[] values) {
        ArrayList<Layer> layers = new ArrayList<Layer>();
        for (LayerConf c : this.layerSpaces) {
            int n = (Integer)c.numLayers.getValue(values);
            if (c.duplicateConfig) {
                Layer l = (Layer)c.layerSpace.getValue(values);
                for (int i = 0; i < n; ++i) {
                    layers.add(l.clone());
                }
                continue;
            }
            throw new UnsupportedOperationException("Not yet implemented");
        }
        NeuralNetConfiguration.Builder builder = this.randomGlobalConf(values);
        int lastNOut = ((FeedForwardLayer)layers.get(0)).getNOut();
        for (int i = 1; i < layers.size(); ++i) {
            FeedForwardLayer ffl = (FeedForwardLayer)layers.get(i);
            ffl.setNIn(lastNOut);
            lastNOut = ffl.getNOut();
        }
        NeuralNetConfiguration.ListBuilder listBuilder = builder.list();
        for (int i = 0; i < layers.size(); ++i) {
            listBuilder.layer(i, (Layer)layers.get(i));
        }
        if (this.backprop != null) {
            listBuilder.backprop(((Boolean)this.backprop.getValue(values)).booleanValue());
        }
        if (this.pretrain != null) {
            listBuilder.pretrain(((Boolean)this.pretrain.getValue(values)).booleanValue());
        }
        if (this.backpropType != null) {
            listBuilder.backpropType((BackpropType)this.backpropType.getValue(values));
        }
        if (this.tbpttFwdLength != null) {
            listBuilder.tBPTTForwardLength(((Integer)this.tbpttFwdLength.getValue(values)).intValue());
        }
        if (this.tbpttBwdLength != null) {
            listBuilder.tBPTTBackwardLength(((Integer)this.tbpttBwdLength.getValue(values)).intValue());
        }
        if (this.cnnInputSize != null) {
            listBuilder.cnnInputSize((int[])this.cnnInputSize.getValue(values));
        }
        if (this.inputType != null) {
            listBuilder.setInputType((InputType)this.inputType.getValue(values));
        }
        MultiLayerConfiguration configuration = listBuilder.build();
        return new DL4JConfiguration(configuration, this.earlyStoppingConfiguration, this.numEpochs);
    }

    public int numParameters() {
        return this.numParameters;
    }

    @Override
    public List<ParameterSpace> collectLeaves() {
        List<ParameterSpace> list = super.collectLeaves();
        for (LayerConf lc : this.layerSpaces) {
            list.addAll(lc.numLayers.collectLeaves());
            list.addAll(lc.layerSpace.collectLeaves());
        }
        if (this.cnnInputSize != null) {
            list.addAll(this.cnnInputSize.collectLeaves());
        }
        if (this.inputType != null) {
            list.addAll(this.inputType.collectLeaves());
        }
        return list;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder(super.toString());
        int i = 0;
        for (LayerConf conf : this.layerSpaces) {
            sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers).append(", duplicate: ").append(conf.duplicateConfig).append("), ").append(conf.layerSpace.toString()).append("\n");
        }
        if (this.cnnInputSize != null) {
            sb.append("cnnInputSize: ").append(this.cnnInputSize).append("\n");
        }
        if (this.inputType != null) {
            sb.append("inputType: ").append(this.inputType).append("\n");
        }
        if (this.earlyStoppingConfiguration != null) {
            sb.append("Early stopping configuration:").append(this.earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(this.numEpochs).append("\n");
        }
        return sb.toString();
    }

    public static class Builder
    extends BaseNetworkSpace.Builder<Builder> {
        @Deprecated
        private ParameterSpace<int[]> cnnInputSize;
        private List<LayerConf> layerSpaces = new ArrayList<LayerConf>();
        private ParameterSpace<InputType> inputType;
        private EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration;

        @Deprecated
        public Builder cnnInputSize(int height, int width, int depth) {
            return this.cnnInputSize((ParameterSpace<int[]>)new FixedValue((Object)new int[]{height, width, depth}));
        }

        @Deprecated
        public Builder cnnInputSize(ParameterSpace<int[]> cnnInputSize) {
            this.cnnInputSize = cnnInputSize;
            return this;
        }

        public Builder setInputType(InputType inputType) {
            return this.setInputType((ParameterSpace<InputType>)new FixedValue((Object)inputType));
        }

        public Builder setInputType(ParameterSpace<InputType> inputType) {
            this.inputType = inputType;
            return this;
        }

        public Builder addLayer(LayerSpace<?> layerSpace) {
            return this.addLayer(layerSpace, (ParameterSpace<Integer>)new FixedValue((Object)1), true);
        }

        public Builder addLayer(LayerSpace<? extends Layer> layerSpace, ParameterSpace<Integer> numLayersDistribution, boolean duplicateConfig) {
            this.layerSpaces.add(new LayerConf(layerSpace, numLayersDistribution, duplicateConfig));
            return this;
        }

        public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration) {
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        @Override
        public MultiLayerSpace build() {
            return new MultiLayerSpace(this);
        }
    }

    private static class LayerConf {
        private final LayerSpace<?> layerSpace;
        private final ParameterSpace<Integer> numLayers;
        private final boolean duplicateConfig;

        @ConstructorProperties(value={"layerSpace", "numLayers", "duplicateConfig"})
        public LayerConf(LayerSpace<?> layerSpace, ParameterSpace<Integer> numLayers, boolean duplicateConfig) {
            this.layerSpace = layerSpace;
            this.numLayers = numLayers;
            this.duplicateConfig = duplicateConfig;
        }
    }
}

