package org.tensorflow;

import java.util.concurrent.atomic.AtomicReferenceArray;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.c_api.TFE_Op;
import org.tensorflow.internal.c_api.TFE_TensorHandle;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.proto.DataType;
import org.tensorflow.types.family.TType;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/tensorflow/EagerOperation.class */
public class EagerOperation extends AbstractOperation {
    private final EagerSession session;
    private final String type;
    private final String name;
    private final AtomicReferenceArray<Tensor> outputTensors;
    private final TFE_Op opHandle;
    private final TFE_TensorHandle[] outputHandles;

    /* JADX INFO: Access modifiers changed from: package-private */
    public EagerOperation(EagerSession eagerSession, TFE_Op tFE_Op, TFE_TensorHandle[] tFE_TensorHandleArr, String str, String str2) {
        this.session = eagerSession;
        this.type = str;
        this.name = str2;
        this.opHandle = tFE_Op;
        this.outputHandles = tFE_TensorHandleArr;
        this.outputTensors = new AtomicReferenceArray<>(tFE_TensorHandleArr.length);
    }

    @Override // org.tensorflow.Operation
    public String name() {
        return this.name;
    }

    @Override // org.tensorflow.Operation
    public String type() {
        return this.type;
    }

    @Override // org.tensorflow.Operation
    public EagerSession env() {
        return this.session;
    }

    @Override // org.tensorflow.Operation
    public int numOutputs() {
        return this.outputHandles.length;
    }

    @Override // org.tensorflow.Operation
    public int outputListLength(String str) {
        return outputListLength(this.opHandle, str);
    }

    @Override // org.tensorflow.Operation
    public int inputListLength(String str) {
        return inputListLength(this.opHandle, str);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.tensorflow.AbstractOperation
    public TFE_TensorHandle getUnsafeNativeHandle(int i) {
        return this.outputHandles[i];
    }

    public int hashCode() {
        return Long.valueOf(this.opHandle.address()).hashCode();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof EagerOperation)) {
            return false;
        }
        EagerOperation eagerOperation = (EagerOperation) obj;
        if (this.session != eagerOperation.session || this.opHandle == null || eagerOperation.opHandle == null || this.opHandle.isNull() || eagerOperation.opHandle.isNull()) {
            return false;
        }
        return this.opHandle.equals(eagerOperation.opHandle);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.tensorflow.AbstractOperation
    public Shape shape(int i) {
        Tensor tensor = this.outputTensors.get(i);
        if (tensor != null) {
            return tensor.shape();
        }
        TFE_TensorHandle unsafeNativeHandle = getUnsafeNativeHandle(i);
        long[] jArr = new long[numDims(unsafeNativeHandle)];
        for (int i2 = 0; i2 < jArr.length; i2++) {
            jArr[i2] = dim(unsafeNativeHandle, i2);
        }
        return Shape.of(jArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.tensorflow.AbstractOperation
    public DataType dtype(int i) {
        Tensor tensor = this.outputTensors.get(i);
        return tensor != null ? tensor.dataType() : DataType.forNumber(dataType(getUnsafeNativeHandle(i)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.tensorflow.AbstractOperation
    public Tensor tensor(int i) {
        Tensor tensor = this.outputTensors.get(i);
        if (tensor == null) {
            tensor = resolveTensor(i);
        }
        return tensor;
    }

    private Tensor resolveTensor(int i) {
        Tensor resolveTensorHandle = resolveTensorHandle(getUnsafeNativeHandle(i), this.session);
        if (!this.outputTensors.compareAndSet(i, null, resolveTensorHandle)) {
            this.session.detach(resolveTensorHandle.asRawTensor().nativeHandle());
            resolveTensorHandle = this.outputTensors.get(i);
        }
        return resolveTensorHandle;
    }

    private static void requireOp(TFE_Op tFE_Op) {
        if (tFE_Op == null || tFE_Op.isNull()) {
            throw new IllegalStateException("Eager session has been closed");
        }
    }

    private static void requireTensorHandle(TFE_TensorHandle tFE_TensorHandle) {
        if (tFE_TensorHandle == null || tFE_TensorHandle.isNull()) {
            throw new IllegalStateException("Eager session has been closed");
        }
    }

    private static Tensor resolveTensorHandle(TFE_TensorHandle tFE_TensorHandle, EagerSession eagerSession) {
        requireTensorHandle(tFE_TensorHandle);
        PointerScope pointerScope = new PointerScope();
        try {
            TF_Status newStatus = TF_Status.newStatus();
            TF_Tensor TFE_TensorHandleResolve = tensorflow.TFE_TensorHandleResolve(tFE_TensorHandle, newStatus);
            newStatus.throwExceptionIfNotOK();
            TType asTypedTensor = RawTensor.fromHandle(TFE_TensorHandleResolve.withDeallocator(), eagerSession).asTypedTensor();
            pointerScope.close();
            return asTypedTensor;
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static int outputListLength(TFE_Op tFE_Op, String str) {
        requireOp(tFE_Op);
        PointerScope pointerScope = new PointerScope();
        try {
            TF_Status newStatus = TF_Status.newStatus();
            int TFE_OpGetOutputLength = tensorflow.TFE_OpGetOutputLength(tFE_Op, str, newStatus);
            newStatus.throwExceptionIfNotOK();
            pointerScope.close();
            return TFE_OpGetOutputLength;
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static int inputListLength(TFE_Op tFE_Op, String str) {
        requireOp(tFE_Op);
        PointerScope pointerScope = new PointerScope();
        try {
            TF_Status newStatus = TF_Status.newStatus();
            int TFE_OpGetInputLength = tensorflow.TFE_OpGetInputLength(tFE_Op, str, newStatus);
            newStatus.throwExceptionIfNotOK();
            pointerScope.close();
            return TFE_OpGetInputLength;
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static int dataType(TFE_TensorHandle tFE_TensorHandle) {
        requireTensorHandle(tFE_TensorHandle);
        return tensorflow.TFE_TensorHandleDataType(tFE_TensorHandle);
    }

    private static int numDims(TFE_TensorHandle tFE_TensorHandle) {
        requireTensorHandle(tFE_TensorHandle);
        PointerScope pointerScope = new PointerScope();
        try {
            TF_Status newStatus = TF_Status.newStatus();
            int TFE_TensorHandleNumDims = tensorflow.TFE_TensorHandleNumDims(tFE_TensorHandle, newStatus);
            newStatus.throwExceptionIfNotOK();
            pointerScope.close();
            return TFE_TensorHandleNumDims;
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static long dim(TFE_TensorHandle tFE_TensorHandle, int i) {
        requireTensorHandle(tFE_TensorHandle);
        PointerScope pointerScope = new PointerScope();
        try {
            TF_Status newStatus = TF_Status.newStatus();
            long TFE_TensorHandleDim = tensorflow.TFE_TensorHandleDim(tFE_TensorHandle, i, newStatus);
            newStatus.throwExceptionIfNotOK();
            pointerScope.close();
            return TFE_TensorHandleDim;
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
