package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.IOException;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcCall;
import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcResponse;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.ListObject;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/SparkPSProxy.class */
public class SparkPSProxy extends ParamServer {
    private final TransportClient _client;
    private final long _rpcTimeout;
    private final LongAccumulator _aRPC;

    public SparkPSProxy(TransportClient transportClient, long j, LongAccumulator longAccumulator) {
        this._client = transportClient;
        this._rpcTimeout = j;
        this._aRPC = longAccumulator;
    }

    private void accRpcRequestTime(Timing timing) {
        if (DMLScript.STATISTICS) {
            this._aRPC.add((long) timing.stop());
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.ParamServer
    public void push(int i, ListObject listObject) {
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        try {
            PSRpcResponse pSRpcResponse = new PSRpcResponse(this._client.sendRpcSync(new PSRpcCall(1, i, listObject).serialize(), this._rpcTimeout));
            accRpcRequestTime(timing);
            if (!pSRpcResponse.isSuccessful()) {
                throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", Integer.valueOf(i), pSRpcResponse.getErrorMessage()));
            }
        } catch (IOException e) {
            throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", Integer.valueOf(i)), e);
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.ParamServer
    public ListObject pull(int i) {
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        try {
            PSRpcResponse pSRpcResponse = new PSRpcResponse(this._client.sendRpcSync(new PSRpcCall(2, i, null).serialize(), this._rpcTimeout));
            accRpcRequestTime(timing);
            if (pSRpcResponse.isSuccessful()) {
                return pSRpcResponse.getResultModel();
            }
            throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", Integer.valueOf(i), pSRpcResponse.getErrorMessage()));
        } catch (IOException e) {
            throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", Integer.valueOf(i)), e);
        }
    }
}
