package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.class */
public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme {

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme$replicateDataOnFederatedWorker.class */
    private static class replicateDataOnFederatedWorker extends FederatedUDF {
        private static final long serialVersionUID = -6930898456315100587L;
        private final int _seed;
        private final int _max_rows;

        protected replicateDataOnFederatedWorker(long[] jArr, int i, int i2) {
            super(jArr);
            this._seed = i;
            this._max_rows = i2;
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public FederatedResponse execute(ExecutionContext executionContext, Data... dataArr) {
            MatrixObject matrixObject = (MatrixObject) dataArr[0];
            MatrixObject matrixObject2 = (MatrixObject) dataArr[1];
            if (matrixObject.getNumRows() < this._max_rows) {
                MatrixBlock generateReplicationMatrix = ParamservUtils.generateReplicationMatrix(this._max_rows - Math.toIntExact(matrixObject.getNumRows()), Math.toIntExact(matrixObject.getNumRows()), this._seed);
                DataPartitionFederatedScheme.replicateTo(matrixObject, generateReplicationMatrix);
                DataPartitionFederatedScheme.replicateTo(matrixObject2, generateReplicationMatrix);
            }
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override // org.apache.sysds.runtime.lineage.LineageTraceable
        public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
            return null;
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme
    public DataPartitionFederatedScheme.Result partition(MatrixObject matrixObject, MatrixObject matrixObject2, int i) {
        List<MatrixObject> sliceFederatedMatrix = sliceFederatedMatrix(matrixObject);
        List<MatrixObject> sliceFederatedMatrix2 = sliceFederatedMatrix(matrixObject2);
        List<Double> weightingFactors = getWeightingFactors(sliceFederatedMatrix, getBalanceMetrics(sliceFederatedMatrix));
        int i2 = 0;
        for (MatrixObject matrixObject3 : sliceFederatedMatrix) {
            i2 = matrixObject3.getNumRows() > ((long) i2) ? Math.toIntExact(matrixObject3.getNumRows()) : i2;
        }
        for (int i3 = 0; i3 < sliceFederatedMatrix.size(); i3++) {
            FederatedData federatedData = sliceFederatedMatrix.get(i3).getFedMapping().getFederatedData()[0];
            try {
                if (!federatedData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, federatedData.getVarID(), new replicateDataOnFederatedWorker(new long[]{federatedData.getVarID(), sliceFederatedMatrix2.get(i3).getFedMapping().getFederatedData()[0].getVarID()}, i, i2))).get().isSuccessful()) {
                    throw new DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: replicate UDF returned fail");
                }
                sliceFederatedMatrix.get(i3).updateDataCharacteristics(sliceFederatedMatrix.get(i3).getDataCharacteristics().setRows(i2));
                sliceFederatedMatrix2.get(i3).updateDataCharacteristics(sliceFederatedMatrix2.get(i3).getDataCharacteristics().setRows(i2));
            } catch (Exception e) {
                throw new DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: executing replicate UDF failed" + e.getMessage());
            }
        }
        return new DataPartitionFederatedScheme.Result(sliceFederatedMatrix, sliceFederatedMatrix2, sliceFederatedMatrix.size(), getBalanceMetrics(sliceFederatedMatrix), weightingFactors);
    }
}
