package org.tensorflow.framework.op.math;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.framework.op.linalg.MatMul;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.Rank;
import org.tensorflow.op.core.ReduceProd;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Select;
import org.tensorflow.op.core.SetDiff1d;
import org.tensorflow.op.core.Slice;
import org.tensorflow.op.core.Stack;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.GreaterEqual;
import org.tensorflow.op.math.Less;
import org.tensorflow.op.math.Sub;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TFloating;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/op/math/TensorDot.class */
public class TensorDot {
    private static <T extends TNumber> Object[] tensordotReshape(Scope scope, Operand<T> operand, Operand<TInt32> operand2, boolean z) {
        Operand create;
        Cast out;
        Concat create2;
        Stack create3;
        Shape of;
        Shape shape = operand.shape();
        if (shape.hasUnknownDimension()) {
            long[] jArr = null;
            Constant scalarOf = Constant.scalarOf(scope, 1);
            Constant scalarOf2 = Constant.scalarOf(scope, -1);
            Constant scalarOf3 = Constant.scalarOf(scope, 0);
            org.tensorflow.op.core.Shape create4 = org.tensorflow.op.core.Shape.create(scope, operand);
            if (shape.numDimensions() != Shape.UNKNOWN_SIZE) {
                long[] asArray = shape.asArray();
                if (asArray == null) {
                    asArray = new long[0];
                }
                int[] intArray = getIntArray(scope, operand2);
                for (int i = 0; i < intArray.length; i++) {
                    intArray[i] = intArray[i] >= 0 ? intArray[i] : Math.floorMod(intArray[i], intArray.length);
                }
                List list = (List) Arrays.stream(Arrays.stream(intArray).mapToLong(i2 -> {
                    return i2;
                }).toArray()).boxed().collect(Collectors.toList());
                List list2 = (List) Arrays.stream(asArray).boxed().collect(Collectors.toList());
                ArrayList arrayList = new ArrayList(list);
                arrayList.removeAll(list2);
                long[] array = arrayList.stream().mapToLong(l -> {
                    return l.longValue();
                }).toArray();
                jArr = array;
                create = Constant.vectorOf(scope, intArray);
                out = Cast.create(scope, Constant.vectorOf(scope, array), TInt32.class, new Cast.Options[0]);
            } else {
                Rank create5 = Rank.create(scope, operand);
                create = Select.create(scope, GreaterEqual.create(scope, operand2, Constant.scalarOf(scope, 0)), operand2, Add.create(scope, operand2, create5));
                out = SetDiff1d.create(scope, Range.create(scope, Constant.scalarOf(scope, 0), create5, scalarOf), create).out();
            }
            Gather create6 = Gather.create(scope, create4, out, scalarOf3, new Gather.Options[0]);
            Gather create7 = Gather.create(scope, create4, create, scalarOf3, new Gather.Options[0]);
            Operand create8 = ReduceProd.create(scope, create6, scalarOf2, new ReduceProd.Options[0]);
            Operand create9 = ReduceProd.create(scope, create7, scalarOf2, new ReduceProd.Options[0]);
            if (z) {
                create2 = Concat.create(scope, Arrays.asList(create, out), scalarOf3);
                create3 = Stack.create(scope, Arrays.asList(create9, create8), new Stack.Options[0]);
            } else {
                create2 = Concat.create(scope, Arrays.asList(out, create), scalarOf3);
                create3 = Stack.create(scope, Arrays.asList(create8, create9), new Stack.Options[0]);
            }
            return new Object[]{Reshape.create(scope, Transpose.create(scope, operand, create2), create3), create6, jArr};
        }
        long[] asArray2 = shape.asArray();
        if (asArray2 == null) {
            asArray2 = new long[0];
        }
        long[] jArr2 = new long[asArray2.length];
        for (int i3 = 0; i3 < jArr2.length; i3++) {
            jArr2[i3] = i3;
        }
        int[] intArray2 = getIntArray(scope, operand2);
        for (int i4 = 0; i4 < intArray2.length; i4++) {
            intArray2[i4] = intArray2[i4] >= 0 ? intArray2[i4] : Math.floorMod(intArray2[i4], intArray2.length);
        }
        long[] array2 = Arrays.stream(intArray2).mapToLong(i5 -> {
            return i5;
        }).toArray();
        List list3 = (List) Arrays.stream(array2).boxed().collect(Collectors.toList());
        List list4 = (List) Arrays.stream(jArr2).boxed().collect(Collectors.toList());
        list4.removeAll(list3);
        long[] array3 = list4.stream().mapToLong(l2 -> {
            return l2.longValue();
        }).toArray();
        long[] jArr3 = new long[array3.length];
        for (int i6 = 0; i6 < array3.length; i6++) {
            jArr3[i6] = asArray2[(int) array3[i6]];
        }
        long j = 1;
        for (long j2 : jArr3) {
            j *= j2;
        }
        long j3 = 1;
        for (long j4 : array2) {
            j3 *= asArray2[(int) j4];
        }
        long[] jArr4 = new long[jArr3.length + array2.length];
        if (z) {
            System.arraycopy(array2, 0, jArr4, 0, array2.length);
            System.arraycopy(array3, 0, jArr4, array2.length, array3.length);
            of = Shape.of(new long[]{j3, j});
        } else {
            System.arraycopy(array3, 0, jArr4, 0, array3.length);
            System.arraycopy(array2, 0, jArr4, jArr3.length, array2.length);
            of = Shape.of(new long[]{j, j3});
        }
        long[] jArr5 = new long[array2.length];
        for (int i7 = 0; i7 < jArr5.length; i7++) {
            jArr5[i7] = i7;
        }
        Operand<T> create10 = !Arrays.equals(jArr4, jArr5) ? Transpose.create(scope, operand, Constant.vectorOf(scope, jArr4)) : operand;
        return new Object[]{create10.asOutput().shape().equals(of) ? create10 : Reshape.create(scope, create10, Constant.vectorOf(scope, of.asArray())), Constant.vectorOf(scope, jArr3), jArr3};
    }

    private static int[] getIntArray(Scope scope, Operand<TInt32> operand) {
        ArrayList arrayList = new ArrayList();
        if (scope.env().isEager()) {
            operand.asTensor().scalars().forEach(intNdArray -> {
                arrayList.add(Integer.valueOf(intNdArray.getInt(new long[0])));
            });
        } else {
            Session session = new Session(scope.env());
            Throwable th = null;
            try {
                TInt32 tInt32 = (TInt32) session.runner().fetch(operand).run().get(0);
                Throwable th2 = null;
                try {
                    try {
                        tInt32.scalars().forEach(intNdArray2 -> {
                            arrayList.add(Integer.valueOf(intNdArray2.getInt(new long[0])));
                        });
                        if (tInt32 != null) {
                            if (0 != 0) {
                                try {
                                    tInt32.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                tInt32.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th4) {
                    if (tInt32 != null) {
                        if (th2 != null) {
                            try {
                                tInt32.close();
                            } catch (Throwable th5) {
                                th2.addSuppressed(th5);
                            }
                        } else {
                            tInt32.close();
                        }
                    }
                    throw th4;
                }
            } finally {
                if (session != null) {
                    if (0 != 0) {
                        try {
                            session.close();
                        } catch (Throwable th6) {
                            th.addSuppressed(th6);
                        }
                    } else {
                        session.close();
                    }
                }
            }
        }
        return arrayList.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
    }

    private static <T extends TNumber> Operand<TInt32>[] tensordotAxes(Scope scope, Operand<T> operand, int i) {
        Shape shape = operand.asOutput().shape();
        if (i < 0) {
            throw new IllegalArgumentException("'axis' must be at least 0.");
        }
        int numDimensions = shape.numDimensions();
        Operand<TInt32>[] operandArr = new Operand[2];
        if (numDimensions == Shape.UNKNOWN_SIZE) {
            Operand create = Rank.create(scope, operand);
            Constant scalarOf = Constant.scalarOf(scope, i);
            Constant scalarOf2 = Constant.scalarOf(scope, 1);
            Constant scalarOf3 = Constant.scalarOf(scope, 0);
            Scope withControlDependencies = scope.withControlDependencies(Collections.singletonList(AssertThat.create(scope, Less.create(scope, scalarOf, create), Arrays.asList(Constant.scalarOf(scope, "'axes' must not be larger than the number of dimensions of tensor "), create), new AssertThat.Options[0])));
            operandArr[0] = Range.create(withControlDependencies, Sub.create(scope, create, scalarOf), create, scalarOf2);
            operandArr[1] = Range.create(withControlDependencies, scalarOf3, scalarOf, scalarOf2);
        } else {
            if (i > numDimensions) {
                throw new IllegalArgumentException(String.format("'axis' must not be larger than the number of dimensions of tensor %s.", Integer.valueOf(numDimensions)));
            }
            int i2 = numDimensions - i;
            int i3 = numDimensions - i2;
            int[] iArr = new int[i3];
            for (int i4 = 0; i4 < i3; i4++) {
                iArr[i4] = i4 + i2;
            }
            int[] iArr2 = new int[i];
            for (int i5 = 0; i5 < i; i5++) {
                iArr2[i5] = i5;
            }
            operandArr[0] = Constant.vectorOf(scope, iArr);
            operandArr[1] = Constant.vectorOf(scope, iArr2);
        }
        return operandArr;
    }

    private static <T extends TNumber> Operand<TInt32>[] tensordotAxes(Scope scope, Operand<T> operand, int[] iArr) {
        if (iArr.length != 2) {
            throw new IllegalArgumentException("'axes' must have length 1 or 2, provided with " + iArr.length);
        }
        return new Operand[]{Constant.vectorOf(scope, new int[]{iArr[0]}), Constant.vectorOf(scope, new int[]{iArr[1]})};
    }

    private static <T extends TNumber> Operand<TInt32>[] tensordotAxes(Scope scope, Operand<T> operand, int[][] iArr) {
        if (iArr.length != 2) {
            throw new IllegalArgumentException("'axes' must have length 1 or 2, provided with " + iArr.length);
        }
        int[] iArr2 = iArr[0];
        int[] iArr3 = iArr[1];
        if (iArr2.length != iArr3.length) {
            throw new IllegalArgumentException(String.format("Different number of contraction axes 'a' and 'b', %d != %d", Integer.valueOf(iArr2.length), Integer.valueOf(iArr3.length)));
        }
        return new Operand[]{Constant.vectorOf(scope, iArr2), Constant.vectorOf(scope, iArr3)};
    }

    private static <T extends TNumber> Operand<TInt32>[] tensordotAxes(Scope scope, Operand<T> operand, Operand<TInt32> operand2) {
        Constant scalarOf = Constant.scalarOf(scope, 1);
        return new Operand[]{Slice.create(scope, operand2, Cast.create(scope, Constant.scalarOf(scope, 0), TInt32.class, new Cast.Options[0]), Cast.create(scope, scalarOf, TInt32.class, new Cast.Options[0])), Slice.create(scope, operand2, Cast.create(scope, scalarOf, TInt32.class, new Cast.Options[0]), Cast.create(scope, scalarOf, TInt32.class, new Cast.Options[0]))};
    }

    public static <T extends TFloating> Operand<T> tensordot(Scope scope, Operand<T> operand, Operand<T> operand2, int i) {
        Operand<TInt32>[] tensordotAxes = tensordotAxes(scope, operand, i);
        return tensordot(scope, operand, operand2, tensordotAxes[0], tensordotAxes[1]);
    }

    public static <T extends TFloating> Operand<T> tensordot(Scope scope, Operand<T> operand, Operand<T> operand2, Operand<TInt32> operand3) {
        Operand<TInt32>[] tensordotAxes = tensordotAxes(scope, operand, operand3);
        return tensordot(scope, operand, operand2, tensordotAxes[0], tensordotAxes[1]);
    }

    public static <T extends TFloating> Operand<T> tensordot(Scope scope, Operand<T> operand, Operand<T> operand2, int[] iArr) {
        Operand<TInt32>[] tensordotAxes = tensordotAxes(scope, operand, iArr);
        return tensordot(scope, operand, operand2, tensordotAxes[0], tensordotAxes[1]);
    }

    public static <T extends TFloating> Operand<T> tensordot(Scope scope, Operand<T> operand, Operand<T> operand2, int[][] iArr) {
        Operand<TInt32>[] tensordotAxes = tensordotAxes(scope, operand, iArr);
        return tensordot(scope, operand, operand2, tensordotAxes[0], tensordotAxes[1]);
    }

    public static <T extends TFloating> Operand<T> tensordot(Scope scope, Operand<T> operand, Operand<T> operand2, Operand<TInt32> operand3, Operand<TInt32> operand4) {
        if (operand.type().equals(TBfloat16.class) || operand.type().equals(TFloat16.class)) {
            throw new IllegalArgumentException(String.format("Operand 'a' must be either TFloat32 or TFloat64 DataType, 'a' is a %s DataType", operand.type().getSimpleName()));
        }
        if (!operand.type().equals(operand2.type())) {
            throw new IllegalArgumentException(String.format("Operands a and b must be the same data type, a is %s DataType, b is %s DataType", operand.type().getSimpleName(), operand2.type().getSimpleName()));
        }
        Object[] tensordotReshape = tensordotReshape(scope, operand, operand3, false);
        Operand operand5 = (Operand) tensordotReshape[0];
        long[] jArr = (long[]) tensordotReshape[2];
        Object[] tensordotReshape2 = tensordotReshape(scope, operand2, operand4, true);
        Operand operand6 = (Operand) tensordotReshape2[0];
        long[] jArr2 = (long[]) tensordotReshape2[2];
        Operand<T> matmul = MatMul.matmul(scope, operand5, operand6);
        long[] jArr3 = new long[jArr.length + jArr2.length];
        System.arraycopy(jArr, 0, jArr3, 0, jArr.length);
        System.arraycopy(jArr2, 0, jArr3, jArr.length, jArr2.length);
        return (matmul.shape().hasUnknownDimension() || !matmul.shape().equals(Shape.of(jArr3))) ? Reshape.create(scope, matmul, Constant.vectorOf(scope, jArr3)) : matmul;
    }
}
