package org.apache.sysds.runtime.instructions.spark.utils;

import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.IndexedTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.functions.CopyBinaryCellFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyMatrixBlockFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyMatrixBlockPairFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyTensorBlockFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyTensorBlockPairFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.RecomputeNnzFunction;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixCell;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.class */
public class SparkUtils {
    public static final StorageLevel DEFAULT_TMP = Checkpoint.DEFAULT_STORAGE_LEVEL;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/utils/SparkUtils$AggregateDataCharacteristics.class */
    private static class AggregateDataCharacteristics implements Function2<DataCharacteristics, DataCharacteristics, DataCharacteristics> {
        private static final long serialVersionUID = 4263886749699779994L;

        private AggregateDataCharacteristics() {
        }

        public DataCharacteristics call(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2) throws Exception {
            return new MatrixCharacteristics(Math.max(dataCharacteristics.getRows(), dataCharacteristics2.getRows()), Math.max(dataCharacteristics.getCols(), dataCharacteristics2.getCols()), dataCharacteristics.getBlocksize(), dataCharacteristics.getNonZeros() + dataCharacteristics2.getNonZeros());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/utils/SparkUtils$AnalyzeCellDataCharacteristics.class */
    private static class AnalyzeCellDataCharacteristics implements Function<Tuple2<MatrixIndexes, MatrixCell>, DataCharacteristics> {
        private static final long serialVersionUID = 8899395272683723008L;

        private AnalyzeCellDataCharacteristics() {
        }

        public DataCharacteristics call(Tuple2<MatrixIndexes, MatrixCell> tuple2) throws Exception {
            return new MatrixCharacteristics(((MatrixIndexes) tuple2._1()).getRowIndex(), ((MatrixIndexes) tuple2._1()).getColumnIndex(), 0, ((MatrixCell) tuple2._2()).getValue() != DataExpression.DEFAULT_DELIM_FILL_VALUE ? 1L : 0L);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/utils/SparkUtils$CheckSparsityFunction.class */
    public static class CheckSparsityFunction implements VoidFunction<Tuple2<MatrixIndexes, MatrixBlock>> {
        private static final long serialVersionUID = 4150132775681848807L;

        private CheckSparsityFunction() {
        }

        public void call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            ((MatrixBlock) tuple2._2).checkNonZeros();
            ((MatrixBlock) tuple2._2).checkSparseRows();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/utils/SparkUtils$GenerateEmptyBlocks.class */
    private static class GenerateEmptyBlocks implements PairFlatMapFunction<Long, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 630129586089106855L;
        private final DataCharacteristics _mc;
        private final long _pNumBlocks;

        public GenerateEmptyBlocks(DataCharacteristics dataCharacteristics, long j) {
            this._mc = dataCharacteristics;
            this._pNumBlocks = j;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Long l) throws Exception {
            long numColBlocks = this._mc.getNumColBlocks();
            return LongStream.range(l.longValue(), Math.min(l.longValue() + this._pNumBlocks, this._mc.getNumBlocks())).mapToObj(j -> {
                long j = 1 + (j / numColBlocks);
                long j2 = 1 + (j % numColBlocks);
                return new Tuple2(new MatrixIndexes(j, j2), new MatrixBlock(UtilFunctions.computeBlockSize(this._mc.getRows(), j, this._mc.getBlocksize()), UtilFunctions.computeBlockSize(this._mc.getCols(), j2, this._mc.getBlocksize()), true));
            }).iterator();
        }
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
        return new IndexedMatrixValue((MatrixIndexes) tuple2._1(), (MatrixValue) tuple2._2());
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(MatrixIndexes matrixIndexes, MatrixBlock matrixBlock) {
        return new IndexedMatrixValue(matrixIndexes, matrixBlock);
    }

    public static IndexedTensorBlock toIndexedTensorBlock(Tuple2<TensorIndexes, TensorBlock> tuple2) {
        return new IndexedTensorBlock((TensorIndexes) tuple2._1(), (TensorBlock) tuple2._2());
    }

    public static IndexedTensorBlock toIndexedTensorBlock(TensorIndexes tensorIndexes, TensorBlock tensorBlock) {
        return new IndexedTensorBlock(tensorIndexes, tensorBlock);
    }

    public static Tuple2<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlock(IndexedMatrixValue indexedMatrixValue) {
        return new Tuple2<>(indexedMatrixValue.getIndexes(), (MatrixBlock) indexedMatrixValue.getValue());
    }

    public static List<Tuple2<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlock(List<IndexedMatrixValue> list) {
        return (List) list.stream().map(indexedMatrixValue -> {
            return fromIndexedMatrixBlock(indexedMatrixValue);
        }).collect(Collectors.toList());
    }

    public static Pair<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlockToPair(IndexedMatrixValue indexedMatrixValue) {
        return new Pair<>(indexedMatrixValue.getIndexes(), (MatrixBlock) indexedMatrixValue.getValue());
    }

    public static List<Pair<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlockToPair(List<IndexedMatrixValue> list) {
        return (List) list.stream().map(indexedMatrixValue -> {
            return fromIndexedMatrixBlockToPair(indexedMatrixValue);
        }).collect(Collectors.toList());
    }

    public static Tuple2<Long, FrameBlock> fromIndexedFrameBlock(Pair<Long, FrameBlock> pair) {
        return new Tuple2<>(pair.getKey(), pair.getValue());
    }

    public static List<Tuple2<Long, FrameBlock>> fromIndexedFrameBlock(List<Pair<Long, FrameBlock>> list) {
        return (List) list.stream().map(pair -> {
            return fromIndexedFrameBlock((Pair<Long, FrameBlock>) pair);
        }).collect(Collectors.toList());
    }

    public static List<Pair<Long, Long>> toIndexedLong(List<Tuple2<Long, Long>> list) {
        return (List) list.stream().map(tuple2 -> {
            return new Pair((Long) tuple2._1(), (Long) tuple2._2());
        }).collect(Collectors.toList());
    }

    public static Pair<Long, FrameBlock> toIndexedFrameBlock(Tuple2<Long, FrameBlock> tuple2) {
        return new Pair<>((Long) tuple2._1(), (FrameBlock) tuple2._2());
    }

    public static boolean isHashPartitioned(JavaPairRDD<?, ?> javaPairRDD) {
        return !javaPairRDD.rdd().partitioner().isEmpty() && (javaPairRDD.rdd().partitioner().get() instanceof HashPartitioner);
    }

    public static int getNumPreferredPartitions(DataCharacteristics dataCharacteristics, JavaPairRDD<?, ?> javaPairRDD) {
        return (dataCharacteristics.dimsKnown(true) || javaPairRDD == null) ? getNumPreferredPartitions(dataCharacteristics) : javaPairRDD.getNumPartitions();
    }

    public static int getNumPreferredPartitions(DataCharacteristics dataCharacteristics) {
        return getNumPreferredPartitions(dataCharacteristics, !dataCharacteristics.isNoEmptyBlocks());
    }

    public static int getNumPreferredPartitions(DataCharacteristics dataCharacteristics, boolean z) {
        if (!dataCharacteristics.dimsKnown()) {
            return SparkExecutionContext.getDefaultParallelism(true);
        }
        return (int) Math.max(Math.ceil(OptimizerUtils.estimatePartitionedSizeExactSparsity(dataCharacteristics, z) / InfrastructureAnalyzer.getHDFSBlockSize()), 1.0d);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return copyBinaryBlockMatrix(javaPairRDD, true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, boolean z) {
        return !z ? javaPairRDD.mapValues(new CopyMatrixBlockFunction(false)) : javaPairRDD.mapPartitionsToPair(new CopyMatrixBlockPairFunction(z), true);
    }

    public static JavaPairRDD<TensorIndexes, BasicTensorBlock> copyBinaryBlockTensor(JavaPairRDD<TensorIndexes, BasicTensorBlock> javaPairRDD) {
        return copyBinaryBlockTensor(javaPairRDD, true);
    }

    public static JavaPairRDD<TensorIndexes, BasicTensorBlock> copyBinaryBlockTensor(JavaPairRDD<TensorIndexes, BasicTensorBlock> javaPairRDD, boolean z) {
        return !z ? javaPairRDD.mapValues(new CopyTensorBlockFunction(false)) : javaPairRDD.mapPartitionsToPair(new CopyTensorBlockPairFunction(z), true);
    }

    public static void checkSparsity(String str, ExecutionContext executionContext) {
        ((SparkExecutionContext) executionContext).getBinaryMatrixBlockRDDHandleForVariable(str).foreach(new CheckSparsityFunction());
    }

    public static String getStartLineFromSparkDebugInfo(String str) {
        return str.substring(4, str.length()).split(":")[0];
    }

    public static String getPrefixFromSparkDebugInfo(String str) {
        String[] split = str.split("\\||\\+-");
        String str2 = split[0];
        for (int i = 1; i < split.length - 1; i++) {
            str2 = str2 + "|" + split[i];
        }
        return str.contains("+-") ? str2 + "+- " : str2 + "|" + "  ";
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> getEmptyBlockRDD(JavaSparkContext javaSparkContext, DataCharacteristics dataCharacteristics) {
        int min = (int) Math.min(4.0d * Math.max(SparkExecutionContext.getDefaultParallelism(true), Math.ceil((dataCharacteristics.getNumBlocks() * OptimizerUtils.estimateSizeEmptyBlock(Math.min(Math.max(dataCharacteristics.getRows(), 1L), dataCharacteristics.getBlocksize()), Math.min(Math.max(dataCharacteristics.getCols(), 1L), dataCharacteristics.getBlocksize()))) / InfrastructureAnalyzer.getHDFSBlockSize())), dataCharacteristics.getNumBlocks());
        long ceil = (long) Math.ceil(dataCharacteristics.getNumBlocks() / min);
        return javaSparkContext.parallelize((List) LongStream.iterate(0L, j -> {
            return j + ceil;
        }).limit(min).boxed().collect(Collectors.toList()), min).flatMapToPair(new GenerateEmptyBlocks(dataCharacteristics, ceil));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixCell> cacheBinaryCellRDD(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD) {
        return !javaPairRDD.getStorageLevel().equals(DEFAULT_TMP) ? javaPairRDD.mapToPair(new CopyBinaryCellFunction()).persist(DEFAULT_TMP) : javaPairRDD;
    }

    public static DataCharacteristics computeDataCharacteristics(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD) {
        return (DataCharacteristics) javaPairRDD.map(new AnalyzeCellDataCharacteristics()).reduce(new AggregateDataCharacteristics());
    }

    public static long getNonZeros(MatrixObject matrixObject) {
        return getNonZeros((JavaPairRDD<MatrixIndexes, MatrixBlock>) matrixObject.getRDDHandle().getRDD());
    }

    public static long getNonZeros(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return ((Long) javaPairRDD.filter(new FilterNonEmptyBlocksFunction()).values().mapPartitions(new RecomputeNnzFunction()).reduce((l, l2) -> {
            return Long.valueOf(l.longValue() + l2.longValue());
        })).longValue();
    }

    public static void postprocessUltraSparseOutput(MatrixObject matrixObject, DataCharacteristics dataCharacteristics) {
        long estimateSizeExactSparsity = OptimizerUtils.estimateSizeExactSparsity(dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getNonZerosBound());
        if (OptimizerUtils.exceedsCachingThreshold(dataCharacteristics.getCols(), estimateSizeExactSparsity) || estimateSizeExactSparsity >= OptimizerUtils.estimateSizeExactSparsity(dataCharacteristics)) {
            return;
        }
        matrixObject.acquireReadAndRelease();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 68659531:
                if (implMethodName.equals("lambda$getNonZeros$d800b02a$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/sysds/runtime/instructions/spark/utils/SparkUtils") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Long;Ljava/lang/Long;)Ljava/lang/Long;")) {
                    return (l, l2) -> {
                        return Long.valueOf(l.longValue() + l2.longValue());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
