package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionSparkScheme;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.class */
public class SparkDataPartitioner implements Serializable {
    private static final long serialVersionUID = 6841548626711057448L;
    private DataPartitionSparkScheme _scheme;

    /* JADX INFO: Access modifiers changed from: protected */
    public SparkDataPartitioner(Statement.PSScheme pSScheme, SparkExecutionContext sparkExecutionContext, int i, int i2) {
        switch (pSScheme) {
            case DISJOINT_CONTIGUOUS:
                this._scheme = new DCSparkScheme();
                createDCIndicator(sparkExecutionContext, i2, i);
                return;
            case DISJOINT_ROUND_ROBIN:
                this._scheme = new DRRSparkScheme();
                createDRIndicator(sparkExecutionContext, i2, i);
                return;
            case DISJOINT_RANDOM:
                this._scheme = new DRSparkScheme();
                createGlobalPermutations(sparkExecutionContext, i, 1);
                createDCIndicator(sparkExecutionContext, i2, i);
                return;
            case OVERLAP_RESHUFFLE:
                this._scheme = new ORSparkScheme();
                createGlobalPermutations(sparkExecutionContext, i, i2);
                return;
            default:
                return;
        }
    }

    private void createDRIndicator(SparkExecutionContext sparkExecutionContext, int i, int i2) {
        this._scheme.setWorkerIndicator(sparkExecutionContext.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(DataConverter.convertToMatrixBlock(IntStream.range(0, i2).mapToDouble(i3 -> {
            return i3 % i;
        }).toArray(), true))));
    }

    private void createDCIndicator(SparkExecutionContext sparkExecutionContext, int i, int i2) {
        double[] dArr = new double[i2];
        int ceil = (int) Math.ceil(i2 / i);
        for (int i3 = 1; i3 < i; i3++) {
            int i4 = ceil * i3;
            Arrays.fill(dArr, i4, Math.min(i4 + ceil, i2), i3);
        }
        this._scheme.setWorkerIndicator(sparkExecutionContext.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(DataConverter.convertToMatrixBlock(dArr, true))));
    }

    private void createGlobalPermutations(SparkExecutionContext sparkExecutionContext, int i, int i2) {
        this._scheme.setGlobalPermutation((List) IntStream.range(0, i2).mapToObj(i3 -> {
            MatrixBlock sampleOperations = MatrixBlock.sampleOperations(i, i, false, ParamservUtils.SEED + i3);
            double[] dArr = new double[i];
            for (int i3 = 0; i3 < sampleOperations.getDenseBlockValues().length; i3++) {
                dArr[((int) sampleOperations.getDenseBlockValues()[i3]) - 1] = i3;
            }
            return sparkExecutionContext.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(DataConverter.convertToMatrixBlock(dArr, true)));
        }).collect(Collectors.toList()));
    }

    public DataPartitionSparkScheme.Result doPartitioning(int i, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, long j) {
        return this._scheme.doPartitioning(i, (int) j, matrixBlock, matrixBlock2);
    }
}
