package org.apache.sysds.runtime.instructions.cp;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.class */
public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction {
    public static final int DEFAULT_BATCH_SIZE = 64;
    public static final int DEFAULT_NBATCHES = 1;
    private static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());
    private static final Statement.PSFrequency DEFAULT_UPDATE_FREQUENCY = Statement.PSFrequency.EPOCH;
    private static final Statement.PSScheme DEFAULT_SCHEME = Statement.PSScheme.DISJOINT_CONTIGUOUS;
    private static final Statement.PSRuntimeBalancing DEFAULT_RUNTIME_BALANCING = Statement.PSRuntimeBalancing.NONE;
    private static final Statement.FederatedPSScheme DEFAULT_FEDERATED_SCHEME = Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER;
    private static final Statement.PSModeType DEFAULT_MODE = Statement.PSModeType.LOCAL;
    private static final Statement.PSUpdateType DEFAULT_TYPE = Statement.PSUpdateType.ASP;
    private static final Boolean DEFAULT_MODELAVG = false;

    public ParamservBuiltinCPInstruction(Operator operator, LinkedHashMap<String, String> linkedHashMap, CPOperand cPOperand, String str, String str2) {
        super(operator, linkedHashMap, cPOperand, str, str2);
    }

    @Override // org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction, org.apache.sysds.runtime.instructions.cp.CPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        if (executionContext.getMatrixObject(getParam(Statement.PS_FEATURES)).isFederated() || executionContext.getMatrixObject(getParam(Statement.PS_LABELS)).isFederated()) {
            runFederated(executionContext);
            return;
        }
        Statement.PSModeType pSMode = getPSMode();
        switch (pSMode) {
            case LOCAL:
                runLocally(executionContext, pSMode);
                return;
            case REMOTE_SPARK:
                runOnSpark((SparkExecutionContext) executionContext, pSMode);
                return;
            default:
                throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", pSMode));
        }
    }

    private void runFederated(ExecutionContext executionContext) {
        if (DMLScript.STATISTICS) {
            Statistics.getPSExecutionTimer().start();
        }
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        LOG.info("PARAMETER SERVER");
        LOG.info("[+] Running in federated mode");
        String param = getParam(Statement.PS_UPDATE_FUN);
        String param2 = getParam(Statement.PS_AGGREGATION_FUN);
        Statement.PSUpdateType updateType = getUpdateType();
        Statement.PSFrequency frequency = getFrequency();
        Statement.FederatedPSScheme federatedScheme = getFederatedScheme();
        Statement.PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
        boolean weighting = getWeighting();
        int seed = getSeed();
        int nbatches = getNbatches();
        if (LOG.isInfoEnabled()) {
            LOG.info("[+] Update Type: " + updateType);
            LOG.info("[+] Frequency: " + frequency);
            LOG.info("[+] Data Partitioning: " + federatedScheme);
            LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
            LOG.info("[+] Weighting: " + weighting);
            LOG.info("[+] Seed: " + seed);
        }
        if (timing != null) {
            Statistics.accPSSetupTime((long) timing.stop());
        }
        Timing timing2 = DMLScript.STATISTICS ? new Timing(true) : null;
        DataPartitionFederatedScheme.Result doPartitioning = new FederatedDataPartitioner(federatedScheme, seed).doPartitioning(executionContext.getMatrixObject(getParam(Statement.PS_FEATURES)), executionContext.getMatrixObject(getParam(Statement.PS_LABELS)));
        int i = doPartitioning._workerNum;
        if (DMLScript.STATISTICS) {
            Statistics.accFedPSDataPartitioningTime((long) timing2.stop());
        }
        if (DMLScript.STATISTICS) {
            timing.start();
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i, new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d").build());
        ExecutionContext createExecutionContext = ParamservUtils.createExecutionContext(executionContext, createVarsMap(executionContext), param, param2, -1, true);
        List<ExecutionContext> copyExecutionContext = ParamservUtils.copyExecutionContext(createExecutionContext, i);
        ExecutionContext executionContext2 = ParamservUtils.copyExecutionContext(createExecutionContext, 1).get(0);
        ListObject listObject = executionContext.getListObject(getParam(Statement.PS_MODEL));
        MatrixObject matrixObject = getParam(Statement.PS_VAL_FEATURES) != null ? executionContext.getMatrixObject(getParam(Statement.PS_VAL_FEATURES)) : null;
        MatrixObject matrixObject2 = getParam(Statement.PS_VAL_LABELS) != null ? executionContext.getMatrixObject(getParam(Statement.PS_VAL_LABELS)) : null;
        boolean parseBoolean = Boolean.parseBoolean(getParam(Statement.PS_MODELAVG));
        ParamServer createPS = createPS(Statement.PSModeType.FEDERATED, param2, updateType, frequency, i, listObject, executionContext2, getValFunction(), getNumBatchesPerEpoch(runtimeBalancing, doPartitioning._balanceMetrics), matrixObject, matrixObject2, nbatches, parseBoolean);
        int numBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, doPartitioning._balanceMetrics);
        List list = (List) IntStream.range(0, i).mapToObj(i2 -> {
            return new FederatedPSControlThread(i2, param, frequency, runtimeBalancing, weighting, getEpochs(), getBatchSize(), numBatchesPerEpoch, (ExecutionContext) copyExecutionContext.get(i2), createPS, nbatches, parseBoolean);
        }).collect(Collectors.toList());
        if (i != list.size()) {
            throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
        }
        for (int i3 = 0; i3 < list.size(); i3++) {
            ((FederatedPSControlThread) list.get(i3)).setFeatures(doPartitioning._pFeatures.get(i3));
            ((FederatedPSControlThread) list.get(i3)).setLabels(doPartitioning._pLabels.get(i3));
            ((FederatedPSControlThread) list.get(i3)).setup(doPartitioning._weightingFactors.get(i3).doubleValue());
        }
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long) timing.stop());
        }
        try {
            try {
                Iterator it = newFixedThreadPool.invokeAll(list).iterator();
                while (it.hasNext()) {
                    ((Future) it.next()).get();
                }
                executionContext.setVariable(this.output.getName(), createPS.getResult());
                if (DMLScript.STATISTICS) {
                    Statistics.accPSExecutionTime((long) Statistics.getPSExecutionTimer().stop());
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
            }
        } finally {
            newFixedThreadPool.shutdownNow();
        }
    }

    private void runOnSpark(SparkExecutionContext sparkExecutionContext, Statement.PSModeType pSModeType) {
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        int workerNum = getWorkerNum(pSModeType);
        String param = getParam(Statement.PS_UPDATE_FUN);
        String param2 = getParam(Statement.PS_AGGREGATION_FUN);
        int nbatches = getNbatches();
        boolean parseBoolean = Boolean.parseBoolean(getParam(Statement.PS_MODELAVG));
        ExecutionContext createExecutionContext = ParamservUtils.createExecutionContext(sparkExecutionContext, createVarsMap(sparkExecutionContext), param, param2, 1);
        ParamServer createPS = createPS(pSModeType, param2, getUpdateType(), getFrequency(), workerNum, sparkExecutionContext.getListObject(getParam(Statement.PS_MODEL)), ParamservUtils.copyExecutionContext(createExecutionContext, 1).get(0), nbatches, parseBoolean);
        TransportServer createServer = PSRpcFactory.createServer(sparkExecutionContext.getSparkContext().getConf(), (LocalParamServer) createPS, sparkExecutionContext.getSparkContext().getConf().get("spark.driver.host"));
        Recompiler.recompileProgramBlockHierarchy2Forced(createExecutionContext.getProgram().getProgramBlocks(), 0L, new HashSet(), Types.ExecType.CP);
        SparkPSBody sparkPSBody = new SparkPSBody(createExecutionContext);
        HashMap hashMap = new HashMap();
        String serializeSparkPSBody = ProgramConverter.serializeSparkPSBody(sparkPSBody, hashMap);
        LongAccumulator longAccumulator = sparkExecutionContext.getSparkContext().sc().longAccumulator("setup");
        LongAccumulator longAccumulator2 = sparkExecutionContext.getSparkContext().sc().longAccumulator("workersNum");
        LongAccumulator longAccumulator3 = sparkExecutionContext.getSparkContext().sc().longAccumulator("modelUpdate");
        LongAccumulator longAccumulator4 = sparkExecutionContext.getSparkContext().sc().longAccumulator("batchIndex");
        LongAccumulator longAccumulator5 = sparkExecutionContext.getSparkContext().sc().longAccumulator("gradCompute");
        LongAccumulator longAccumulator6 = sparkExecutionContext.getSparkContext().sc().longAccumulator("rpcRequest");
        SparkPSWorker sparkPSWorker = new SparkPSWorker(getParam(Statement.PS_UPDATE_FUN), getParam(Statement.PS_AGGREGATION_FUN), getFrequency(), getEpochs(), getBatchSize(), serializeSparkPSBody, hashMap, sparkExecutionContext.getSparkContext().getConf(), createServer.getPort(), longAccumulator, longAccumulator2, longAccumulator3, longAccumulator4, longAccumulator5, longAccumulator6, sparkExecutionContext.getSparkContext().sc().longAccumulator("numBatches"), sparkExecutionContext.getSparkContext().sc().longAccumulator("numEpochs"), nbatches, parseBoolean);
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long) timing.stop());
        }
        try {
            try {
                ParamservUtils.doPartitionOnSpark(sparkExecutionContext, sparkExecutionContext.getMatrixObject(getParam(Statement.PS_FEATURES)), sparkExecutionContext.getMatrixObject(getParam(Statement.PS_LABELS)), getScheme(), workerNum).foreach(sparkPSWorker);
                createServer.close();
                if (DMLScript.STATISTICS) {
                    Statistics.accPSSetupTime(longAccumulator.value().longValue());
                    Statistics.incWorkerNumber(longAccumulator2.value().longValue());
                    Statistics.accPSLocalModelUpdateTime(longAccumulator3.value().longValue());
                    Statistics.accPSBatchIndexingTime(longAccumulator4.value().longValue());
                    Statistics.accPSGradientComputeTime(longAccumulator5.value().longValue());
                    Statistics.accPSRpcRequestTime(longAccumulator6.value().longValue());
                }
                sparkExecutionContext.setVariable(this.output.getName(), createPS.getResult());
            } catch (Exception e) {
                throw new DMLRuntimeException("Paramserv function failed: ", e);
            }
        } catch (Throwable th) {
            createServer.close();
            throw th;
        }
    }

    private void runLocally(ExecutionContext executionContext, Statement.PSModeType pSModeType) {
        if (DMLScript.STATISTICS) {
            Statistics.getPSExecutionTimer().start();
        }
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        int workerNum = getWorkerNum(pSModeType);
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(workerNum, new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d").build());
        String param = getParam(Statement.PS_UPDATE_FUN);
        String param2 = getParam(Statement.PS_AGGREGATION_FUN);
        ExecutionContext createExecutionContext = ParamservUtils.createExecutionContext(executionContext, createVarsMap(executionContext), param, param2, getParLevel(workerNum));
        List<ExecutionContext> copyExecutionContext = ParamservUtils.copyExecutionContext(createExecutionContext, workerNum);
        ExecutionContext executionContext2 = ParamservUtils.copyExecutionContext(createExecutionContext, 1).get(0);
        Statement.PSFrequency frequency = getFrequency();
        Statement.PSUpdateType updateType = getUpdateType();
        int ceil = (int) Math.ceil(Math.ceil(((float) executionContext.getMatrixObject(getParam(Statement.PS_FEATURES)).getNumRows()) / workerNum) / getBatchSize());
        int nbatches = getNbatches();
        ListObject listObject = executionContext.getListObject(getParam(Statement.PS_MODEL));
        MatrixObject matrixObject = getParam(Statement.PS_VAL_FEATURES) != null ? executionContext.getMatrixObject(getParam(Statement.PS_VAL_FEATURES)) : null;
        MatrixObject matrixObject2 = getParam(Statement.PS_VAL_LABELS) != null ? executionContext.getMatrixObject(getParam(Statement.PS_VAL_LABELS)) : null;
        boolean modelAvg = getModelAvg();
        ParamServer createPS = createPS(pSModeType, param2, updateType, frequency, workerNum, listObject, executionContext2, getValFunction(), ceil, matrixObject, matrixObject2, nbatches, modelAvg);
        List<LocalPSWorker> list = (List) IntStream.range(0, workerNum).mapToObj(i -> {
            return new LocalPSWorker(i, param, frequency, getEpochs(), getBatchSize(), (ExecutionContext) copyExecutionContext.get(i), createPS, nbatches, modelAvg);
        }).collect(Collectors.toList());
        Statement.PSScheme scheme = getScheme();
        partitionLocally(scheme, executionContext, list);
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long) timing.stop());
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("\nConfiguration of paramserv func: \nmode: %s \nworkerNum: %d \nupdate frequency: %s \nstrategy: %s \ndata partitioner: %s", pSModeType, Integer.valueOf(workerNum), frequency, updateType, scheme));
        }
        try {
            try {
                Iterator it = newFixedThreadPool.invokeAll(list).iterator();
                while (it.hasNext()) {
                    ((Future) it.next()).get();
                }
                executionContext.setVariable(this.output.getName(), createPS.getResult());
                if (DMLScript.STATISTICS) {
                    Statistics.accPSExecutionTime((long) Statistics.getPSExecutionTimer().stop());
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
            }
        } finally {
            newFixedThreadPool.shutdownNow();
        }
    }

    private LocalVariableMap createVarsMap(ExecutionContext executionContext) {
        LocalVariableMap localVariableMap = new LocalVariableMap();
        ListObject hyperParams = getHyperParams(executionContext);
        if (hyperParams != null) {
            localVariableMap.put(Statement.PS_HYPER_PARAMS, hyperParams);
        }
        return localVariableMap;
    }

    private Statement.PSModeType getPSMode() {
        if (!getParameterMap().containsKey("mode")) {
            return DEFAULT_MODE;
        }
        try {
            return Statement.PSModeType.valueOf(getParam("mode"));
        } catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support ps execution mode '%s'", getParam("mode")));
        }
    }

    private int getEpochs() {
        int intValue = Integer.valueOf(getParam(Statement.PS_EPOCHS)).intValue();
        if (intValue <= 0) {
            throw new DMLRuntimeException(String.format("Paramserv function: The argument '%s' could not be less than or equal to 0.", Statement.PS_EPOCHS));
        }
        return intValue;
    }

    private static int getParLevel(int i) {
        return Math.max((int) Math.ceil(getRemainingCores() / i), 1);
    }

    private Statement.PSUpdateType getUpdateType() {
        if (!getParameterMap().containsKey(Statement.PS_UPDATE_TYPE)) {
            return DEFAULT_TYPE;
        }
        try {
            Statement.PSUpdateType valueOf = Statement.PSUpdateType.valueOf(getParam(Statement.PS_UPDATE_TYPE));
            if (valueOf == Statement.PSUpdateType.SSP) {
                throw new DMLRuntimeException("Paramserv function: Not support update type SSP.");
            }
            return valueOf;
        } catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support update type '%s'.", getParam(Statement.PS_UPDATE_TYPE)));
        }
    }

    private Statement.PSFrequency getFrequency() {
        if (!getParameterMap().containsKey(Statement.PS_FREQUENCY)) {
            return DEFAULT_UPDATE_FREQUENCY;
        }
        try {
            return Statement.PSFrequency.valueOf(getParam(Statement.PS_FREQUENCY));
        } catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support '%s' update frequency.", getParam(Statement.PS_FREQUENCY)));
        }
    }

    private Statement.PSRuntimeBalancing getRuntimeBalancing() {
        if (!getParameterMap().containsKey(Statement.PS_FED_RUNTIME_BALANCING)) {
            return DEFAULT_RUNTIME_BALANCING;
        }
        try {
            return Statement.PSRuntimeBalancing.valueOf(getParam(Statement.PS_FED_RUNTIME_BALANCING));
        } catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support '%s' runtime balancing.", getParam(Statement.PS_FED_RUNTIME_BALANCING)));
        }
    }

    private static int getRemainingCores() {
        return InfrastructureAnalyzer.getLocalParallelism();
    }

    private int getWorkerNum(Statement.PSModeType pSModeType) {
        switch (pSModeType) {
            case LOCAL:
                return getParameterMap().containsKey(Statement.PS_PARALLELISM) ? Integer.valueOf(getParam(Statement.PS_PARALLELISM)).intValue() : getRemainingCores();
            case REMOTE_SPARK:
                return getParameterMap().containsKey(Statement.PS_PARALLELISM) ? Integer.valueOf(getParam(Statement.PS_PARALLELISM)).intValue() : SparkExecutionContext.getDefaultParallelism(true);
            default:
                throw new DMLRuntimeException("Unsupported parameter server: " + pSModeType.name());
        }
    }

    private static ParamServer createPS(Statement.PSModeType pSModeType, String str, Statement.PSUpdateType pSUpdateType, Statement.PSFrequency pSFrequency, int i, ListObject listObject, ExecutionContext executionContext, int i2, boolean z) {
        return createPS(pSModeType, str, pSUpdateType, pSFrequency, i, listObject, executionContext, null, -1, null, null, i2, z);
    }

    private static ParamServer createPS(Statement.PSModeType pSModeType, String str, Statement.PSUpdateType pSUpdateType, Statement.PSFrequency pSFrequency, int i, ListObject listObject, ExecutionContext executionContext, String str2, int i2, MatrixObject matrixObject, MatrixObject matrixObject2, int i3, boolean z) {
        switch (pSModeType) {
            case LOCAL:
            case REMOTE_SPARK:
            case FEDERATED:
                return LocalParamServer.create(listObject, str, pSUpdateType, pSFrequency, executionContext, i, str2, i2, matrixObject, matrixObject2, i3, z);
            default:
                throw new DMLRuntimeException("Unsupported parameter server: " + pSModeType.name());
        }
    }

    private long getBatchSize() {
        if (!getParameterMap().containsKey(Statement.PS_BATCH_SIZE)) {
            return 64L;
        }
        long intValue = Integer.valueOf(getParam(Statement.PS_BATCH_SIZE)).intValue();
        if (intValue <= 0) {
            throw new DMLRuntimeException(String.format("Paramserv function: the number of argument '%s' could not be less than or equal to 0.", Statement.PS_BATCH_SIZE));
        }
        return intValue;
    }

    private ListObject getHyperParams(ExecutionContext executionContext) {
        ListObject listObject = null;
        if (getParameterMap().containsKey(Statement.PS_HYPER_PARAMS)) {
            listObject = executionContext.getListObject(getParam(Statement.PS_HYPER_PARAMS));
        }
        return listObject;
    }

    private void partitionLocally(Statement.PSScheme pSScheme, ExecutionContext executionContext, List<LocalPSWorker> list) {
        DataPartitionLocalScheme.Result doPartitioning = new LocalDataPartitioner(pSScheme).doPartitioning(list.size(), executionContext.getMatrixObject(getParam(Statement.PS_FEATURES)).acquireReadAndRelease(), executionContext.getMatrixObject(getParam(Statement.PS_LABELS)).acquireReadAndRelease());
        List<MatrixObject> list2 = doPartitioning.pFeatures;
        List<MatrixObject> list3 = doPartitioning.pLabels;
        if (list2.size() < list.size()) {
            if (LOG.isWarnEnabled()) {
                LOG.warn(String.format("There is only %d batches of data but has %d workers. Hence, reset the number of workers with %d.", Integer.valueOf(list2.size()), Integer.valueOf(list.size()), Integer.valueOf(list2.size())));
            }
            list = list.subList(0, list2.size());
        }
        for (int i = 0; i < list.size(); i++) {
            list.get(i).setFeatures(list2.get(i));
            list.get(i).setLabels(list3.get(i));
        }
    }

    private Statement.PSScheme getScheme() {
        Statement.PSScheme pSScheme = DEFAULT_SCHEME;
        if (getParameterMap().containsKey(Statement.PS_SCHEME)) {
            try {
                pSScheme = Statement.PSScheme.valueOf(getParam(Statement.PS_SCHEME));
            } catch (IllegalArgumentException e) {
                throw new DMLRuntimeException(String.format("Paramserv function: not support data partition scheme '%s'", getParam(Statement.PS_SCHEME)));
            }
        }
        return pSScheme;
    }

    private Statement.FederatedPSScheme getFederatedScheme() {
        Statement.FederatedPSScheme federatedPSScheme = DEFAULT_FEDERATED_SCHEME;
        if (getParameterMap().containsKey(Statement.PS_SCHEME)) {
            try {
                federatedPSScheme = Statement.FederatedPSScheme.valueOf(getParam(Statement.PS_SCHEME));
            } catch (IllegalArgumentException e) {
                throw new DMLRuntimeException(String.format("Paramserv function in federated mode: not support data partition scheme '%s'", getParam(Statement.PS_SCHEME)));
            }
        }
        return federatedPSScheme;
    }

    private int getNumBatchesPerEpoch(Statement.PSRuntimeBalancing pSRuntimeBalancing, DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
        return (pSRuntimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MIN || pSRuntimeBalancing == Statement.PSRuntimeBalancing.BASELINE) ? (int) Math.ceil(((float) balanceMetrics._minRows) / ((float) getBatchSize())) : (pSRuntimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG || pSRuntimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH) ? (int) Math.ceil(((float) balanceMetrics._avgRows) / ((float) getBatchSize())) : pSRuntimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX ? (int) Math.ceil(((float) balanceMetrics._maxRows) / ((float) getBatchSize())) : (int) Math.ceil(((float) balanceMetrics._avgRows) / ((float) getBatchSize()));
    }

    private boolean getWeighting() {
        return getParameterMap().containsKey(Statement.PS_FED_WEIGHTING) && Boolean.parseBoolean(getParam(Statement.PS_FED_WEIGHTING));
    }

    private String getValFunction() {
        if (getParameterMap().containsKey(Statement.PS_VAL_FUN)) {
            return getParam(Statement.PS_VAL_FUN);
        }
        return null;
    }

    private int getSeed() {
        return getParameterMap().containsKey("seed") ? Integer.parseInt(getParam("seed")) : (int) System.currentTimeMillis();
    }

    private boolean getModelAvg() {
        return !getParameterMap().containsKey(Statement.PS_MODELAVG) ? DEFAULT_MODELAVG.booleanValue() : Boolean.parseBoolean(getParam(Statement.PS_MODELAVG));
    }

    private int getNbatches() {
        if (getParameterMap().containsKey(Statement.PS_NBATCHES)) {
            return Integer.parseInt(getParam(Statement.PS_NBATCHES));
        }
        return 1;
    }
}
