package org.apache.sysds.runtime.instructions.spark.functions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.class */
public class ExtractBlockForBinaryReblock implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
    private static final long serialVersionUID = -762987655085029215L;
    private final long rlen;
    private final long clen;
    private final int in_blen;
    private final int out_blen;

    public ExtractBlockForBinaryReblock(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2) {
        this.rlen = dataCharacteristics.getRows();
        this.clen = dataCharacteristics.getCols();
        this.in_blen = dataCharacteristics.getBlocksize();
        this.out_blen = dataCharacteristics2.getBlocksize();
        if (this.in_blen <= 0 || this.out_blen <= 0) {
            throw new DMLRuntimeException("Block sizes unknown:" + this.in_blen + ", " + this.out_blen);
        }
    }

    public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
        if (this.in_blen == this.out_blen) {
            return Collections.singletonList(tuple2).listIterator();
        }
        MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
        MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
        long computeCellIndex = UtilFunctions.computeCellIndex(matrixIndexes.getRowIndex(), this.in_blen, 0);
        long endGlobalIndex = getEndGlobalIndex(matrixIndexes.getRowIndex(), true, true);
        long computeCellIndex2 = UtilFunctions.computeCellIndex(matrixIndexes.getColumnIndex(), this.in_blen, 0);
        long endGlobalIndex2 = getEndGlobalIndex(matrixIndexes.getColumnIndex(), true, false);
        long computeBlockIndex = UtilFunctions.computeBlockIndex(computeCellIndex, this.out_blen);
        long computeBlockIndex2 = UtilFunctions.computeBlockIndex(endGlobalIndex, this.out_blen);
        long computeBlockIndex3 = UtilFunctions.computeBlockIndex(computeCellIndex2, this.out_blen);
        long computeBlockIndex4 = UtilFunctions.computeBlockIndex(endGlobalIndex2, this.out_blen);
        boolean z = this.out_blen % this.in_blen == 0 && this.out_blen % this.in_blen == 0;
        ArrayList arrayList = new ArrayList();
        long j = computeBlockIndex;
        while (true) {
            long j2 = j;
            if (j2 > computeBlockIndex2) {
                return arrayList.iterator();
            }
            long j3 = computeBlockIndex3;
            while (true) {
                long j4 = j3;
                if (j4 <= computeBlockIndex4) {
                    MatrixIndexes matrixIndexes2 = new MatrixIndexes(j2, j4);
                    MatrixBlock matrixBlock2 = new MatrixBlock(UtilFunctions.computeBlockSize(this.rlen, j2, this.out_blen), UtilFunctions.computeBlockSize(this.clen, j4, this.out_blen), true);
                    if (!matrixBlock.isEmptyBlock(false)) {
                        long max = Math.max(UtilFunctions.computeCellIndex(j2, this.out_blen, 0), computeCellIndex);
                        long min = Math.min(getEndGlobalIndex(j2, false, true), endGlobalIndex);
                        long max2 = Math.max(UtilFunctions.computeCellIndex(j4, this.out_blen, 0), computeCellIndex2);
                        long min2 = Math.min(getEndGlobalIndex(j4, false, false), endGlobalIndex2);
                        int computeCellInBlock = UtilFunctions.computeCellInBlock(max, this.in_blen);
                        int computeCellInBlock2 = UtilFunctions.computeCellInBlock(max2, this.in_blen);
                        int computeCellInBlock3 = UtilFunctions.computeCellInBlock(max, this.out_blen);
                        int computeCellInBlock4 = UtilFunctions.computeCellInBlock(max2, this.out_blen);
                        if (!z) {
                            if (matrixBlock instanceof CompressedMatrixBlock) {
                                matrixBlock = CompressedMatrixBlock.getUncompressed(matrixBlock);
                            }
                            for (int i = 0; i <= ((int) (min - max)); i++) {
                                for (int i2 = 0; i2 <= ((int) (min2 - max2)); i2++) {
                                    matrixBlock2.appendValue(computeCellInBlock3 + i, computeCellInBlock4 + i2, matrixBlock.quickGetValue(computeCellInBlock + i, computeCellInBlock2 + i2));
                                }
                            }
                        } else if (matrixBlock instanceof CompressedMatrixBlock) {
                            matrixBlock2.allocateSparseRowsBlock(false);
                            CLALibDecompress.decompressTo((CompressedMatrixBlock) matrixBlock, matrixBlock2, computeCellInBlock3 - computeCellInBlock, computeCellInBlock4 - computeCellInBlock2, 1, true);
                        } else {
                            matrixBlock2.appendToSparse(matrixBlock, computeCellInBlock3, computeCellInBlock4);
                            matrixBlock2.setNonZeros(matrixBlock.getNonZeros());
                        }
                        arrayList.add(new Tuple2(matrixIndexes2, matrixBlock2));
                    }
                    j3 = j4 + 1;
                }
            }
            j = j2 + 1;
        }
    }

    private long getEndGlobalIndex(long j, boolean z, boolean z2) {
        long j2 = z2 ? this.rlen : this.clen;
        int i = z ? z2 ? this.in_blen : this.in_blen : z2 ? this.out_blen : this.out_blen;
        return UtilFunctions.computeCellIndex(j, i, UtilFunctions.computeBlockSize(j2, j, i) - 1);
    }
}
