package org.tensorflow.framework.initializers;

import org.tensorflow.Operand;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.framework.utils.ShapeUtils;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.linalg.Qr;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TFloating;

/* loaded from: input_file:org/tensorflow/framework/initializers/Orthogonal.class */
public class Orthogonal<T extends TFloating> extends BaseInitializer<T> {
    public static final double GAIN_DEFAULT = 1.0d;
    private final double gain;
    private final long seed;

    public Orthogonal(long j) {
        this(1.0d, j);
    }

    public Orthogonal(double d, long j) {
        this.gain = d;
        this.seed = j;
    }

    @Override // org.tensorflow.framework.initializers.Initializer
    public Operand<T> call(Ops ops, Operand<TInt64> operand, Class<T> cls) {
        Shape shape = ShapeUtils.toShape(ops.scope(), operand);
        if (shape.numDimensions() < 2) {
            throw new IllegalArgumentException("The tensor to initialize must be at least two-dimensional, got " + shape.numDimensions());
        }
        long j = 1;
        int i = 0;
        while (i < shape.numDimensions() - 1) {
            j *= shape.size(i);
            i++;
        }
        long size = shape.size(i);
        Qr qr = ops.linalg.qr(ops.random.statelessRandomNormal(ops.constant(Shape.of(new long[]{Math.max(j, size), Math.min(j, size)})), ops.constant(new long[]{this.seed, 0}), cls), new Qr.Options[]{Qr.fullMatrices(false)});
        Operand mul = ops.math.mul(qr.q(), ops.math.sign(ops.linalg.matrixDiagPart(qr.r(), ops.constant(0), CastHelper.cast(ops, ops.constant(0), cls))));
        if (j < size) {
            mul = ops.linalg.transpose(mul, (Operand) null);
        }
        return ops.math.mul(mul, CastHelper.cast(ops, ops.constant(this.gain), cls));
    }
}
