package org.tensorflow.framework.utils;

import java.util.ArrayList;
import java.util.Arrays;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Scope;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TIntegral;

/* loaded from: input_file:org/tensorflow/framework/utils/ShapeUtils.class */
public class ShapeUtils {
    public static <T extends TIntegral> Shape toShape(Scope scope, Operand<T> operand) {
        return Shape.of(getLongArray(scope, operand));
    }

    public static int[] getIntArray(Scope scope, Operand<TInt32> operand) {
        return Arrays.stream(getLongArray(scope, operand)).mapToInt(j -> {
            return (int) j;
        }).toArray();
    }

    public static <T extends TIntegral> long[] getLongArray(Scope scope, Operand<T> operand) {
        if (scope.env().isEager()) {
            return getLongArray(operand.asTensor());
        }
        Session session = new Session(scope.env());
        Throwable th = null;
        try {
            TIntegral tIntegral = (TIntegral) session.runner().fetch(operand).run().get(0);
            Throwable th2 = null;
            try {
                try {
                    long[] longArray = getLongArray(tIntegral);
                    if (tIntegral != null) {
                        if (0 != 0) {
                            try {
                                tIntegral.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            tIntegral.close();
                        }
                    }
                    return longArray;
                } finally {
                }
            } catch (Throwable th4) {
                if (tIntegral != null) {
                    if (th2 != null) {
                        try {
                            tIntegral.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        tIntegral.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (session != null) {
                if (0 != 0) {
                    try {
                        session.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    session.close();
                }
            }
        }
    }

    public static <T extends TIntegral> long[] getLongArray(T t) {
        ArrayList arrayList = new ArrayList();
        if (t instanceof TInt32) {
            ((TInt32) t).scalars().forEach(intNdArray -> {
                arrayList.add(Long.valueOf(intNdArray.getInt(new long[0])));
            });
        } else if (t instanceof TInt64) {
            ((TInt64) t).scalars().forEach(longNdArray -> {
                arrayList.add(Long.valueOf(longNdArray.getLong(new long[0])));
            });
        } else {
            if (!(t instanceof TUint8)) {
                throw new IllegalArgumentException("the data type must be an integer type");
            }
            ((TUint8) t).scalars().forEach(byteNdArray -> {
                arrayList.add(Long.valueOf(byteNdArray.getObject(new long[0]).longValue()));
            });
        }
        return arrayList.stream().mapToLong(l -> {
            return l.longValue();
        }).toArray();
    }

    public static Shape reduce(Shape shape, int i) {
        int numDimensions = i % shape.numDimensions();
        if (numDimensions < 0) {
            numDimensions = shape.numDimensions() + numDimensions;
        }
        long[] asArray = shape.asArray();
        if (asArray == null) {
            return Shape.unknown();
        }
        long[] jArr = new long[numDimensions];
        System.arraycopy(asArray, 0, jArr, 0, numDimensions - 1);
        long j = asArray[numDimensions - 1];
        for (int i2 = numDimensions; i2 < asArray.length; i2++) {
            if (asArray[i2] != Shape.UNKNOWN_SIZE) {
                j *= asArray[i2];
            }
        }
        jArr[numDimensions - 1] = j;
        return Shape.of(jArr);
    }
}
