package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Parameter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/training/ParameterStore.class */
public class ParameterStore {
    private NDManager manager;
    private Map<String, ParameterData> parameterMap = new ConcurrentHashMap();
    private Map<Device, Integer> deviceMap = new ConcurrentHashMap();
    private boolean copy;
    private ParameterServer parameterServer;

    /* loaded from: input_file:ai/djl/training/ParameterStore$ParameterData.class */
    private final class ParameterData {
        private Parameter parameter;
        private List<NDArray> list;

        private ParameterData(Parameter parameter) {
            this.parameter = parameter;
            this.list = Collections.synchronizedList(new ArrayList());
        }

        private List<NDArray> getNDArrays() {
            return this.list;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean isEmpty() {
            return this.list.isEmpty();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void add(NDArray nDArray) {
            this.list.add(nDArray);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public NDArray get(int i) {
            return this.list.get(i);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public NDArray[] toArray() {
            return (NDArray[]) this.list.toArray(new NDArray[0]);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean requireGradient() {
            return this.parameter.requireGradient();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void sync() {
            NDArray array = this.parameter.getArray();
            if (ParameterStore.this.deviceMap.containsKey(array.getDevice())) {
                return;
            }
            this.list.get(0).copyTo(array);
        }
    }

    public ParameterStore(NDManager nDManager, boolean z) {
        this.manager = nDManager;
        this.copy = z;
        this.deviceMap.put(nDManager.getDevice(), 0);
    }

    public void setParameterServer(ParameterServer parameterServer, Device[] deviceArr) {
        this.parameterServer = parameterServer;
        this.deviceMap.clear();
        for (int i = 0; i < deviceArr.length; i++) {
            if (this.deviceMap.put(deviceArr[i], Integer.valueOf(i)) != null) {
                throw new IllegalArgumentException("Duplicated devices are not allowed.");
            }
        }
    }

    public void updateAllParameters() {
        for (Map.Entry<String, ParameterData> entry : this.parameterMap.entrySet()) {
            String key = entry.getKey();
            ParameterData value = entry.getValue();
            if (value.requireGradient()) {
                this.parameterServer.update(key, value.toArray());
            }
        }
    }

    public NDArray getValue(Parameter parameter, Device device, boolean z) {
        if (parameter == null) {
            return null;
        }
        String id = parameter.getId();
        int intValue = this.deviceMap.get(device).intValue();
        ParameterData computeIfAbsent = this.parameterMap.computeIfAbsent(id, str -> {
            return new ParameterData(parameter);
        });
        if (computeIfAbsent.isEmpty()) {
            NDArray array = parameter.getArray();
            if (this.parameterServer != null) {
                this.parameterServer.init(id, new NDArray[]{array});
                NDArray[] nDArrayArr = new NDArray[this.deviceMap.size()];
                for (Map.Entry<Device, Integer> entry : this.deviceMap.entrySet()) {
                    Device key = entry.getKey();
                    int intValue2 = entry.getValue().intValue();
                    if (intValue2 == intValue && array.getDevice().equals(key)) {
                        nDArrayArr[intValue2] = array;
                    } else {
                        nDArrayArr[intValue2] = array.toDevice(key, true);
                        nDArrayArr[intValue2].attach(this.manager);
                        if (parameter.requireGradient()) {
                            nDArrayArr[intValue2].attachGradient();
                        }
                    }
                    computeIfAbsent.add(nDArrayArr[intValue2]);
                }
            } else {
                if (this.copy || !array.getDevice().equals(device)) {
                    array = array.toDevice(device, true);
                    array.attach(this.manager);
                    if (parameter.requireGradient() && z) {
                        array.attachGradient();
                    }
                }
                computeIfAbsent.add(array);
            }
        }
        return computeIfAbsent.get(intValue);
    }

    public void sync() {
        Iterator<ParameterData> it = this.parameterMap.values().iterator();
        while (it.hasNext()) {
            it.next().sync();
        }
    }
}
