package org.arbiter.deeplearning4j.layers;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.arbiter.optimize.api.ParameterSpace;
import org.arbiter.optimize.parameter.FixedValue;
import org.arbiter.util.CollectionUtils;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInit;

/* loaded from: input_file:org/arbiter/deeplearning4j/layers/LayerSpace.class */
public abstract class LayerSpace<L extends Layer> implements ParameterSpace<L> {
    protected ParameterSpace<String> activationFunction;
    protected ParameterSpace<WeightInit> weightInit;
    protected ParameterSpace<Double> biasInit;
    protected ParameterSpace<Distribution> dist;
    protected ParameterSpace<Double> learningRate;
    protected ParameterSpace<Map<Integer, Double>> learningRateAfter;
    protected ParameterSpace<Double> lrScoreBasedDecay;
    protected ParameterSpace<Double> l1;
    protected ParameterSpace<Double> l2;
    protected ParameterSpace<Double> dropOut;
    protected ParameterSpace<Double> momentum;
    protected ParameterSpace<Map<Integer, Double>> momentumAfter;
    protected ParameterSpace<Updater> updater;
    protected ParameterSpace<Double> rho;
    protected ParameterSpace<Double> rmsDecay;
    protected ParameterSpace<GradientNormalization> gradientNormalization;
    protected ParameterSpace<Double> gradientNormalizationThreshold;
    private int numParameters = CollectionUtils.countUnique(collectLeaves());

    /* loaded from: input_file:org/arbiter/deeplearning4j/layers/LayerSpace$Builder.class */
    public static abstract class Builder<T> {
        protected ParameterSpace<String> activationFunction;
        protected ParameterSpace<WeightInit> weightInit;
        protected ParameterSpace<Double> biasInit;
        protected ParameterSpace<Distribution> dist;
        protected ParameterSpace<Double> learningRate;
        protected ParameterSpace<Map<Integer, Double>> learningRateAfter;
        protected ParameterSpace<Double> lrScoreBasedDecay;
        protected ParameterSpace<Double> l1;
        protected ParameterSpace<Double> l2;
        protected ParameterSpace<Double> dropOut;
        protected ParameterSpace<Double> momentum;
        protected ParameterSpace<Map<Integer, Double>> momentumAfter;
        protected ParameterSpace<Updater> updater;
        protected ParameterSpace<Double> rho;
        protected ParameterSpace<Double> rmsDecay;
        protected ParameterSpace<GradientNormalization> gradientNormalization;
        protected ParameterSpace<Double> gradientNormalizationThreshold;

        public T activation(String str) {
            return activation((ParameterSpace<String>) new FixedValue(str));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T activation(ParameterSpace<String> parameterSpace) {
            this.activationFunction = parameterSpace;
            return this;
        }

        public T weightInit(WeightInit weightInit) {
            return weightInit((ParameterSpace<WeightInit>) new FixedValue(weightInit));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T weightInit(ParameterSpace<WeightInit> parameterSpace) {
            this.weightInit = parameterSpace;
            return this;
        }

        public T dist(Distribution distribution) {
            return dist((ParameterSpace<Distribution>) new FixedValue(distribution));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T dist(ParameterSpace<Distribution> parameterSpace) {
            this.dist = parameterSpace;
            return this;
        }

        public T learningRate(double d) {
            return learningRate((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T learningRate(ParameterSpace<Double> parameterSpace) {
            this.learningRate = parameterSpace;
            return this;
        }

        public T learningRateAfter(Map<Integer, Double> map) {
            return learningRateAfter((ParameterSpace<Map<Integer, Double>>) new FixedValue(map));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T learningRateAfter(ParameterSpace<Map<Integer, Double>> parameterSpace) {
            this.learningRateAfter = parameterSpace;
            return this;
        }

        public T learningRateScoreBasedDecayRate(double d) {
            return learningRateScoreBasedDecayRate((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T learningRateScoreBasedDecayRate(ParameterSpace<Double> parameterSpace) {
            this.lrScoreBasedDecay = parameterSpace;
            return this;
        }

        public T l1(double d) {
            return l1((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T l1(ParameterSpace<Double> parameterSpace) {
            this.l1 = parameterSpace;
            return this;
        }

        public T l2(double d) {
            return l2((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T l2(ParameterSpace<Double> parameterSpace) {
            this.l2 = parameterSpace;
            return this;
        }

        public T dropOut(double d) {
            return dropOut((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T dropOut(ParameterSpace<Double> parameterSpace) {
            this.dropOut = parameterSpace;
            return this;
        }

        public T momentum(double d) {
            return momentum((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T momentum(ParameterSpace<Double> parameterSpace) {
            this.momentum = parameterSpace;
            return this;
        }

        public T momentumAfter(Map<Integer, Double> map) {
            return momentumAfter((ParameterSpace<Map<Integer, Double>>) new FixedValue(map));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T momentumAfter(ParameterSpace<Map<Integer, Double>> parameterSpace) {
            this.momentumAfter = parameterSpace;
            return this;
        }

        public T updater(Updater updater) {
            return updater((ParameterSpace<Updater>) new FixedValue(updater));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T updater(ParameterSpace<Updater> parameterSpace) {
            this.updater = parameterSpace;
            return this;
        }

        public T rho(double d) {
            return rho((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T rho(ParameterSpace<Double> parameterSpace) {
            this.rho = parameterSpace;
            return this;
        }

        public T rmsDecay(double d) {
            return rmsDecay((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T rmsDecay(ParameterSpace<Double> parameterSpace) {
            this.rmsDecay = parameterSpace;
            return this;
        }

        public T gradientNormalization(GradientNormalization gradientNormalization) {
            return gradientNormalization((ParameterSpace<GradientNormalization>) new FixedValue(gradientNormalization));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T gradientNormalization(ParameterSpace<GradientNormalization> parameterSpace) {
            this.gradientNormalization = parameterSpace;
            return this;
        }

        public T gradientNormalizationThreshold(double d) {
            return gradientNormalizationThreshold((ParameterSpace<Double>) new FixedValue(Double.valueOf(d)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public T gradientNormalizationThreshold(ParameterSpace<Double> parameterSpace) {
            this.gradientNormalizationThreshold = parameterSpace;
            return this;
        }

        public abstract <E extends LayerSpace> E build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public LayerSpace(Builder builder) {
        this.activationFunction = builder.activationFunction;
        this.weightInit = builder.weightInit;
        this.biasInit = builder.biasInit;
        this.dist = builder.dist;
        this.learningRate = builder.learningRate;
        this.learningRateAfter = builder.learningRateAfter;
        this.lrScoreBasedDecay = builder.lrScoreBasedDecay;
        this.l1 = builder.l1;
        this.l2 = builder.l2;
        this.dropOut = builder.dropOut;
        this.momentum = builder.momentum;
        this.momentumAfter = builder.momentumAfter;
        this.updater = builder.updater;
        this.rho = builder.rho;
        this.rmsDecay = builder.rmsDecay;
        this.gradientNormalization = builder.gradientNormalization;
        this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
    }

    public List<ParameterSpace> collectLeaves() {
        ArrayList arrayList = new ArrayList();
        if (this.activationFunction != null) {
            arrayList.addAll(this.activationFunction.collectLeaves());
        }
        if (this.weightInit != null) {
            arrayList.addAll(this.weightInit.collectLeaves());
        }
        if (this.biasInit != null) {
            arrayList.addAll(this.biasInit.collectLeaves());
        }
        if (this.dist != null) {
            arrayList.addAll(this.dist.collectLeaves());
        }
        if (this.learningRate != null) {
            arrayList.addAll(this.learningRate.collectLeaves());
        }
        if (this.learningRateAfter != null) {
            arrayList.addAll(this.learningRateAfter.collectLeaves());
        }
        if (this.lrScoreBasedDecay != null) {
            arrayList.addAll(this.lrScoreBasedDecay.collectLeaves());
        }
        if (this.l1 != null) {
            arrayList.addAll(this.l1.collectLeaves());
        }
        if (this.l2 != null) {
            arrayList.addAll(this.l2.collectLeaves());
        }
        if (this.dropOut != null) {
            arrayList.addAll(this.dropOut.collectLeaves());
        }
        if (this.momentum != null) {
            arrayList.addAll(this.momentum.collectLeaves());
        }
        if (this.momentumAfter != null) {
            arrayList.addAll(this.momentumAfter.collectLeaves());
        }
        if (this.updater != null) {
            arrayList.addAll(this.updater.collectLeaves());
        }
        if (this.rho != null) {
            arrayList.addAll(this.rho.collectLeaves());
        }
        if (this.rmsDecay != null) {
            arrayList.addAll(this.rmsDecay.collectLeaves());
        }
        if (this.gradientNormalization != null) {
            arrayList.addAll(this.gradientNormalization.collectLeaves());
        }
        if (this.gradientNormalizationThreshold != null) {
            arrayList.addAll(this.gradientNormalizationThreshold.collectLeaves());
        }
        return arrayList;
    }

    public int numParameters() {
        return this.numParameters;
    }

    public boolean isLeaf() {
        return false;
    }

    public void setIndices(int... iArr) {
        throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setLayerOptionsBuilder(Layer.Builder builder, double[] dArr) {
        if (this.activationFunction != null) {
            builder.activation((String) this.activationFunction.getValue(dArr));
        }
        if (this.weightInit != null) {
            builder.weightInit((WeightInit) this.weightInit.getValue(dArr));
        }
        if (this.biasInit != null) {
            builder.biasInit(((Double) this.biasInit.getValue(dArr)).doubleValue());
        }
        if (this.dist != null) {
            builder.dist((Distribution) this.dist.getValue(dArr));
        }
        if (this.learningRate != null) {
            builder.learningRate(((Double) this.learningRate.getValue(dArr)).doubleValue());
        }
        if (this.learningRateAfter != null) {
            builder.learningRateSchedule((Map) this.learningRateAfter.getValue(dArr));
        }
        if (this.lrScoreBasedDecay != null) {
            builder.learningRate(((Double) this.lrScoreBasedDecay.getValue(dArr)).doubleValue());
        }
        if (this.l1 != null) {
            builder.l1(((Double) this.l1.getValue(dArr)).doubleValue());
        }
        if (this.l2 != null) {
            builder.l2(((Double) this.l2.getValue(dArr)).doubleValue());
        }
        if (this.dropOut != null) {
            builder.dropOut(((Double) this.dropOut.getValue(dArr)).doubleValue());
        }
        if (this.momentum != null) {
            builder.momentum(((Double) this.momentum.getValue(dArr)).doubleValue());
        }
        if (this.momentumAfter != null) {
            builder.momentumAfter((Map) this.momentumAfter.getValue(dArr));
        }
        if (this.updater != null) {
            builder.updater((Updater) this.updater.getValue(dArr));
        }
        if (this.rho != null) {
            builder.rho(((Double) this.rho.getValue(dArr)).doubleValue());
        }
        if (this.rmsDecay != null) {
            builder.rmsDecay(((Double) this.rmsDecay.getValue(dArr)).doubleValue());
        }
        if (this.gradientNormalization != null) {
            builder.gradientNormalization((GradientNormalization) this.gradientNormalization.getValue(dArr));
        }
        if (this.gradientNormalizationThreshold != null) {
            builder.gradientNormalizationThreshold(((Double) this.gradientNormalizationThreshold.getValue(dArr)).doubleValue());
        }
    }

    public String toString() {
        return toString(", ");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String toString(String str) {
        StringBuilder sb = new StringBuilder();
        if (this.activationFunction != null) {
            sb.append("activationFunction: ").append(this.activationFunction).append(str);
        }
        if (this.weightInit != null) {
            sb.append("weightInit: ").append(this.weightInit).append(str);
        }
        if (this.biasInit != null) {
            sb.append("biasInit: ").append(this.biasInit).append(str);
        }
        if (this.dist != null) {
            sb.append("dist: ").append(this.dist).append(str);
        }
        if (this.learningRate != null) {
            sb.append("learningRate: ").append(this.learningRate).append(str);
        }
        if (this.learningRateAfter != null) {
            sb.append("learningRateAfter: ").append(this.learningRateAfter).append(str);
        }
        if (this.lrScoreBasedDecay != null) {
            sb.append("lrScoreBasedDecay: ").append(this.lrScoreBasedDecay).append(str);
        }
        if (this.l1 != null) {
            sb.append("l1: ").append(this.l1).append(str);
        }
        if (this.l2 != null) {
            sb.append("l2: ").append(this.l2).append(str);
        }
        if (this.dropOut != null) {
            sb.append("dropOut: ").append(this.dropOut).append(str);
        }
        if (this.momentum != null) {
            sb.append("momentum: ").append(this.momentum).append(str);
        }
        if (this.momentumAfter != null) {
            sb.append("momentumAfter: ").append(this.momentumAfter).append(str);
        }
        if (this.updater != null) {
            sb.append("updater: ").append(this.updater).append(str);
        }
        if (this.rho != null) {
            sb.append("rho: ").append(this.rho).append(str);
        }
        if (this.rmsDecay != null) {
            sb.append("rmsDecay: ").append(this.rmsDecay).append(str);
        }
        if (this.gradientNormalization != null) {
            sb.append("gradientNormalization: ").append(this.gradientNormalization).append(str);
        }
        if (this.gradientNormalizationThreshold != null) {
            sb.append("gradientNormalizationThreshold").append(this.gradientNormalizationThreshold);
        }
        String sb2 = sb.toString();
        return sb2.endsWith(str) ? sb2.substring(0, sb2.lastIndexOf(str)) : sb2;
    }
}
