/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.updater;

import java.util.HashMap;
import java.util.Map;
import org.nd4j.aeron.ipc.NDArrayHolder;
import org.nd4j.aeron.ipc.NDArrayMessage;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.parameterserver.updater.BaseParameterUpdater;
import org.nd4j.parameterserver.updater.storage.UpdateStorage;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;

public class SynchronousParameterUpdater
extends BaseParameterUpdater {
    private int workers = Runtime.getRuntime().availableProcessors();
    private static ObjectMapper objectMapper = new ObjectMapper();

    @Override
    public int requiredUpdatesForPass() {
        return this.workers;
    }

    @Override
    public boolean isAsync() {
        return false;
    }

    public SynchronousParameterUpdater(UpdateStorage updateStorage, NDArrayHolder ndArrayHolder, int workers) {
        super(updateStorage, ndArrayHolder);
        this.workers = workers;
    }

    public SynchronousParameterUpdater(UpdateStorage updateStorage, int workers) {
        super(updateStorage);
        this.workers = workers;
    }

    public SynchronousParameterUpdater(int workers) {
        this.workers = workers;
    }

    @Override
    public Map<String, Number> status() {
        HashMap<String, Number> ret = new HashMap<String, Number>();
        ret.put("workers", this.workers);
        ret.put("accumulatedUpdates", this.numUpdates());
        return ret;
    }

    @Override
    public String toJson() {
        try {
            return objectMapper.writeValueAsString(this.status());
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public boolean shouldReplicate() {
        return this.numUpdates() == this.workers;
    }

    @Override
    public void update(NDArrayMessage message) {
        boolean whole;
        this.updateStorage.addUpdate(message);
        INDArray arr = message.getArr();
        int[] dimensions = message.getDimensions();
        boolean bl = whole = dimensions.length == 1 && dimensions[0] == -1;
        if (!whole) {
            this.partialUpdate(arr, this.ndArrayHolder.get(), message.getIndex(), dimensions);
        } else {
            this.update(arr, this.ndArrayHolder.get());
        }
    }

    @Override
    public void partialUpdate(INDArray arr, INDArray result, long idx, int ... dimensions) {
        result.tensorAlongDimension((int)idx, dimensions).addi(arr);
    }

    @Override
    public void update(INDArray arr, INDArray result) {
        result.addi(arr);
    }
}

