/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.layers;

import java.util.List;
import org.deeplearning4j.arbiter.layers.FeedForwardLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer>
extends FeedForwardLayerSpace<L> {
    protected ParameterSpace<LossFunctions.LossFunction> lossFunction;

    protected BaseOutputLayerSpace(Builder builder) {
        super(builder);
        this.lossFunction = builder.lossFunction;
    }

    protected void setLayerOptionsBuilder(BaseOutputLayer.Builder builder, double[] values) {
        super.setLayerOptionsBuilder((FeedForwardLayer.Builder)builder, values);
        if (this.lossFunction != null) {
            builder.lossFunction((LossFunctions.LossFunction)this.lossFunction.getValue(values));
        }
    }

    @Override
    public List<ParameterSpace> collectLeaves() {
        List<ParameterSpace> list = super.collectLeaves();
        if (this.lossFunction != null) {
            list.addAll(this.lossFunction.collectLeaves());
        }
        return list;
    }

    public static abstract class Builder<T>
    extends FeedForwardLayerSpace.Builder<T> {
        protected ParameterSpace<LossFunctions.LossFunction> lossFunction;

        public T lossFunction(LossFunctions.LossFunction lossFunction) {
            return this.lossFunction((ParameterSpace<LossFunctions.LossFunction>)new FixedValue((Object)lossFunction));
        }

        public T lossFunction(ParameterSpace<LossFunctions.LossFunction> lossFunction) {
            this.lossFunction = lossFunction;
            return (T)this;
        }
    }
}

