package org.arbiter.deeplearning4j.layers;

import java.util.List;
import org.arbiter.deeplearning4j.layers.FeedForwardLayerSpace;
import org.arbiter.optimize.api.ParameterSpace;
import org.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/arbiter/deeplearning4j/layers/BaseOutputLayerSpace.class */
public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer> extends FeedForwardLayerSpace<L> {
    protected ParameterSpace<LossFunctions.LossFunction> lossFunction;

    /* loaded from: input_file:org/arbiter/deeplearning4j/layers/BaseOutputLayerSpace$Builder.class */
    public static abstract class Builder<T> extends FeedForwardLayerSpace.Builder<T> {
        protected ParameterSpace<LossFunctions.LossFunction> lossFunction;

        public T lossFunction(LossFunctions.LossFunction lossFunction) {
            return lossFunction((ParameterSpace<LossFunctions.LossFunction>) new FixedValue(lossFunction));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T lossFunction(ParameterSpace<LossFunctions.LossFunction> parameterSpace) {
            this.lossFunction = parameterSpace;
            return this;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseOutputLayerSpace(Builder builder) {
        super(builder);
        this.lossFunction = builder.lossFunction;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setLayerOptionsBuilder(BaseOutputLayer.Builder builder, double[] dArr) {
        super.setLayerOptionsBuilder((FeedForwardLayer.Builder) builder, dArr);
        if (this.lossFunction != null) {
            builder.lossFunction((LossFunctions.LossFunction) this.lossFunction.getValue(dArr));
        }
    }

    @Override // org.arbiter.deeplearning4j.layers.FeedForwardLayerSpace, org.arbiter.deeplearning4j.layers.LayerSpace
    public List<ParameterSpace> collectLeaves() {
        List<ParameterSpace> collectLeaves = super.collectLeaves();
        if (this.lossFunction != null) {
            collectLeaves.addAll(this.lossFunction.collectLeaves());
        }
        return collectLeaves;
    }
}
