package ai.djl.mxnet.engine;

import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.MxnetLibrary;
import ai.djl.mxnet.jna.NativeResource;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import com.sun.jna.Pointer;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/mxnet/engine/MxParameterServer.class */
public class MxParameterServer extends NativeResource implements ParameterServer {

    /* loaded from: input_file:ai/djl/mxnet/engine/MxParameterServer$OptimizerCallback.class */
    private static final class OptimizerCallback implements MxnetLibrary.MXKVStoreStrUpdater {
        private Optimizer optimizer;

        OptimizerCallback(Optimizer optimizer) {
            this.optimizer = optimizer;
        }

        @Override // ai.djl.mxnet.jna.MxnetLibrary.MXKVStoreStrUpdater
        public void apply(String str, Pointer pointer, Pointer pointer2, Pointer pointer3) {
            MxNDManager mo7newSubManager = MxNDManager.getSystemManager().mo7newSubManager();
            Throwable th = null;
            try {
                MxNDArray create = mo7newSubManager.create(pointer);
                MxNDArray create2 = mo7newSubManager.create(pointer2);
                create.setShouldFree(false);
                create2.setShouldFree(false);
                this.optimizer.update(str, create2, create);
                if (mo7newSubManager != null) {
                    if (0 == 0) {
                        mo7newSubManager.close();
                        return;
                    }
                    try {
                        mo7newSubManager.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                if (mo7newSubManager != null) {
                    if (0 != 0) {
                        try {
                            mo7newSubManager.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        mo7newSubManager.close();
                    }
                }
                throw th3;
            }
        }
    }

    public MxParameterServer(Optimizer optimizer) {
        super(createdKVStore());
        JnaUtils.parameterStoreSetUpdater(getHandle(), null, new OptimizerCallback(optimizer), null);
    }

    public void init(String str, NDArray[] nDArrayArr) {
        String[] strArr = new String[nDArrayArr.length];
        Arrays.fill(strArr, str);
        JnaUtils.parameterStoreInit(getHandle(), nDArrayArr.length, strArr, new NDList(nDArrayArr));
    }

    public void push(String str, NDArray[] nDArrayArr, int i) {
        String[] strArr = new String[nDArrayArr.length];
        Arrays.fill(strArr, str);
        JnaUtils.parameterStorePush(getHandle(), nDArrayArr.length, strArr, new NDList(nDArrayArr), i);
    }

    public void pull(String str, NDArray[] nDArrayArr, int i) {
        String[] strArr = new String[nDArrayArr.length];
        Arrays.fill(strArr, str);
        JnaUtils.parameterStorePull(getHandle(), nDArrayArr.length, strArr, new NDList(nDArrayArr), i);
    }

    private static Pointer createdKVStore() {
        return JnaUtils.parameterStoreCreate("device");
    }

    @Override // ai.djl.mxnet.jna.NativeResource, java.lang.AutoCloseable
    public void close() {
        Pointer andSet = this.handle.getAndSet(null);
        if (andSet != null) {
            JnaUtils.parameterStoreClose(andSet);
        }
    }
}
