package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.class */
public abstract class DataPartitionFederatedScheme {

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme$BalanceMetrics.class */
    public static final class BalanceMetrics {
        public final long _minRows;
        public final long _avgRows;
        public final long _maxRows;

        public BalanceMetrics(long j, long j2, long j3) {
            this._minRows = j;
            this._avgRows = j2;
            this._maxRows = j3;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme$Result.class */
    public static final class Result {
        public final List<MatrixObject> _pFeatures;
        public final List<MatrixObject> _pLabels;
        public final int _workerNum;
        public final BalanceMetrics _balanceMetrics;
        public final List<Double> _weightingFactors;

        public Result(List<MatrixObject> list, List<MatrixObject> list2, int i, BalanceMetrics balanceMetrics, List<Double> list3) {
            this._pFeatures = list;
            this._pLabels = list2;
            this._workerNum = i;
            this._balanceMetrics = balanceMetrics;
            this._weightingFactors = list3;
        }
    }

    public abstract Result partition(MatrixObject matrixObject, MatrixObject matrixObject2, int i);

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<MatrixObject> sliceFederatedMatrix(MatrixObject matrixObject) {
        if (!matrixObject.isFederated(FederationMap.FType.ROW)) {
            throw new DMLRuntimeException("Federated data partitioner: currently only supports row federated data");
        }
        List<MatrixObject> synchronizedList = Collections.synchronizedList(new ArrayList());
        matrixObject.getFedMapping().forEachParallel((federatedRange, federatedData) -> {
            MatrixObject matrixObject2 = new MatrixObject(matrixObject.getValueType(), Dag.getNextUniqueVarname(Types.DataType.MATRIX));
            matrixObject2.setMetaData(new MetaDataFormat(new MatrixCharacteristics(federatedRange.getSize(0), federatedRange.getSize(1)), Types.FileFormat.BINARY));
            ArrayList arrayList = new ArrayList();
            arrayList.add(Pair.of(federatedRange, federatedData));
            matrixObject2.setFedMapping(new FederationMap(matrixObject.getFedMapping().getID(), arrayList));
            matrixObject2.getFedMapping().setType(FederationMap.FType.ROW);
            synchronizedList.add(matrixObject2);
            return null;
        });
        return synchronizedList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static BalanceMetrics getBalanceMetrics(List<MatrixObject> list) {
        if (list == null || list.size() == 0) {
            return new BalanceMetrics(0L, 0L, 0L);
        }
        long numRows = list.get(0).getNumRows();
        long j = numRows;
        long j2 = 0;
        for (MatrixObject matrixObject : list) {
            if (matrixObject.getNumRows() < numRows) {
                numRows = matrixObject.getNumRows();
            } else if (matrixObject.getNumRows() > j) {
                j = matrixObject.getNumRows();
            }
            j2 += matrixObject.getNumRows();
        }
        return new BalanceMetrics(numRows, j2 / list.size(), j);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<Double> getWeightingFactors(List<MatrixObject> list, BalanceMetrics balanceMetrics) {
        ArrayList arrayList = new ArrayList();
        list.forEach(matrixObject -> {
            arrayList.add(Double.valueOf(matrixObject.getNumRows() / balanceMetrics._avgRows));
        });
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void shuffle(MatrixObject matrixObject, MatrixBlock matrixBlock) {
        matrixObject.acquireModify(matrixBlock.aggregateBinaryOperations(matrixBlock, matrixObject.acquireReadAndRelease(), new MatrixBlock(), InstructionUtils.getMatMultOperator(InfrastructureAnalyzer.getLocalParallelism())));
        matrixObject.release();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void replicateTo(MatrixObject matrixObject, MatrixBlock matrixBlock) {
        matrixObject.acquireModify(matrixObject.acquireReadAndRelease().append(matrixBlock.aggregateBinaryOperations(matrixBlock, matrixObject.acquireReadAndRelease(), new MatrixBlock(), InstructionUtils.getMatMultOperator(InfrastructureAnalyzer.getLocalParallelism())), new MatrixBlock(), false));
        matrixObject.release();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void subsampleTo(MatrixObject matrixObject, MatrixBlock matrixBlock) {
        matrixObject.acquireModify(matrixBlock.aggregateBinaryOperations(matrixBlock, matrixObject.acquireReadAndRelease(), new MatrixBlock(), InstructionUtils.getMatMultOperator(InfrastructureAnalyzer.getLocalParallelism())));
        matrixObject.release();
    }
}
