package org.apache.sysds.runtime.controlprogram.paramserv;

import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.class */
public class LocalPSWorker extends PSWorker implements Callable<Void> {
    protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName());
    private static final long serialVersionUID = 5195390748495357295L;

    /* JADX INFO: Access modifiers changed from: protected */
    public LocalPSWorker() {
    }

    public LocalPSWorker(int i, String str, Statement.PSFrequency pSFrequency, int i2, long j, ExecutionContext executionContext, ParamServer paramServer) {
        super(i, str, pSFrequency, i2, j, executionContext, paramServer);
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    public String getWorkerName() {
        return String.format("Local worker_%d", Integer.valueOf(this._workerID));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public Void call() throws Exception {
        incWorkerNumber();
        try {
            long numRows = this._features.getNumRows();
            int ceil = (int) Math.ceil(numRows / this._batchSize);
            switch (this._freq) {
                case BATCH:
                    computeBatch(numRows, ceil);
                    break;
                case EPOCH:
                    computeEpoch(numRows, ceil);
                    break;
                default:
                    throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), this._freq));
            }
            if (LOG.isDebugEnabled()) {
                LOG.debug(String.format("%s: job finished.", getWorkerName()));
            }
            return null;
        } catch (Exception e) {
            throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e);
        }
    }

    private void computeEpoch(long j, int i) {
        for (int i2 = 0; i2 < this._epochs; i2++) {
            ListObject pullModel = pullModel();
            ListObject listObject = null;
            int i3 = 0;
            while (i3 < i) {
                ListObject computeGradients = computeGradients(pullModel, j, i, i2, i3);
                boolean z = i3 < i - 1;
                listObject = ParamservUtils.accrueGradients(listObject, computeGradients, !z);
                if (z) {
                    pullModel = updateModel(pullModel, computeGradients, i2, i3, i);
                }
                accNumBatches(1);
                i3++;
            }
            pushGradients(listObject);
            ParamservUtils.cleanupListObject(this._ec, Statement.PS_MODEL);
            accNumEpochs(1);
            if (LOG.isDebugEnabled()) {
                LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), Integer.valueOf(i2 + 1)));
            }
        }
    }

    private ListObject updateModel(ListObject listObject, ListObject listObject2, int i, int i2, int i3) {
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        ListObject updateLocalModel = this._ps.updateLocalModel(this._ec, listObject2, listObject);
        accLocalModelUpdateTime(timing);
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. [Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", getWorkerName(), Long.valueOf(updateLocalModel.getDataSize()), Integer.valueOf(i + 1), Integer.valueOf(this._epochs), Integer.valueOf(i2 + 1), Integer.valueOf(i3)));
        }
        return updateLocalModel;
    }

    private void computeBatch(long j, int i) {
        for (int i2 = 0; i2 < this._epochs; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                pushGradients(computeGradients(pullModel(), j, i, i2, i3));
                ParamservUtils.cleanupListObject(this._ec, Statement.PS_MODEL);
                accNumBatches(1);
            }
            accNumEpochs(1);
            if (LOG.isDebugEnabled()) {
                LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), Integer.valueOf(i2 + 1)));
            }
        }
    }

    private ListObject pullModel() {
        ListObject pull = this._ps.pull(this._workerID);
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("%s: successfully pull the global parameters [size:%d kb] from ps.", getWorkerName(), Long.valueOf(pull.getDataSize() / 1024)));
        }
        return pull;
    }

    private void pushGradients(ListObject listObject) {
        this._ps.push(this._workerID, listObject);
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("%s: successfully push the gradients [size:%d kb] to ps.", getWorkerName(), Long.valueOf(listObject.getDataSize() / 1024)));
        }
    }

    private ListObject computeGradients(ListObject listObject, long j, int i, int i2, int i3) {
        this._ec.setVariable(Statement.PS_MODEL, listObject);
        long j2 = (i3 * this._batchSize) + 1;
        long min = Math.min((i3 + 1) * this._batchSize, j);
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        MatrixObject sliceMatrix = ParamservUtils.sliceMatrix(this._features, j2, min);
        MatrixObject sliceMatrix2 = ParamservUtils.sliceMatrix(this._labels, j2, min);
        accBatchIndexingTime(timing);
        this._ec.setVariable(Statement.PS_FEATURES, sliceMatrix);
        this._ec.setVariable(Statement.PS_LABELS, sliceMatrix2);
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. [Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", getWorkerName(), Long.valueOf((sliceMatrix.getDataSize() / 1024) + (sliceMatrix2.getDataSize() / 1024)), Long.valueOf(j2), Long.valueOf(min), Long.valueOf(j), Integer.valueOf(i2 + 1), Integer.valueOf(this._epochs), Integer.valueOf(i3 + 1), Integer.valueOf(i)));
        }
        Timing timing2 = DMLScript.STATISTICS ? new Timing(true) : null;
        this._inst.processInstruction(this._ec);
        accGradientComputeTime(timing2);
        ListObject listObject2 = this._ec.getListObject(this._output.getName());
        ParamservUtils.cleanupData(this._ec, Statement.PS_FEATURES);
        ParamservUtils.cleanupData(this._ec, Statement.PS_LABELS);
        return listObject2;
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void incWorkerNumber() {
        if (DMLScript.STATISTICS) {
            Statistics.incWorkerNumber();
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void accLocalModelUpdateTime(Timing timing) {
        if (DMLScript.STATISTICS) {
            Statistics.accPSLocalModelUpdateTime((long) timing.stop());
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void accBatchIndexingTime(Timing timing) {
        if (DMLScript.STATISTICS) {
            Statistics.accPSBatchIndexingTime((long) timing.stop());
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void accGradientComputeTime(Timing timing) {
        if (DMLScript.STATISTICS) {
            Statistics.accPSGradientComputeTime((long) timing.stop());
        }
    }
}
