package org.tensorflow.framework.initializers;

import org.tensorflow.Operand;
import org.tensorflow.framework.utils.ShapeUtils;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Fill;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.MatrixSetDiag;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TFloating;

/* loaded from: input_file:org/tensorflow/framework/initializers/Identity.class */
public class Identity<T extends TFloating> extends BaseInitializer<T> {
    public static final double GAIN_DEFAULT = 1.0d;
    private final double gain;

    public Identity(Ops ops) {
        super(ops);
        this.gain = 1.0d;
    }

    public Identity(Ops ops, double d) {
        super(ops);
        this.gain = d;
    }

    @Override // org.tensorflow.framework.initializers.Initializer
    public Operand<T> call(Operand<TInt64> operand, Class<T> cls) {
        Shape shape = ShapeUtils.toShape(this.tf.scope(), operand);
        if (shape.numDimensions() != 2) {
            throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions());
        }
        boolean z = shape.size(0) == shape.size(1);
        Shape of = Shape.of(new long[]{Math.min(shape.size(0), shape.size(1))});
        Cast cast = this.tf.dtypes.cast(this.tf.constant(0), cls, new Cast.Options[0]);
        Fill fill = this.tf.fill(this.tf.constant(of.asArray()), this.tf.dtypes.cast(this.tf.constant(1.0d), cls, new Cast.Options[0]));
        return this.tf.math.mul(z ? this.tf.linalg.matrixDiag(fill, this.tf.constant(0), this.tf.constant((int) shape.size(0)), this.tf.constant((int) shape.size(1)), cast) : this.tf.linalg.matrixSetDiag(this.tf.zeros(operand, cls), fill, this.tf.constant(0), new MatrixSetDiag.Options[0]), this.tf.dtypes.cast(this.tf.constant(this.gain), cls, new Cast.Options[0]));
    }
}
