package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.IsFrameBlockInRange;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction.class */
public class FrameIndexingSPInstruction extends IndexingSPInstruction {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction$LeftIndexPartitionFunction.class */
    public static class LeftIndexPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Long, FrameBlock> {
        private static final long serialVersionUID = -911940376947364915L;
        private PartitionedBroadcast<FrameBlock> _binput;
        private IndexRange _ixrange;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction$LeftIndexPartitionFunction$LeftIndexPartitionIterator.class */
        public class LeftIndexPartitionIterator extends LazyIterableIterator<Tuple2<Long, FrameBlock>> {
            public LeftIndexPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> tuple2) throws Exception {
                int numRows = ((FrameBlock) tuple2._2).getNumRows();
                int numColumns = ((FrameBlock) tuple2._2).getNumColumns();
                if (!UtilFunctions.isInFrameBlockRange((Long) tuple2._1(), numRows, LeftIndexPartitionFunction.this._ixrange)) {
                    return tuple2;
                }
                long max = Math.max(LeftIndexPartitionFunction.this._ixrange.rowStart, ((Long) tuple2._1).longValue());
                long min = Math.min(LeftIndexPartitionFunction.this._ixrange.rowEnd, (((Long) tuple2._1).longValue() + numRows) - 1);
                long max2 = Math.max(LeftIndexPartitionFunction.this._ixrange.colStart, 1L);
                long min2 = Math.min(LeftIndexPartitionFunction.this._ixrange.colEnd, numColumns);
                long j = (max - LeftIndexPartitionFunction.this._ixrange.rowStart) + 1;
                long j2 = j + (min - max);
                long j3 = (max2 - LeftIndexPartitionFunction.this._ixrange.colStart) + 1;
                long j4 = j3 + (min2 - max2);
                int longValue = (int) (max - ((Long) tuple2._1).longValue());
                int longValue2 = (int) (min - ((Long) tuple2._1).longValue());
                int i = ((int) max2) - 1;
                int i2 = ((int) min2) - 1;
                FrameBlock frameBlock = (FrameBlock) tuple2._2;
                long j5 = j;
                long min3 = Math.min(j2, (((j - 1) / 1000) + 1) * 1000);
                while (true) {
                    long j6 = min3;
                    if (j5 > j6) {
                        return new Tuple2<>(tuple2._1, frameBlock);
                    }
                    frameBlock = frameBlock.leftIndexingOperations((FrameBlock) LeftIndexPartitionFunction.this._binput.slice(j5, j6, j3, j4, new FrameBlock()), (int) (longValue + (j5 - j)), (int) (longValue2 + (j6 - j2)), i, i2, new FrameBlock());
                    j5 = j6 + 1;
                    min3 = Math.min(j2, j6 + 1000);
                }
            }
        }

        public LeftIndexPartitionFunction(PartitionedBroadcast<FrameBlock> partitionedBroadcast, IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = null;
            this._binput = partitionedBroadcast;
            this._ixrange = indexRange;
        }

        public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> it) throws Exception {
            return new LeftIndexPartitionIterator(it);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction$SliceBlock.class */
    private static class SliceBlock implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
        private static final long serialVersionUID = -5270171193018691692L;
        private IndexRange _ixrange;

        public SliceBlock(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = indexRange;
        }

        public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> tuple2) throws Exception {
            long longValue = ((Long) tuple2._1()).longValue();
            FrameBlock frameBlock = (FrameBlock) tuple2._2();
            return new Tuple2<>(Long.valueOf(longValue > this._ixrange.rowStart ? (longValue - this._ixrange.rowStart) + 1 : 1L), frameBlock.slice((int) (longValue > this._ixrange.rowStart ? 0L : this._ixrange.rowStart - longValue), (int) (this._ixrange.rowEnd - longValue >= ((long) frameBlock.getNumRows()) ? frameBlock.getNumRows() - 1 : this._ixrange.rowEnd - longValue), (int) (this._ixrange.colStart - 1), (int) (this._ixrange.colEnd - 1), (CacheBlock) new FrameBlock()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction$SliceBlockPartitionFunction.class */
    public static class SliceBlockPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Long, FrameBlock> {
        private static final long serialVersionUID = -1655390518299307588L;
        private IndexRange _ixrange;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction$SliceBlockPartitionFunction$SliceBlockPartitionIterator.class */
        public class SliceBlockPartitionIterator extends LazyIterableIterator<Tuple2<Long, FrameBlock>> {
            public SliceBlockPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> tuple2) throws Exception {
                long longValue = ((Long) tuple2._1()).longValue();
                FrameBlock frameBlock = (FrameBlock) tuple2._2();
                return new Tuple2<>(Long.valueOf(longValue), frameBlock.slice(0, frameBlock.getNumRows() - 1, ((int) SliceBlockPartitionFunction.this._ixrange.colStart) - 1, ((int) SliceBlockPartitionFunction.this._ixrange.colEnd) - 1, (CacheBlock) new FrameBlock()));
            }
        }

        public SliceBlockPartitionFunction(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = indexRange;
        }

        public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> it) throws Exception {
            return new SliceBlockPartitionIterator(it);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction$SliceRHSForLeftIndexing.class */
    private static class SliceRHSForLeftIndexing implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
        private static final long serialVersionUID = 5724800998701216440L;
        private IndexRange _ixrange;
        private int _blen;
        private long _rlen;
        private long _clen;

        public SliceRHSForLeftIndexing(IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._ixrange = null;
            this._blen = -1;
            this._rlen = -1L;
            this._clen = -1L;
            this._ixrange = indexRange;
            this._rlen = dataCharacteristics.getRows();
            this._clen = dataCharacteristics.getCols();
            this._blen = (int) Math.min(OptimizerUtils.getDefaultFrameSize(), this._rlen);
            this._blen = (int) dataCharacteristics.getCols();
        }

        public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> tuple2) throws Exception {
            Pair<Long, FrameBlock> indexedFrameBlock = SparkUtils.toIndexedFrameBlock(tuple2);
            ArrayList arrayList = new ArrayList();
            OperationsOnMatrixValues.performShift(indexedFrameBlock, this._ixrange, this._blen, this._rlen, this._clen, (ArrayList<Pair<Long, FrameBlock>>) arrayList);
            return SparkUtils.fromIndexedFrameBlock(arrayList).iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/FrameIndexingSPInstruction$ZeroOutLHS.class */
    private static class ZeroOutLHS implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
        private static final long serialVersionUID = -2672267231152496854L;
        private boolean _complement;
        private IndexRange _ixrange;
        private int _blen;
        private long _rlen;

        public ZeroOutLHS(boolean z, IndexRange indexRange, DataCharacteristics dataCharacteristics) {
            this._complement = false;
            this._ixrange = null;
            this._blen = -1;
            this._rlen = -1L;
            this._complement = z;
            this._ixrange = indexRange;
            this._blen = OptimizerUtils.getDefaultFrameSize();
            this._blen = (int) dataCharacteristics.getCols();
            this._rlen = dataCharacteristics.getRows();
        }

        public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> tuple2) throws Exception {
            ArrayList arrayList = new ArrayList();
            IndexRange indexRange = new IndexRange(this._ixrange.rowStart, this._ixrange.rowEnd, this._ixrange.colStart, this._ixrange.colEnd);
            long longValue = (((((Long) tuple2._1).longValue() - 1) / this._blen) * this._blen) + 1;
            int computeCellInBlock = UtilFunctions.computeCellInBlock(((Long) tuple2._1).longValue(), this._blen);
            int i = 0;
            while (i < ((FrameBlock) tuple2._2).getNumRows()) {
                IndexRange selectedRangeForZeroOut = UtilFunctions.getSelectedRangeForZeroOut(new Pair(tuple2._1, tuple2._2), this._blen, indexRange, longValue - 1, longValue);
                if (selectedRangeForZeroOut.rowStart == -1 && selectedRangeForZeroOut.rowEnd == -1 && selectedRangeForZeroOut.colStart == -1 && selectedRangeForZeroOut.colEnd == -1) {
                    throw new Exception("Error while getting range for zero-out");
                }
                int min = (int) Math.min(this._blen, (this._rlen - longValue) + 1);
                int min2 = Math.min(Math.min(min, ((FrameBlock) tuple2._2).getNumRows() - i), min - computeCellInBlock);
                arrayList.add(new Pair(Long.valueOf(longValue), ((FrameBlock) tuple2._2).zeroOutOperations(new FrameBlock(), selectedRangeForZeroOut, this._complement, i, computeCellInBlock, min, min2)));
                indexRange.rowStart = longValue + this._blen;
                computeCellInBlock = UtilFunctions.computeCellInBlock(computeCellInBlock + min2 + 1, this._blen);
                i += min2;
                longValue += this._blen;
            }
            return SparkUtils.fromIndexedFrameBlock(arrayList).iterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FrameIndexingSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, AggBinaryOp.SparkAggType sparkAggType, String str, String str2) {
        super(cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, cPOperand6, sparkAggType, str, str2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FrameIndexingSPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, CPOperand cPOperand7, String str, String str2) {
        super(cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, cPOperand6, cPOperand7, str, str2);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        JavaPairRDD<Long, FrameBlock> mergeByKey;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        String opcode = getOpcode();
        long longValue = executionContext.getScalarInput(this.rowLower).getLongValue();
        long longValue2 = executionContext.getScalarInput(this.rowUpper).getLongValue();
        long longValue3 = executionContext.getScalarInput(this.colLower).getLongValue();
        long longValue4 = executionContext.getScalarInput(this.colUpper).getLongValue();
        IndexRange indexRange = new IndexRange(longValue, longValue2, longValue3, longValue4);
        if (opcode.equalsIgnoreCase(RightIndex.OPCODE)) {
            DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
            DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
            dataCharacteristics2.set((longValue2 - longValue) + 1, (longValue4 - longValue3) + 1, dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
            checkValidOutputDimensions(dataCharacteristics2);
            JavaPairRDD<Long, FrameBlock> frameBinaryBlockRDDHandleForVariable = sparkExecutionContext.getFrameBinaryBlockRDDHandleForVariable(this.input1.getName());
            sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), isPartitioningPreservingRightIndexing(dataCharacteristics, indexRange) ? frameBinaryBlockRDDHandleForVariable.mapPartitionsToPair(new SliceBlockPartitionFunction(indexRange, dataCharacteristics2), true) : frameBinaryBlockRDDHandleForVariable.filter(new IsFrameBlockInRange(longValue, longValue2, dataCharacteristics2)).mapToPair(new SliceBlock(indexRange, dataCharacteristics2)));
            sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
            sparkExecutionContext.getFrameObject(this.output.getName()).setSchema(sparkExecutionContext.getFrameObject(this.input1.getName()).getSchema((int) longValue3, (int) longValue4));
            return;
        }
        if (!opcode.equalsIgnoreCase(LeftIndex.OPCODE) && !opcode.equalsIgnoreCase("mapLeftIndex")) {
            throw new DMLRuntimeException("Invalid opcode (" + opcode + ") encountered in FrameIndexingSPInstruction.");
        }
        JavaPairRDD<Long, FrameBlock> frameBinaryBlockRDDHandleForVariable2 = sparkExecutionContext.getFrameBinaryBlockRDDHandleForVariable(this.input1.getName());
        PartitionedBroadcast<FrameBlock> partitionedBroadcast = null;
        JavaPairRDD javaPairRDD = null;
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        DataCharacteristics dataCharacteristics4 = executionContext.getDataCharacteristics(this.input1.getName());
        dataCharacteristics3.set(dataCharacteristics4.getRows(), dataCharacteristics4.getCols(), dataCharacteristics4.getBlocksize(), dataCharacteristics4.getBlocksize());
        checkValidOutputDimensions(dataCharacteristics3);
        DataCharacteristics dataCharacteristics5 = executionContext.getDataCharacteristics(this.input2.getName());
        if (!dataCharacteristics5.dimsKnown()) {
            throw new DMLRuntimeException("The right input frame dimensions are not specified for FrameIndexingSPInstruction");
        }
        if ((longValue2 - longValue) + 1 != dataCharacteristics5.getRows() || (longValue4 - longValue3) + 1 != dataCharacteristics5.getCols()) {
            throw new DMLRuntimeException("Invalid index range of leftindexing: [" + longValue + IOUtilFunctions.LIBSVM_INDEX_DELIM + longValue2 + "," + longValue3 + IOUtilFunctions.LIBSVM_INDEX_DELIM + longValue4 + "] vs [" + dataCharacteristics5.getRows() + "x" + dataCharacteristics5.getCols() + "].");
        }
        if (opcode.equalsIgnoreCase("mapLeftIndex")) {
            partitionedBroadcast = sparkExecutionContext.getBroadcastForFrameVariable(this.input2.getName());
            mergeByKey = frameBinaryBlockRDDHandleForVariable2.mapPartitionsToPair(new LeftIndexPartitionFunction(partitionedBroadcast, indexRange, dataCharacteristics3), true);
        } else {
            JavaPairRDD flatMapToPair = frameBinaryBlockRDDHandleForVariable2.flatMapToPair(new ZeroOutLHS(false, indexRange, dataCharacteristics4));
            javaPairRDD = sparkExecutionContext.getFrameBinaryBlockRDDHandleForVariable(this.input2.getName()).flatMapToPair(new SliceRHSForLeftIndexing(indexRange, dataCharacteristics4));
            mergeByKey = FrameRDDAggregateUtils.mergeByKey(flatMapToPair.union(javaPairRDD));
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mergeByKey);
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        if (partitionedBroadcast != null) {
            sparkExecutionContext.addLineageBroadcast(this.output.getName(), this.input2.getName());
        }
        if (javaPairRDD != null) {
            sparkExecutionContext.addLineageRDD(this.output.getName(), this.input2.getName());
        }
    }

    private static boolean isPartitioningPreservingRightIndexing(DataCharacteristics dataCharacteristics, IndexRange indexRange) {
        return dataCharacteristics.dimsKnown() && indexRange.rowStart == 1 && indexRange.rowEnd == dataCharacteristics.getRows();
    }

    private static void checkValidOutputDimensions(DataCharacteristics dataCharacteristics) {
        if (!dataCharacteristics.dimsKnown()) {
            throw new DMLRuntimeException("FrameIndexingSPInstruction: The updated output dimensions are invalid: " + dataCharacteristics);
        }
    }
}
