package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.class */
public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme$shuffleDataOnFederatedWorker.class */
    private static class shuffleDataOnFederatedWorker extends FederatedUDF {
        private static final long serialVersionUID = 3228664618781333325L;
        private final int _seed;

        protected shuffleDataOnFederatedWorker(long[] jArr, int i) {
            super(jArr);
            this._seed = i;
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public FederatedResponse execute(ExecutionContext executionContext, Data... dataArr) {
            MatrixObject matrixObject = (MatrixObject) dataArr[0];
            MatrixObject matrixObject2 = (MatrixObject) dataArr[1];
            MatrixBlock generatePermutation = ParamservUtils.generatePermutation(Math.toIntExact(matrixObject.getNumRows()), this._seed);
            DataPartitionFederatedScheme.shuffle(matrixObject, generatePermutation);
            DataPartitionFederatedScheme.shuffle(matrixObject2, generatePermutation);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override // org.apache.sysds.runtime.lineage.LineageTraceable
        public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
            return null;
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme
    public DataPartitionFederatedScheme.Result partition(MatrixObject matrixObject, MatrixObject matrixObject2, int i) {
        List<MatrixObject> sliceFederatedMatrix = sliceFederatedMatrix(matrixObject);
        List<MatrixObject> sliceFederatedMatrix2 = sliceFederatedMatrix(matrixObject2);
        DataPartitionFederatedScheme.BalanceMetrics balanceMetrics = getBalanceMetrics(sliceFederatedMatrix);
        List<Double> weightingFactors = getWeightingFactors(sliceFederatedMatrix, balanceMetrics);
        for (int i2 = 0; i2 < sliceFederatedMatrix.size(); i2++) {
            FederatedData federatedData = sliceFederatedMatrix.get(i2).getFedMapping().getFederatedData()[0];
            try {
                FederatedResponse federatedResponse = federatedData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, federatedData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{federatedData.getVarID(), sliceFederatedMatrix2.get(i2).getFedMapping().getFederatedData()[0].getVarID()}, i))).get();
                if (!federatedResponse.isSuccessful()) {
                    throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: shuffle UDF returned fail. Federated worker error message: " + federatedResponse.getErrorMessage());
                }
            } catch (Exception e) {
                throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing shuffle UDF failed" + e.getMessage());
            }
        }
        return new DataPartitionFederatedScheme.Result(sliceFederatedMatrix, sliceFederatedMatrix2, sliceFederatedMatrix.size(), balanceMetrics, weightingFactors);
    }
}
