package org.tensorflow.framework.initializers;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.Losses;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.framework.utils.ShapeUtils;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.math.Mul;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TFloating;

/* loaded from: input_file:org/tensorflow/framework/initializers/VarianceScaling.class */
public class VarianceScaling<T extends TFloating> extends BaseInitializer<T> {
    public static final double SCALE_DEFAULT = 1.0d;
    public static final Mode MODE_DEFAULT = Mode.FAN_IN;
    public static final Distribution DISTRIBUTION_DEFAULT = Distribution.TRUNCATED_NORMAL;
    private final double scale;
    private final Mode mode;
    private final Distribution distribution;
    private final long seed;

    /* renamed from: org.tensorflow.framework.initializers.VarianceScaling$1, reason: invalid class name */
    /* loaded from: input_file:org/tensorflow/framework/initializers/VarianceScaling$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Mode;
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Distribution = new int[Distribution.values().length];

        static {
            try {
                $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Distribution[Distribution.TRUNCATED_NORMAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Distribution[Distribution.NORMAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Distribution[Distribution.UNIFORM.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Mode = new int[Mode.values().length];
            try {
                $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Mode[Mode.FAN_IN.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Mode[Mode.FAN_OUT.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Mode[Mode.FAN_AVG.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* loaded from: input_file:org/tensorflow/framework/initializers/VarianceScaling$Distribution.class */
    public enum Distribution {
        TRUNCATED_NORMAL,
        NORMAL,
        UNIFORM
    }

    /* loaded from: input_file:org/tensorflow/framework/initializers/VarianceScaling$Mode.class */
    public enum Mode {
        FAN_IN,
        FAN_OUT,
        FAN_AVG
    }

    public VarianceScaling(long j) {
        this(1.0d, MODE_DEFAULT, DISTRIBUTION_DEFAULT, j);
    }

    public VarianceScaling(double d, Mode mode, Distribution distribution, long j) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("scale must be greater than 0, got " + d);
        }
        this.scale = d;
        this.mode = mode;
        this.distribution = distribution;
        this.seed = j;
    }

    @Override // org.tensorflow.framework.initializers.Initializer
    public Operand<T> call(Ops ops, Operand<TInt64> operand, Class<T> cls) {
        Shape shape = ShapeUtils.toShape(ops.scope(), operand);
        double d = this.scale;
        double[] computeFans = computeFans(shape);
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Mode[this.mode.ordinal()]) {
            case Losses.CHANNELS_FIRST /* 1 */:
                d /= Math.max(1.0d, computeFans[0]);
                break;
            case 2:
                d /= Math.max(1.0d, computeFans[1]);
                break;
            case 3:
                d /= Math.max(1.0d, (computeFans[0] + computeFans[1]) / 2.0d);
                break;
        }
        Mul mul = null;
        long[] jArr = {this.seed, 0};
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$framework$initializers$VarianceScaling$Distribution[this.distribution.ordinal()]) {
            case Losses.CHANNELS_FIRST /* 1 */:
                mul = ops.math.mul(ops.random.statelessTruncatedNormal(operand, ops.constant(jArr), cls), CastHelper.cast(ops, ops.constant(Math.sqrt(d) / 0.8796256610342398d), cls));
                break;
            case 2:
                mul = ops.math.mul(ops.random.statelessRandomNormal(operand, ops.constant(jArr), cls), CastHelper.cast(ops, ops.constant(Math.sqrt(d)), cls));
                break;
            case 3:
                mul = ops.math.mul(ops.random.statelessRandomUniform(operand, ops.constant(jArr), cls), CastHelper.cast(ops, ops.constant(Math.sqrt(3.0d * d)), cls));
                break;
        }
        return mul;
    }

    private double[] computeFans(Shape shape) {
        double d;
        double d2;
        long[] asArray = shape.asArray();
        if (asArray == null || asArray.length < 1) {
            d = 1.0d;
            d2 = 1.0d;
        } else if (asArray.length == 1) {
            double d3 = asArray[0];
            d = d3;
            d2 = d3;
        } else if (asArray.length == 2) {
            d2 = asArray[0];
            d = asArray[1];
        } else {
            double d4 = 1.0d;
            for (int length = asArray.length - 2; length >= 0; length--) {
                d4 *= asArray[length];
            }
            d2 = asArray[asArray.length - 2] * d4;
            d = asArray[asArray.length - 1] * d4;
        }
        return new double[]{d2, d};
    }
}
