package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.training.optimizer.Optimizer;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/training/LocalParameterServer.class */
public class LocalParameterServer implements ParameterServer {
    private Optimizer optimizer;
    private Map<String, NDArray[]> gradMap = new ConcurrentHashMap();

    public LocalParameterServer(Optimizer optimizer) {
        this.optimizer = optimizer;
    }

    @Override // ai.djl.training.ParameterServer
    public void init(String str, NDArray[] nDArrayArr) {
    }

    @Override // ai.djl.training.ParameterServer
    public void push(String str, NDArray[] nDArrayArr, int i) {
        NDArray[] put = this.gradMap.put(str, nDArrayArr);
        if (put != null) {
            Arrays.stream(put).forEach((v0) -> {
                v0.close();
            });
        }
    }

    @Override // ai.djl.training.ParameterServer
    public void pull(String str, NDArray[] nDArrayArr, int i) {
        NDArray asInDevice;
        NDArray[] nDArrayArr2 = this.gradMap.get(str);
        Device device = nDArrayArr2[0].getDevice();
        for (int i2 = 1; i2 < nDArrayArr2.length; i2++) {
            asInDevice = nDArrayArr2[i2].asInDevice(device, true);
            Throwable th = null;
            try {
                try {
                    nDArrayArr2[0].addi(asInDevice);
                    if (asInDevice != null) {
                        if (0 != 0) {
                            try {
                                asInDevice.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            asInDevice.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        }
        for (NDArray nDArray : nDArrayArr) {
            if (nDArray.getDevice().equals(device)) {
                this.optimizer.update(str, nDArray, nDArrayArr2[0]);
            } else {
                asInDevice = nDArrayArr2[0].asInDevice(nDArray.getDevice(), true);
                Throwable th3 = null;
                try {
                    try {
                        this.optimizer.update(str, nDArray, asInDevice);
                        if (asInDevice != null) {
                            if (0 != 0) {
                                try {
                                    asInDevice.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                asInDevice.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            }
        }
        Arrays.stream(nDArrayArr2).forEach((v0) -> {
            v0.close();
        });
    }

    @Override // ai.djl.training.ParameterServer, java.lang.AutoCloseable
    public void close() {
    }
}
