package org.apache.sysds.runtime.instructions.spark;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.rdd.PartitionPruningRDD;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInRange;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;
import scala.reflect.ClassManifestFactory;
import scala.runtime.AbstractFunction1;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.class */
public class MatrixIndexingSPInstruction extends IndexingSPInstruction {
    private final LeftIndex.LixCacheType _type;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$LeftIndexPartitionFunction.class */
    public static class LeftIndexPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1757075506076838258L;
        private final PartitionedBroadcast<MatrixBlock> _binput;
        private final IndexRange _ixrange;
        private final LeftIndex.LixCacheType _type;
        private final int _blen;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$LeftIndexPartitionFunction$LeftIndexPartitionIterator.class */
        public class LeftIndexPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public LeftIndexPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
                if (LeftIndexPartitionFunction.this._type == LeftIndex.LixCacheType.RIGHT && !UtilFunctions.isInBlockRange((MatrixIndexes) tuple2._1(), LeftIndexPartitionFunction.this._blen, LeftIndexPartitionFunction.this._ixrange)) {
                    return tuple2;
                }
                if (LeftIndexPartitionFunction.this._type == LeftIndex.LixCacheType.LEFT) {
                    MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
                    return new Tuple2<>(matrixIndexes, LeftIndexPartitionFunction.this._binput.getBlock((int) matrixIndexes.getRowIndex(), (int) matrixIndexes.getColumnIndex()).leftIndexingOperations((MatrixBlock) tuple2._2(), UtilFunctions.computeCellInBlock(LeftIndexPartitionFunction.this._ixrange.rowStart, LeftIndexPartitionFunction.this._blen), ((int) Math.min(LeftIndexPartitionFunction.this._ixrange.rowEnd, r0 + r0.getNumRows())) - 1, UtilFunctions.computeCellInBlock(LeftIndexPartitionFunction.this._ixrange.colStart, LeftIndexPartitionFunction.this._blen), ((int) Math.min(LeftIndexPartitionFunction.this._ixrange.colEnd, r0 + r0.getNumColumns())) - 1, new MatrixBlock(), MatrixObject.UpdateType.COPY));
                }
                long max = Math.max(LeftIndexPartitionFunction.this._ixrange.rowStart, ((((MatrixIndexes) tuple2._1).getRowIndex() - 1) * LeftIndexPartitionFunction.this._blen) + 1);
                long min = Math.min(LeftIndexPartitionFunction.this._ixrange.rowEnd, ((MatrixIndexes) tuple2._1).getRowIndex() * LeftIndexPartitionFunction.this._blen);
                long max2 = Math.max(LeftIndexPartitionFunction.this._ixrange.colStart, ((((MatrixIndexes) tuple2._1).getColumnIndex() - 1) * LeftIndexPartitionFunction.this._blen) + 1);
                long min2 = Math.min(LeftIndexPartitionFunction.this._ixrange.colEnd, ((MatrixIndexes) tuple2._1).getColumnIndex() * LeftIndexPartitionFunction.this._blen);
                long j = (max - LeftIndexPartitionFunction.this._ixrange.rowStart) + 1;
                long j2 = j + (min - max);
                long j3 = (max2 - LeftIndexPartitionFunction.this._ixrange.colStart) + 1;
                return new Tuple2<>((MatrixIndexes) tuple2._1, ((MatrixBlock) tuple2._2).leftIndexingOperations(LeftIndexPartitionFunction.this._binput.slice(j, j2, j3, j3 + (min2 - max2), new MatrixBlock()), UtilFunctions.computeCellInBlock(max, LeftIndexPartitionFunction.this._blen), UtilFunctions.computeCellInBlock(min, LeftIndexPartitionFunction.this._blen), UtilFunctions.computeCellInBlock(max2, LeftIndexPartitionFunction.this._blen), UtilFunctions.computeCellInBlock(min2, LeftIndexPartitionFunction.this._blen), new MatrixBlock(), MatrixObject.UpdateType.COPY));
            }
        }

        public LeftIndexPartitionFunction(PartitionedBroadcast<MatrixBlock> partitionedBroadcast, IndexRange indexRange, LeftIndex.LixCacheType lixCacheType, DataCharacteristics dataCharacteristics) {
            this._binput = partitionedBroadcast;
            this._ixrange = indexRange;
            this._type = lixCacheType;
            this._blen = dataCharacteristics.getBlocksize();
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            return new LeftIndexPartitionIterator(it);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$PartitionPruningFunction.class */
    public static class PartitionPruningFunction extends AbstractFunction1<Object, Object> implements Serializable {
        private static final long serialVersionUID = -9114299718258329951L;
        private HashSet<Integer> _filterFlags;

        public PartitionPruningFunction(HashSet<Integer> hashSet) {
            this._filterFlags = null;
            this._filterFlags = hashSet;
        }

        /* renamed from: apply, reason: merged with bridge method [inline-methods] */
        public Boolean m815apply(Object obj) {
            return Boolean.valueOf(this._filterFlags.contains(obj));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$SliceBlock2.class */
    public static class SliceBlock2 implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 7481889252529447770L;
        private IndexRange _ixrange;
        private int _blen;

        public SliceBlock2(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = indexRange;
            this._blen = dataCharacteristics.getBlocksize();
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            return SparkUtils.fromIndexedMatrixBlock(OperationsOnMatrixValues.performSlice(new IndexedMatrixValue((MatrixIndexes) tuple2._1(), (MatrixValue) tuple2._2()), this._ixrange, this._blen).get(0));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$SliceBlockPartitionFunction.class */
    public static class SliceBlockPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8111291718258309968L;
        private IndexRange _ixrange;
        private int _blen;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$SliceBlockPartitionFunction$SliceBlockPartitionIterator.class */
        public class SliceBlockPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            static final /* synthetic */ boolean $assertionsDisabled;

            public SliceBlockPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
                ArrayList<IndexedMatrixValue> performSlice = OperationsOnMatrixValues.performSlice(SparkUtils.toIndexedMatrixBlock(tuple2), SliceBlockPartitionFunction.this._ixrange, SliceBlockPartitionFunction.this._blen);
                if ($assertionsDisabled || performSlice.size() == 1) {
                    return SparkUtils.fromIndexedMatrixBlock(performSlice.get(0));
                }
                throw new AssertionError();
            }

            static {
                $assertionsDisabled = !MatrixIndexingSPInstruction.class.desiredAssertionStatus();
            }
        }

        public SliceBlockPartitionFunction(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = indexRange;
            this._blen = dataCharacteristics.getBlocksize();
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            return new SliceBlockPartitionIterator(it);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$SliceMultipleBlocks.class */
    public static class SliceMultipleBlocks implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 5733886476413136826L;
        private final IndexRange _ixrange;
        private final int _blen;

        public SliceMultipleBlocks(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = indexRange;
            this._blen = dataCharacteristics.getBlocksize();
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            return SparkUtils.fromIndexedMatrixBlock(OperationsOnMatrixValues.performSlice(SparkUtils.toIndexedMatrixBlock(tuple2), this._ixrange, this._blen)).iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$SliceRHSForLeftIndexing.class */
    private static class SliceRHSForLeftIndexing implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 5724800998701216440L;
        private IndexRange _ixrange;
        private int _blen;
        private long _rlen;
        private long _clen;

        public SliceRHSForLeftIndexing(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = null;
            this._blen = -1;
            this._rlen = -1L;
            this._clen = -1L;
            this._ixrange = indexRange;
            this._rlen = dataCharacteristics.getRows();
            this._clen = dataCharacteristics.getCols();
            this._blen = dataCharacteristics.getBlocksize();
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            IndexedMatrixValue indexedMatrixBlock = SparkUtils.toIndexedMatrixBlock(tuple2);
            ArrayList arrayList = new ArrayList();
            OperationsOnMatrixValues.performShift(indexedMatrixBlock, this._ixrange, this._blen, this._rlen, this._clen, (ArrayList<IndexedMatrixValue>) arrayList);
            return SparkUtils.fromIndexedMatrixBlock(arrayList).iterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$SliceSingleBlock.class */
    public static class SliceSingleBlock implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -6724027136506200924L;
        private final IndexRange _ixrange;
        private final int _blen;

        public SliceSingleBlock(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = indexRange;
            this._blen = dataCharacteristics.getBlocksize();
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            long computeCellIndex = UtilFunctions.computeCellIndex(matrixIndexes.getRowIndex(), this._blen, 0);
            long computeCellIndex2 = UtilFunctions.computeCellIndex(matrixIndexes.getColumnIndex(), this._blen, 0);
            int i = (int) (this._ixrange.rowStart < computeCellIndex ? 0L : this._ixrange.rowStart - computeCellIndex);
            int i2 = (int) (this._ixrange.colStart < computeCellIndex2 ? 0L : this._ixrange.colStart - computeCellIndex2);
            int min = (int) Math.min(matrixBlock.getNumRows() - 1, this._ixrange.rowEnd - computeCellIndex);
            int min2 = (int) Math.min(matrixBlock.getNumColumns() - 1, this._ixrange.colEnd - computeCellIndex2);
            MatrixIndexes matrixIndexes2 = new MatrixIndexes(matrixIndexes.getRowIndex() - ((this._ixrange.rowStart - 1) / this._blen), matrixIndexes.getColumnIndex() - ((this._ixrange.colStart - 1) / this._blen));
            return (i == 0 && min == matrixBlock.getNumRows() - 1 && i2 == 0 && min2 == matrixBlock.getNumColumns() - 1) ? new Tuple2<>(matrixIndexes2, matrixBlock) : new Tuple2<>(matrixIndexes2, matrixBlock.slice(i, min, i2, min2, new MatrixBlock()));
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction$ZeroOutLHS.class */
    private static class ZeroOutLHS implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -3581795160948484261L;
        private boolean _complement;
        private IndexRange _ixrange;
        private int _blen;

        public ZeroOutLHS(boolean z, IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._complement = false;
            this._ixrange = null;
            this._blen = -1;
            this._complement = z;
            this._ixrange = indexRange;
            this._blen = dataCharacteristics.getBlocksize();
            this._blen = dataCharacteristics.getBlocksize();
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            if (!UtilFunctions.isInBlockRange((MatrixIndexes) tuple2._1(), this._blen, this._ixrange)) {
                return tuple2;
            }
            IndexRange selectedRangeForZeroOut = UtilFunctions.getSelectedRangeForZeroOut(new IndexedMatrixValue((MatrixIndexes) tuple2._1, (MatrixValue) tuple2._2), this._blen, this._ixrange);
            if (selectedRangeForZeroOut.rowStart == -1 && selectedRangeForZeroOut.rowEnd == -1 && selectedRangeForZeroOut.colStart == -1 && selectedRangeForZeroOut.colEnd == -1) {
                throw new Exception("Error while getting range for zero-out");
            }
            return new Tuple2<>((MatrixIndexes) tuple2._1, ((MatrixBlock) tuple2._2).zeroOutOperations((MatrixValue) new MatrixBlock(), selectedRangeForZeroOut, this._complement));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MatrixIndexingSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, AggBinaryOp.SparkAggType sparkAggType, String str, String str2) {
        super(cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, cPOperand6, sparkAggType, str, str2);
        this._type = LeftIndex.LixCacheType.NONE;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MatrixIndexingSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, CPOperand cPOperand7, LeftIndex.LixCacheType lixCacheType, String str, String str2) {
        super(cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, cPOperand6, cPOperand7, str, str2);
        this._type = lixCacheType;
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        String opcode = getOpcode();
        long longValue = executionContext.getScalarInput(this.rowLower).getLongValue();
        long longValue2 = executionContext.getScalarInput(this.rowUpper).getLongValue();
        long longValue3 = executionContext.getScalarInput(this.colLower).getLongValue();
        long longValue4 = executionContext.getScalarInput(this.colUpper).getLongValue();
        IndexRange indexRange = new IndexRange(longValue, longValue2, longValue3, longValue4);
        if (opcode.equalsIgnoreCase(RightIndex.OPCODE)) {
            DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
            DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
            dataCharacteristics2.set((longValue2 - longValue) + 1, (longValue4 - longValue3) + 1, dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
            dataCharacteristics2.setNonZerosBound(Math.min(dataCharacteristics2.getLength(), dataCharacteristics.getNonZerosBound()));
            checkValidOutputDimensions(dataCharacteristics2);
            JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
            if (isSingleBlockLookup(dataCharacteristics, indexRange)) {
                sparkExecutionContext.setMatrixOutput(this.output.getName(), singleBlockIndexing(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics, dataCharacteristics2, indexRange));
                return;
            } else if (isMultiBlockLookup(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics, dataCharacteristics2, indexRange)) {
                sparkExecutionContext.setMatrixOutput(this.output.getName(), multiBlockIndexing(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics, dataCharacteristics2, indexRange));
                return;
            } else {
                sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), generalCaseRightIndexing(binaryMatrixBlockRDDHandleForVariable, dataCharacteristics, dataCharacteristics2, indexRange, this._aggType));
                sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
                return;
            }
        }
        if (!opcode.equalsIgnoreCase(LeftIndex.OPCODE) && !opcode.equalsIgnoreCase("mapLeftIndex")) {
            throw new DMLRuntimeException("Invalid opcode (" + opcode + ") encountered in MatrixIndexingSPInstruction.");
        }
        String name = this._type == LeftIndex.LixCacheType.LEFT ? this.input2.getName() : this.input1.getName();
        String name2 = this._type == LeftIndex.LixCacheType.LEFT ? this.input1.getName() : this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable2 = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(name);
        PartitionedBroadcast<MatrixBlock> partitionedBroadcast = null;
        JavaPairRDD javaPairRDD = null;
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        DataCharacteristics dataCharacteristics4 = executionContext.getDataCharacteristics(this.input1.getName());
        dataCharacteristics3.set(dataCharacteristics4.getRows(), dataCharacteristics4.getCols(), dataCharacteristics4.getBlocksize(), dataCharacteristics4.getBlocksize());
        checkValidOutputDimensions(dataCharacteristics3);
        DataCharacteristics dataCharacteristics5 = executionContext.getDataCharacteristics(this.input2.getName());
        if (!dataCharacteristics5.dimsKnown()) {
            throw new DMLRuntimeException("The right input matrix dimensions are not specified for MatrixIndexingSPInstruction");
        }
        if ((longValue2 - longValue) + 1 != dataCharacteristics5.getRows() || (longValue4 - longValue3) + 1 != dataCharacteristics5.getCols()) {
            dataCharacteristics5.getRows();
            dataCharacteristics5.getCols();
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Invalid index range of leftindexing: [" + longValue + ":" + dMLRuntimeException + "," + longValue2 + ":" + dMLRuntimeException + "] vs [" + longValue3 + "x" + dMLRuntimeException + "].");
            throw dMLRuntimeException;
        }
        if (opcode.equalsIgnoreCase("mapLeftIndex")) {
            partitionedBroadcast = sparkExecutionContext.getBroadcastForVariable(name2);
            mergeByKey = binaryMatrixBlockRDDHandleForVariable2.mapPartitionsToPair(new LeftIndexPartitionFunction(partitionedBroadcast, indexRange, this._type, dataCharacteristics3), true);
        } else {
            JavaPairRDD mapToPair = binaryMatrixBlockRDDHandleForVariable2.mapToPair(new ZeroOutLHS(false, indexRange, dataCharacteristics4));
            javaPairRDD = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName()).flatMapToPair(new SliceRHSForLeftIndexing(indexRange, dataCharacteristics4));
            mergeByKey = RDDAggregateUtils.mergeByKey(mapToPair.union(javaPairRDD));
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mergeByKey);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        if (partitionedBroadcast != null) {
            sparkExecutionContext.addLineageBroadcast(this.output.getName(), name2);
        }
        if (javaPairRDD != null) {
            sparkExecutionContext.addLineageRDD(this.output.getName(), this.input2.getName());
        }
    }

    public static MatrixBlock inmemoryIndexing(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, IndexRange indexRange) {
        if (isSingleBlockLookup(dataCharacteristics, indexRange)) {
            return singleBlockIndexing(javaPairRDD, dataCharacteristics, dataCharacteristics2, indexRange);
        }
        if (isMultiBlockLookup(javaPairRDD, dataCharacteristics, dataCharacteristics2, indexRange)) {
            return multiBlockIndexing(javaPairRDD, dataCharacteristics, dataCharacteristics2, indexRange);
        }
        throw new DMLRuntimeException("Incorrect usage of inmemoryIndexing");
    }

    private static MatrixBlock multiBlockIndexing(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, IndexRange indexRange) {
        ArrayList arrayList = new ArrayList();
        long computeBlockIndex = UtilFunctions.computeBlockIndex(indexRange.rowStart, dataCharacteristics.getBlocksize());
        long computeBlockIndex2 = UtilFunctions.computeBlockIndex(indexRange.rowEnd, dataCharacteristics.getBlocksize());
        long computeBlockIndex3 = UtilFunctions.computeBlockIndex(indexRange.colStart, dataCharacteristics.getBlocksize());
        long computeBlockIndex4 = UtilFunctions.computeBlockIndex(indexRange.colEnd, dataCharacteristics.getBlocksize());
        long j = computeBlockIndex;
        while (true) {
            long j2 = j;
            if (j2 > computeBlockIndex2) {
                return SparkExecutionContext.toMatrixBlock((JavaPairRDD<MatrixIndexes, MatrixBlock>) createPartitionPruningRDD(javaPairRDD, arrayList).filter(new IsBlockInRange(indexRange.rowStart, indexRange.rowEnd, indexRange.colStart, indexRange.colEnd, dataCharacteristics2)).mapToPair(new SliceBlock2(indexRange, dataCharacteristics2)), (int) dataCharacteristics2.getRows(), (int) dataCharacteristics2.getCols(), dataCharacteristics2.getBlocksize(), -1L);
            }
            long j3 = computeBlockIndex3;
            while (true) {
                long j4 = j3;
                if (j4 <= computeBlockIndex4) {
                    arrayList.add(new MatrixIndexes(j2, j4));
                    j3 = j4 + 1;
                }
            }
            j = j2 + 1;
        }
    }

    private static MatrixBlock singleBlockIndexing(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, IndexRange indexRange) {
        List lookup = javaPairRDD.lookup(new MatrixIndexes(UtilFunctions.computeBlockIndex(indexRange.rowStart, dataCharacteristics.getBlocksize()), UtilFunctions.computeBlockIndex(indexRange.colStart, dataCharacteristics.getBlocksize())));
        if (lookup.size() != 1) {
            throw new DMLRuntimeException("Block lookup returned " + lookup.size() + " blocks (expected 1).");
        }
        MatrixBlock matrixBlock = (MatrixBlock) lookup.get(0);
        MatrixBlock slice = (((long) matrixBlock.getNumRows()) == dataCharacteristics2.getRows() && ((long) matrixBlock.getNumColumns()) == dataCharacteristics2.getCols()) ? matrixBlock : matrixBlock.slice(UtilFunctions.computeCellInBlock(indexRange.rowStart, dataCharacteristics.getBlocksize()), UtilFunctions.computeCellInBlock(indexRange.rowEnd, dataCharacteristics.getBlocksize()), UtilFunctions.computeCellInBlock(indexRange.colStart, dataCharacteristics.getBlocksize()), UtilFunctions.computeCellInBlock(indexRange.colEnd, dataCharacteristics.getBlocksize()), new MatrixBlock());
        slice.examSparsity();
        return slice;
    }

    private static JavaPairRDD<MatrixIndexes, MatrixBlock> generalCaseRightIndexing(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, IndexRange indexRange, AggBinaryOp.SparkAggType sparkAggType) {
        JavaPairRDD<MatrixIndexes, MatrixBlock> mapToPair;
        if (isPartitioningPreservingRightIndexing(dataCharacteristics, indexRange)) {
            mapToPair = javaPairRDD.mapPartitionsToPair(new SliceBlockPartitionFunction(indexRange, dataCharacteristics2), true);
        } else if (sparkAggType == AggBinaryOp.SparkAggType.NONE || OptimizerUtils.isIndexingRangeBlockAligned(indexRange, dataCharacteristics)) {
            mapToPair = javaPairRDD.filter(new IsBlockInRange(indexRange.rowStart, indexRange.rowEnd, indexRange.colStart, indexRange.colEnd, dataCharacteristics2)).mapToPair(new SliceSingleBlock(indexRange, dataCharacteristics2));
            int numPreferredPartitions = SparkUtils.getNumPreferredPartitions(dataCharacteristics2);
            if (1.4d * ((double) numPreferredPartitions) < ((double) javaPairRDD.getNumPartitions()) && !SparkUtils.isHashPartitioned(javaPairRDD)) {
                mapToPair = mapToPair.coalesce(numPreferredPartitions);
            }
        } else {
            mapToPair = RDDAggregateUtils.mergeByKey(javaPairRDD.filter(new IsBlockInRange(indexRange.rowStart, indexRange.rowEnd, indexRange.colStart, indexRange.colEnd, dataCharacteristics2)).flatMapToPair(new SliceMultipleBlocks(indexRange, dataCharacteristics2)));
        }
        return mapToPair;
    }

    private static void checkValidOutputDimensions(DataCharacteristics dataCharacteristics) {
        if (!dataCharacteristics.dimsKnown()) {
            throw new DMLRuntimeException("MatrixIndexingSPInstruction: The updated output dimensions are invalid: " + dataCharacteristics);
        }
    }

    private static boolean isPartitioningPreservingRightIndexing(DataCharacteristics dataCharacteristics, IndexRange indexRange) {
        return (dataCharacteristics.dimsKnown() && indexRange.rowStart == 1 && indexRange.rowEnd == dataCharacteristics.getRows() && dataCharacteristics.getCols() <= ((long) dataCharacteristics.getBlocksize())) || (indexRange.colStart == 1 && indexRange.colEnd == dataCharacteristics.getCols() && dataCharacteristics.getRows() <= ((long) dataCharacteristics.getBlocksize()));
    }

    public static boolean isSingleBlockLookup(DataCharacteristics dataCharacteristics, IndexRange indexRange) {
        return UtilFunctions.computeBlockIndex(indexRange.rowStart, dataCharacteristics.getBlocksize()) == UtilFunctions.computeBlockIndex(indexRange.rowEnd, dataCharacteristics.getBlocksize()) && UtilFunctions.computeBlockIndex(indexRange.colStart, dataCharacteristics.getBlocksize()) == UtilFunctions.computeBlockIndex(indexRange.colEnd, dataCharacteristics.getBlocksize());
    }

    public static boolean isMultiBlockLookup(JavaPairRDD<?, ?> javaPairRDD, DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, IndexRange indexRange) {
        return SparkUtils.isHashPartitioned(javaPairRDD) && ((double) OptimizerUtils.estimatePartitionedSizeExactSparsity(dataCharacteristics)) > SparkExecutionContext.getDataMemoryBudget(true, true) && OptimizerUtils.isIndexingRangeBlockAligned(indexRange, dataCharacteristics) && ((double) OptimizerUtils.estimateSize(dataCharacteristics2)) < OptimizerUtils.getLocalMemBudget() / 2.0d;
    }

    private static JavaPairRDD<MatrixIndexes, MatrixBlock> createPartitionPruningRDD(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, List<MatrixIndexes> list) {
        HashSet hashSet = new HashSet();
        Partitioner partitioner = (Partitioner) javaPairRDD.rdd().partitioner().get();
        Iterator<MatrixIndexes> it = list.iterator();
        while (it.hasNext()) {
            hashSet.add(Integer.valueOf(partitioner.getPartition(it.next())));
        }
        return new JavaPairRDD<>(PartitionPruningRDD.create(javaPairRDD.rdd(), new PartitionPruningFunction(hashSet)), ClassManifestFactory.fromClass(MatrixIndexes.class), ClassManifestFactory.fromClass(MatrixBlock.class));
    }

    @Override // org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction, org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, this.input1, this.input2, this.input3, this.rowLower, this.rowUpper, this.colLower, this.colUpper)));
    }
}
