package ai.djl.tensorflow.engine;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.Preconditions;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_Op;
import org.tensorflow.internal.c_api.TFE_TensorHandle;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/djl/tensorflow/engine/TfOpExecutor.class */
public final class TfOpExecutor implements AutoCloseable {
    private TfNDManager manager;
    private TFE_Op opHandle;
    private AtomicBoolean closed = new AtomicBoolean(false);

    /* JADX INFO: Access modifiers changed from: package-private */
    public TfOpExecutor(TfNDManager tfNDManager, TFE_Context tFE_Context, String str) {
        this.manager = tfNDManager;
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            TF_Status newStatus = TF_Status.newStatus();
            this.opHandle = TFE_Op.newOp(tFE_Context, str, newStatus);
            newStatus.throwExceptionIfNotOK();
            this.opHandle.retainReference();
            if (pointerScope != null) {
                if (0 == 0) {
                    pointerScope.close();
                    return;
                }
                try {
                    pointerScope.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (pointerScope != null) {
                if (0 != 0) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    pointerScope.close();
                }
            }
            throw th3;
        }
    }

    public NDArray[] build(int i) {
        TFE_TensorHandle[] buildRawPointer = buildRawPointer(i);
        NDArray[] nDArrayArr = new NDArray[buildRawPointer.length];
        for (int i2 = 0; i2 < buildRawPointer.length; i2++) {
            nDArrayArr[i2] = new TfNDArray(this.manager, buildRawPointer[i2]);
        }
        return nDArrayArr;
    }

    public NDArray buildSingletonOrThrow() {
        TFE_TensorHandle[] buildRawPointer = buildRawPointer(1);
        try {
            Preconditions.checkArgument(buildRawPointer.length == 1, "The expected size of outputs is 1 but got " + buildRawPointer.length);
            return new TfNDArray(this.manager, buildRawPointer[0]);
        } catch (IllegalArgumentException e) {
            Arrays.stream(buildRawPointer).forEach((v0) -> {
                v0.close();
            });
            throw e;
        }
    }

    public TFE_TensorHandle[] buildRawPointer(int i) {
        try {
            PointerScope pointerScope = new PointerScope(new Class[0]);
            Throwable th = null;
            try {
                try {
                    IntPointer put = new IntPointer(1L).put(i);
                    PointerPointer pointerPointer = new PointerPointer(i);
                    TF_Status newStatus = TF_Status.newStatus();
                    tensorflow.TFE_Execute(this.opHandle, pointerPointer, put, newStatus);
                    newStatus.throwExceptionIfNotOK();
                    TFE_TensorHandle[] tFE_TensorHandleArr = new TFE_TensorHandle[put.get()];
                    for (int i2 = 0; i2 < tFE_TensorHandleArr.length; i2++) {
                        tFE_TensorHandleArr[i2] = (TFE_TensorHandle) pointerPointer.get(TFE_TensorHandle.class, i2).withDeallocator().retainReference();
                    }
                    if (pointerScope != null) {
                        if (0 != 0) {
                            try {
                                pointerScope.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            pointerScope.close();
                        }
                    }
                    return tFE_TensorHandleArr;
                } finally {
                }
            } finally {
            }
        } finally {
            close();
        }
    }

    public TfOpExecutor addInput(NDArray nDArray) {
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            try {
                TF_Status newStatus = TF_Status.newStatus();
                tensorflow.TFE_OpAddInput(this.opHandle, (TFE_TensorHandle) ((TfNDArray) nDArray).getHandle(), newStatus);
                newStatus.throwExceptionIfNotOK();
                if (pointerScope != null) {
                    if (0 != 0) {
                        try {
                            pointerScope.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pointerScope.close();
                    }
                }
                return this;
            } finally {
            }
        } catch (Throwable th3) {
            if (pointerScope != null) {
                if (th != null) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    pointerScope.close();
                }
            }
            throw th3;
        }
    }

    public TfOpExecutor addInputList(NDArray[] nDArrayArr) {
        Pointer[] pointerArr = (TFE_TensorHandle[]) Arrays.stream(nDArrayArr).map(nDArray -> {
            return (TFE_TensorHandle) ((TfNDArray) nDArray).getHandle();
        }).toArray(i -> {
            return new TFE_TensorHandle[i];
        });
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            try {
                PointerPointer pointerPointer = new PointerPointer(pointerArr.length);
                for (int i2 = 0; i2 < pointerArr.length; i2++) {
                    pointerPointer.put(i2, pointerArr[i2]);
                }
                TF_Status newStatus = TF_Status.newStatus();
                tensorflow.TFE_OpAddInputList(this.opHandle, pointerPointer, pointerArr.length, newStatus);
                newStatus.throwExceptionIfNotOK();
                if (pointerScope != null) {
                    if (0 != 0) {
                        try {
                            pointerScope.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pointerScope.close();
                    }
                }
                return this;
            } finally {
            }
        } catch (Throwable th3) {
            if (pointerScope != null) {
                if (th != null) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    pointerScope.close();
                }
            }
            throw th3;
        }
    }

    /* JADX WARN: Failed to calculate best type for var: r8v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r8v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r9v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r9v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 8, insn: 0x00aa: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r8 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:32:0x00aa */
    /* JADX WARN: Not initialized variable reg: 9, insn: 0x00ae: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r9 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:34:0x00ae */
    /* JADX WARN: Type inference failed for: r8v0, types: [org.bytedeco.javacpp.PointerScope] */
    /* JADX WARN: Type inference failed for: r9v0, types: [java.lang.Throwable] */
    public TfOpExecutor setDevice(Device device) {
        String str;
        try {
            try {
                PointerScope pointerScope = new PointerScope(new Class[0]);
                Throwable th = null;
                if (device.getDeviceType().equals("cpu")) {
                    str = "/device:CPU:0";
                } else {
                    if (!device.getDeviceType().equals("gpu")) {
                        throw new EngineException("Unknown device type to TensorFlow Engine: " + device.toString());
                    }
                    str = "/device:GPU:" + device.getDeviceId();
                }
                TF_Status newStatus = TF_Status.newStatus();
                tensorflow.TFE_OpSetDevice(this.opHandle, str, newStatus);
                newStatus.throwExceptionIfNotOK();
                if (pointerScope != null) {
                    if (0 != 0) {
                        try {
                            pointerScope.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pointerScope.close();
                    }
                }
                return this;
            } finally {
            }
        } catch (Exception e) {
            close();
            throw e;
        }
    }

    public TfOpExecutor addParam(String str, String str2) {
        byte[] bytes = str2.getBytes(StandardCharsets.UTF_8);
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            try {
                tensorflow.TFE_OpSetAttrString(this.opHandle, str, new BytePointer(bytes), bytes.length);
                if (pointerScope != null) {
                    if (0 != 0) {
                        try {
                            pointerScope.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pointerScope.close();
                    }
                }
                return this;
            } finally {
            }
        } catch (Throwable th3) {
            if (pointerScope != null) {
                if (th != null) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    pointerScope.close();
                }
            }
            throw th3;
        }
    }

    public TfOpExecutor addParam(String str, long j) {
        tensorflow.TFE_OpSetAttrInt(this.opHandle, str, j);
        return this;
    }

    public TfOpExecutor addParam(String str, float f) {
        tensorflow.TFE_OpSetAttrFloat(this.opHandle, str, f);
        return this;
    }

    public TfOpExecutor addParam(String str, boolean z) {
        tensorflow.TFE_OpSetAttrBool(this.opHandle, str, (byte) (z ? 1 : 0));
        return this;
    }

    public TfOpExecutor addParam(String str, DataType dataType) {
        tensorflow.TFE_OpSetAttrType(this.opHandle, str, TfDataType.toTf(dataType));
        return this;
    }

    public TfOpExecutor addParam(String str, long[] jArr) {
        tensorflow.TFE_OpSetAttrIntList(this.opHandle, str, jArr, jArr.length);
        return this;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.closed.getAndSet(true) || this.opHandle == null || this.opHandle.isNull()) {
            return;
        }
        this.opHandle.close();
    }
}
