package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.class */
public class CumulativeOffsetSPInstruction extends BinarySPInstruction {
    private UnaryOperator _uop;
    private boolean _cumsumprod;
    private final double _initValue;
    private final boolean _broadcast;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction$RDDCumOffsetFunction.class */
    private static class RDDCumOffsetFunction implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -5804080263258064743L;
        private final UnaryOperator _uop;
        private final boolean _cumsumprod;

        public RDDCumOffsetFunction(UnaryOperator unaryOperator, boolean z) {
            this._uop = unaryOperator;
            this._cumsumprod = z;
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> tuple2) throws Exception {
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._1();
            return LibMatrixAgg.cumaggregateUnaryMatrix(matrixBlock, new MatrixBlock(matrixBlock.getNumRows(), this._cumsumprod ? 1 : matrixBlock.getNumColumns(), false), this._uop, DataConverter.convertToDoubleVector((MatrixBlock) tuple2._2(), false, ((Builtin) this._uop.fn).bFunc == Builtin.BuiltinCode.CUMSUM));
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction$RDDCumSplitFunction.class */
    private static class RDDCumSplitFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8407407527406576965L;
        private double _initValue;
        private int _blen;
        private long _lastRowBlockIndex;

        public RDDCumSplitFunction(double d, long j, int i) {
            this._initValue = DataExpression.DEFAULT_DELIM_FILL_VALUE;
            this._blen = -1;
            this._initValue = d;
            this._blen = i;
            this._lastRowBlockIndex = (long) Math.ceil(j / i);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            ArrayList arrayList = new ArrayList();
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            long rowIndex = (matrixIndexes.getRowIndex() - 1) * this._blen;
            boolean z = matrixIndexes.getRowIndex() == 1;
            boolean z2 = matrixIndexes.getRowIndex() == this._lastRowBlockIndex;
            if (z) {
                MatrixIndexes matrixIndexes2 = new MatrixIndexes(1L, matrixIndexes.getColumnIndex());
                MatrixBlock matrixBlock2 = new MatrixBlock(1, matrixBlock.getNumColumns(), matrixBlock.isInSparseFormat());
                if (this._initValue != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    for (int i = 0; i < matrixBlock.getNumColumns(); i++) {
                        matrixBlock2.appendValue(0, i, this._initValue);
                    }
                }
                arrayList.add(new Tuple2(matrixIndexes2, matrixBlock2));
            }
            for (int i2 = 0; i2 < matrixBlock.getNumRows(); i2++) {
                if (!z2 || i2 != matrixBlock.getNumRows() - 1) {
                    MatrixIndexes matrixIndexes3 = new MatrixIndexes(rowIndex + i2 + 2, matrixIndexes.getColumnIndex());
                    MatrixBlock matrixBlock3 = new MatrixBlock(1, matrixBlock.getNumColumns(), matrixBlock.isInSparseFormat());
                    matrixBlock.slice(i2, i2, 0, matrixBlock.getNumColumns() - 1, matrixBlock3);
                    arrayList.add(new Tuple2(matrixIndexes3, matrixBlock3));
                }
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction$RDDCumSplitLookupFunction.class */
    private static class RDDCumSplitLookupFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> {
        private static final long serialVersionUID = -2785629043886477479L;
        private final PartitionedBroadcast<MatrixBlock> _pbc;
        private final double _initValue;
        private final int _blen;

        public RDDCumSplitLookupFunction(PartitionedBroadcast<MatrixBlock> partitionedBroadcast, double d, long j, int i) {
            this._pbc = partitionedBroadcast;
            this._initValue = d;
            this._blen = i;
        }

        public Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            long computeBlockIndex = UtilFunctions.computeBlockIndex(matrixIndexes.getRowIndex() - 1, this._blen);
            int computeCellInBlock = UtilFunctions.computeCellInBlock(matrixIndexes.getRowIndex() - 1, this._blen);
            return new Tuple2<>(matrixIndexes, new Tuple2(matrixBlock, matrixIndexes.getRowIndex() == 1 ? new MatrixBlock(1, matrixBlock.getNumColumns(), this._initValue) : this._pbc.getBlock((int) computeBlockIndex, (int) matrixIndexes.getColumnIndex()).slice2(computeCellInBlock, computeCellInBlock)));
        }
    }

    private CumulativeOffsetSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, double d, boolean z, String str, String str2) {
        super(SPInstruction.SPType.CumsumOffset, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._uop = null;
        this._cumsumprod = false;
        if ("bcumoffk+".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
        } else if ("bcumoff*".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
        } else if ("bcumoff+*".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
            this._cumsumprod = true;
        } else if ("bcumoffmin".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummin"));
        } else if ("bcumoffmax".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummax"));
        }
        this._initValue = d;
        this._broadcast = z;
    }

    public static CumulativeOffsetSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
        return new CumulativeOffsetSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), Double.parseDouble(instructionPartsWithValueType[4]), Boolean.parseBoolean(instructionPartsWithValueType[5]), instructionPartsWithValueType[0], str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        long rows = dataCharacteristics2.getRows();
        int blocksize = dataCharacteristics2.getBlocksize();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        boolean z = this._broadcast && !SparkUtils.isHashPartitioned(binaryMatrixBlockRDDHandleForVariable);
        JavaPairRDD<?, ?> mapValues = (z ? binaryMatrixBlockRDDHandleForVariable.mapToPair(new RDDCumSplitLookupFunction(sparkExecutionContext.getBroadcastForVariable(this.input2.getName()), this._initValue, rows, blocksize)) : binaryMatrixBlockRDDHandleForVariable.join(sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName()).flatMapToPair(new RDDCumSplitFunction(this._initValue, rows, blocksize)))).mapValues(new RDDCumOffsetFunction(this._uop, this._cumsumprod));
        if (this._cumsumprod) {
            sparkExecutionContext.getDataCharacteristics(this.output.getName()).set(dataCharacteristics.getRows(), 1L, dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
        } else {
            updateUnaryOutputDataCharacteristics(sparkExecutionContext);
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapValues);
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        sparkExecutionContext.addLineage(this.output.getName(), this.input2.getName(), z);
    }

    public double getInitValue() {
        return this._initValue;
    }

    public boolean getBroadcast() {
        return this._broadcast;
    }
}
