package org.apache.sysds.runtime.instructions.spark;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.stream.IntStream;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.MapInputSignature;
import org.apache.sysds.runtime.instructions.spark.functions.MapJoinSignature;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.class */
public class SpoofSPInstruction extends SPInstruction {
    private final Class<?> _class;
    private final byte[] _classBytes;
    private final CPOperand[] _in;
    private final CPOperand _out;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction$CellwiseFunction.class */
    private static class CellwiseFunction extends SpoofFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8209188316939435099L;
        private SpoofCellwise _op;
        private final int _blen;

        public CellwiseFunction(String str, byte[] bArr, boolean[] zArr, ArrayList<PartitionedBroadcast<MatrixBlock>> arrayList, ArrayList<ScalarObject> arrayList2, int i) {
            super(str, bArr, zArr, arrayList, arrayList2);
            this._op = null;
            this._blen = i;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> it) throws Exception {
            if (this._op == null) {
                this._op = (SpoofCellwise) CodegenUtils.createInstance(CodegenUtils.getClassSync(this._className, this._classBytes));
            }
            ArrayList arrayList = new ArrayList();
            while (it.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock[]> next = it.next();
                MatrixIndexes matrixIndexes = (MatrixIndexes) next._1();
                MatrixBlock[] matrixBlockArr = (MatrixBlock[]) next._2();
                MatrixIndexes matrixIndexes2 = matrixIndexes;
                MatrixBlock matrixBlock = new MatrixBlock();
                ArrayList<MatrixBlock> allMatrixInputs = getAllMatrixInputs(matrixIndexes, matrixBlockArr);
                long rowIndex = (matrixIndexes.getRowIndex() - 1) * this._blen;
                if (this._op.getCellType() == SpoofCellwise.CellType.FULL_AGG) {
                    ScalarObject execute = this._op.execute(allMatrixInputs, this._scalars, 1, rowIndex);
                    matrixBlock.reset(1, 1);
                    matrixBlock.quickSetValue(0, 0, execute.getDoubleValue());
                } else {
                    if (this._op.getCellType() == SpoofCellwise.CellType.ROW_AGG) {
                        matrixIndexes2 = new MatrixIndexes(matrixIndexes2.getRowIndex(), 1L);
                    } else if (this._op.getCellType() == SpoofCellwise.CellType.COL_AGG) {
                        matrixIndexes2 = new MatrixIndexes(1L, matrixIndexes2.getColumnIndex());
                    }
                    matrixBlock = this._op.execute(allMatrixInputs, this._scalars, matrixBlock, 1, rowIndex);
                }
                arrayList.add(new Tuple2(matrixIndexes2, matrixBlock));
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction$MultiAggAggregateFunction.class */
    private static class MultiAggAggregateFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5978731867787952513L;
        private SpoofCellwise.AggOp[] _ops;

        public MultiAggAggregateFunction(SpoofCellwise.AggOp[] aggOpArr) {
            this._ops = null;
            this._ops = aggOpArr;
        }

        public MatrixBlock call(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) throws Exception {
            if (matrixBlock.getNumRows() <= 0 || matrixBlock.getNumColumns() <= 0) {
                matrixBlock.copy(matrixBlock2);
                return matrixBlock;
            }
            if (matrixBlock2.getNumRows() <= 0 || matrixBlock2.getNumColumns() <= 0) {
                return matrixBlock;
            }
            SpoofMultiAggregate.aggregatePartialResults(this._ops, matrixBlock, matrixBlock2);
            return matrixBlock;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction$MultiAggregateFunction.class */
    private static class MultiAggregateFunction extends SpoofFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -5224519291577332734L;
        private SpoofMultiAggregate _op;
        private final int _blen;

        public MultiAggregateFunction(String str, byte[] bArr, boolean[] zArr, ArrayList<PartitionedBroadcast<MatrixBlock>> arrayList, ArrayList<ScalarObject> arrayList2, int i) {
            super(str, bArr, zArr, arrayList, arrayList2);
            this._op = null;
            this._blen = i;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> tuple2) throws Exception {
            if (this._op == null) {
                this._op = (SpoofMultiAggregate) CodegenUtils.createInstance(CodegenUtils.getClassSync(this._className, this._classBytes));
            }
            return new Tuple2<>((MatrixIndexes) tuple2._1(), this._op.execute(getAllMatrixInputs((MatrixIndexes) tuple2._1(), (MatrixBlock[]) tuple2._2()), this._scalars, new MatrixBlock(), 1, (((MatrixIndexes) tuple2._1()).getRowIndex() - 1) * this._blen));
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction$OuterProductFunction.class */
    private static class OuterProductFunction extends SpoofFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8209188316939435099L;
        private SpoofOperator _op;

        public OuterProductFunction(String str, byte[] bArr, boolean[] zArr, ArrayList<PartitionedBroadcast<MatrixBlock>> arrayList, ArrayList<ScalarObject> arrayList2) {
            super(str, bArr, zArr, arrayList, arrayList2);
            this._op = null;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> it) throws Exception {
            if (this._op == null) {
                this._op = CodegenUtils.createInstance(CodegenUtils.getClassSync(this._className, this._classBytes));
            }
            ArrayList arrayList = new ArrayList();
            while (it.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock[]> next = it.next();
                MatrixIndexes matrixIndexes = (MatrixIndexes) next._1();
                MatrixBlock[] matrixBlockArr = (MatrixBlock[]) next._2();
                MatrixBlock matrixBlock = new MatrixBlock();
                ArrayList<MatrixBlock> allMatrixInputs = getAllMatrixInputs(matrixIndexes, matrixBlockArr, true);
                if (((SpoofOuterProduct) this._op).getOuterProdType() == SpoofOuterProduct.OutProdType.AGG_OUTER_PRODUCT) {
                    ScalarObject execute = this._op.execute(allMatrixInputs, this._scalars, 1);
                    matrixBlock.reset(1, 1);
                    matrixBlock.quickSetValue(0, 0, execute.getDoubleValue());
                } else {
                    matrixBlock = this._op.execute(allMatrixInputs, this._scalars, matrixBlock);
                }
                arrayList.add(new Tuple2(createOutputIndexes(matrixIndexes, this._op), matrixBlock));
            }
            return arrayList.iterator();
        }

        private static MatrixIndexes createOutputIndexes(MatrixIndexes matrixIndexes, SpoofOperator spoofOperator) {
            return ((SpoofOuterProduct) spoofOperator).getOuterProdType() == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT ? new MatrixIndexes(matrixIndexes.getColumnIndex(), 1L) : ((SpoofOuterProduct) spoofOperator).getOuterProdType() == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT ? new MatrixIndexes(matrixIndexes.getRowIndex(), 1L) : matrixIndexes;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction$ReplicateRightFactorFunction.class */
    public static class ReplicateRightFactorFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -7295989688796126442L;
        private final long _len;
        private final long _blen;

        public ReplicateRightFactorFunction(long j, long j2) {
            this._len = j;
            this._blen = j2;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            LinkedList linkedList = new LinkedList();
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            long ceil = (long) Math.ceil(this._len / this._blen);
            long rowIndex = matrixIndexes.getRowIndex();
            long j = 1;
            while (true) {
                long j2 = j;
                if (j2 > ceil) {
                    return linkedList.iterator();
                }
                linkedList.add(new Tuple2(new MatrixIndexes(j2, rowIndex), matrixBlock));
                j = j2 + 1;
            }
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction$RowwiseFunction.class */
    private static class RowwiseFunction extends SpoofFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -7926980450209760212L;
        private final int _blen;
        private final int _clen;
        private final int _clen2;
        private SpoofRowwise _op;

        public RowwiseFunction(String str, byte[] bArr, boolean[] zArr, ArrayList<PartitionedBroadcast<MatrixBlock>> arrayList, ArrayList<ScalarObject> arrayList2, int i, int i2, int i3) {
            super(str, bArr, zArr, arrayList, arrayList2);
            this._op = null;
            this._blen = i;
            this._clen = i2;
            this._clen2 = i2;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> it) {
            if (this._op == null) {
                this._op = (SpoofRowwise) CodegenUtils.createInstance(CodegenUtils.getClassSync(this._className, this._classBytes));
            }
            LibSpoofPrimitives.setupThreadLocalMemory(this._op.getNumIntermediates(), this._clen, this._clen2);
            ArrayList arrayList = new ArrayList();
            boolean z = this._op.getRowType().isColumnAgg() || this._op.getRowType() == SpoofRowwise.RowType.FULL_AGG;
            MatrixBlock matrixBlock = z ? new MatrixBlock() : null;
            while (it.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock[]> next = it.next();
                MatrixIndexes matrixIndexes = (MatrixIndexes) next._1();
                MatrixBlock[] matrixBlockArr = (MatrixBlock[]) next._2();
                matrixBlock = this._op.execute(getAllMatrixInputs(matrixIndexes, matrixBlockArr), this._scalars, z ? matrixBlock : new MatrixBlock(), false, z, (matrixIndexes.getRowIndex() - 1) * this._blen);
                if (!z) {
                    arrayList.add(new Tuple2(new MatrixIndexes(matrixIndexes.getRowIndex(), this._op.getRowType() != SpoofRowwise.RowType.NO_AGG ? 1L : matrixIndexes.getColumnIndex()), matrixBlock));
                }
            }
            LibSpoofPrimitives.cleanupThreadLocalMemory();
            if (z) {
                matrixBlock.recomputeNonZeros();
                matrixBlock.examSparsity();
                arrayList.add(new Tuple2(new MatrixIndexes(1L, 1L), matrixBlock));
            }
            return arrayList.iterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction$SpoofFunction.class */
    public static class SpoofFunction implements Serializable {
        private static final long serialVersionUID = 2953479427746463003L;
        protected final boolean[] _bcInd;
        protected final ArrayList<PartitionedBroadcast<MatrixBlock>> _inputs;
        protected final ArrayList<ScalarObject> _scalars;
        protected final byte[] _classBytes;
        protected final String _className;

        protected SpoofFunction(String str, byte[] bArr, boolean[] zArr, ArrayList<PartitionedBroadcast<MatrixBlock>> arrayList, ArrayList<ScalarObject> arrayList2) {
            this._bcInd = zArr;
            this._inputs = arrayList;
            this._scalars = arrayList2;
            this._classBytes = bArr;
            this._className = str;
        }

        protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes matrixIndexes, MatrixBlock[] matrixBlockArr) {
            return getAllMatrixInputs(matrixIndexes, matrixBlockArr, false);
        }

        protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes matrixIndexes, MatrixBlock[] matrixBlockArr, boolean z) {
            ArrayList<MatrixBlock> arrayList = new ArrayList<>();
            int i = 0;
            int i2 = 0;
            int i3 = 0;
            while (i < this._bcInd.length) {
                if (this._bcInd[i]) {
                    int i4 = i3;
                    i3++;
                    PartitionedBroadcast<MatrixBlock> partitionedBroadcast = this._inputs.get(i4);
                    arrayList.add(partitionedBroadcast.getBlock((int) ((z && i == 2) ? matrixIndexes.getColumnIndex() : ((long) partitionedBroadcast.getNumRowBlocks()) >= matrixIndexes.getRowIndex() ? matrixIndexes.getRowIndex() : 1L), (int) ((z && i == 2) ? 1L : ((long) partitionedBroadcast.getNumColumnBlocks()) >= matrixIndexes.getColumnIndex() ? matrixIndexes.getColumnIndex() : 1L)));
                } else {
                    int i5 = i2;
                    i2++;
                    arrayList.add(matrixBlockArr[i5]);
                }
                i++;
            }
            return arrayList;
        }
    }

    private SpoofSPInstruction(Class<?> cls, byte[] bArr, CPOperand[] cPOperandArr, CPOperand cPOperand, String str, String str2) {
        super(SPInstruction.SPType.SpoofFused, str, str2);
        this._class = cls;
        this._classBytes = bArr;
        this._in = cPOperandArr;
        this._out = cPOperand;
    }

    public static SpoofSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList arrayList = new ArrayList();
        Class<?> cls = CodegenUtils.getClass(instructionPartsWithValueType[2]);
        byte[] classData = CodegenUtils.getClassData(instructionPartsWithValueType[2]);
        String str2 = instructionPartsWithValueType[0] + CodegenUtils.createInstance(cls).getSpoofType();
        for (int i = 3; i < instructionPartsWithValueType.length - 2; i++) {
            arrayList.add(new CPOperand(instructionPartsWithValueType[i]));
        }
        return new SpoofSPInstruction(cls, classData, (CPOperand[]) arrayList.toArray(new CPOperand[0]), new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 2]), str2, str);
    }

    public Class<?> getOperatorClass() {
        return this._class;
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        boolean[] determineBroadcastInputs = determineBroadcastInputs(sparkExecutionContext, this._in);
        boolean[] matrixBroadcastVector = getMatrixBroadcastVector(sparkExecutionContext, this._in, determineBroadcastInputs);
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this._in[getMainInputIndex(this._in, determineBroadcastInputs)].getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD = createJoinedInputRDD(sparkExecutionContext, this._in, determineBroadcastInputs, this._class.getSuperclass() == SpoofOuterProduct.class);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this._in.length; i++) {
            if (this._in[i].getDataType() == Types.DataType.MATRIX && determineBroadcastInputs[i]) {
                arrayList.add(sparkExecutionContext.getBroadcastForVariable(this._in[i].getName()));
            } else if (this._in[i].getDataType() == Types.DataType.SCALAR) {
                arrayList2.add(sparkExecutionContext.getScalarInput(this._in[i]));
            }
        }
        if (this._class.getSuperclass() == SpoofCellwise.class) {
            SpoofCellwise spoofCellwise = (SpoofCellwise) CodegenUtils.createInstance(this._class);
            AggregateOperator aggregateOperator = getAggregateOperator(spoofCellwise.getAggOp());
            if (this._out.getDataType() != Types.DataType.MATRIX) {
                sparkExecutionContext.setVariable(this._out.getName(), new DoubleObject(RDDAggregateUtils.aggStable((JavaPairRDD<MatrixIndexes, MatrixBlock>) createJoinedInputRDD.mapPartitionsToPair(new CellwiseFunction(this._class.getName(), this._classBytes, matrixBroadcastVector, arrayList, arrayList2, dataCharacteristics.getBlocksize()), true), aggregateOperator).getValue(0, 0)));
                return;
            }
            JavaPairRDD<?, ?> mapPartitionsToPair = createJoinedInputRDD.mapPartitionsToPair(new CellwiseFunction(this._class.getName(), this._classBytes, matrixBroadcastVector, arrayList, arrayList2, dataCharacteristics.getBlocksize()), true);
            if ((spoofCellwise.getCellType() == SpoofCellwise.CellType.ROW_AGG && dataCharacteristics.getCols() > dataCharacteristics.getBlocksize()) || (spoofCellwise.getCellType() == SpoofCellwise.CellType.COL_AGG && dataCharacteristics.getRows() > dataCharacteristics.getBlocksize())) {
                mapPartitionsToPair = RDDAggregateUtils.aggByKeyStable(mapPartitionsToPair, aggregateOperator, (int) Math.min(mapPartitionsToPair.getNumPartitions(), spoofCellwise.getCellType() == SpoofCellwise.CellType.ROW_AGG ? dataCharacteristics.getNumRowBlocks() : dataCharacteristics.getNumColBlocks()), false);
            }
            sparkExecutionContext.setRDDHandleForVariable(this._out.getName(), mapPartitionsToPair);
            maintainLineageInfo(sparkExecutionContext, this._in, determineBroadcastInputs, this._out);
            updateOutputDataCharacteristics(sparkExecutionContext, spoofCellwise);
            return;
        }
        if (this._class.getSuperclass() == SpoofMultiAggregate.class) {
            sparkExecutionContext.setMatrixOutput(this._out.getName(), (MatrixBlock) createJoinedInputRDD.mapToPair(new MultiAggregateFunction(this._class.getName(), this._classBytes, matrixBroadcastVector, arrayList, arrayList2, dataCharacteristics.getBlocksize())).values().fold(new MatrixBlock(), new MultiAggAggregateFunction(((SpoofMultiAggregate) CodegenUtils.createInstance(this._class)).getAggOps())));
            return;
        }
        if (this._class.getSuperclass() == SpoofOuterProduct.class) {
            if (this._out.getDataType() != Types.DataType.MATRIX) {
                sparkExecutionContext.setVariable(this._out.getName(), new DoubleObject(RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>) createJoinedInputRDD.mapPartitionsToPair(new OuterProductFunction(this._class.getName(), this._classBytes, matrixBroadcastVector, arrayList, arrayList2), true)).getValue(0, 0)));
                return;
            }
            SpoofOperator createInstance = CodegenUtils.createInstance(this._class);
            SpoofOuterProduct.OutProdType outerProdType = ((SpoofOuterProduct) createInstance).getOuterProdType();
            updateOutputDataCharacteristics(sparkExecutionContext, createInstance);
            DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this._out.getName());
            JavaPairRDD<?, ?> mapPartitionsToPair2 = createJoinedInputRDD.mapPartitionsToPair(new OuterProductFunction(this._class.getName(), this._classBytes, matrixBroadcastVector, arrayList, arrayList2), true);
            if (outerProdType == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT || outerProdType == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) {
                mapPartitionsToPair2 = RDDAggregateUtils.sumByKeyStable(mapPartitionsToPair2, (int) Math.min(mapPartitionsToPair2.getNumPartitions(), dataCharacteristics2.getNumRowBlocks() * dataCharacteristics2.getNumColBlocks()), false);
            }
            sparkExecutionContext.setRDDHandleForVariable(this._out.getName(), mapPartitionsToPair2);
            maintainLineageInfo(sparkExecutionContext, this._in, determineBroadcastInputs, this._out);
            return;
        }
        if (this._class.getSuperclass() != SpoofRowwise.class) {
            throw new DMLRuntimeException("Operator " + this._class.getSuperclass() + " is not supported on Spark");
        }
        if (dataCharacteristics.getCols() > dataCharacteristics.getBlocksize()) {
            long cols = dataCharacteristics.getCols();
            dataCharacteristics.getBlocksize();
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Invalid spark rowwise operator w/ ncol=" + cols + ", ncolpb=" + dMLRuntimeException + ".");
            throw dMLRuntimeException;
        }
        SpoofRowwise spoofRowwise = (SpoofRowwise) CodegenUtils.createInstance(this._class);
        JavaPairRDD<?, ?> mapPartitionsToPair3 = createJoinedInputRDD.mapPartitionsToPair(new RowwiseFunction(this._class.getName(), this._classBytes, matrixBroadcastVector, arrayList, arrayList2, dataCharacteristics.getBlocksize(), (int) dataCharacteristics.getCols(), (int) (spoofRowwise.getRowType().isConstDim2(spoofRowwise.getConstDim2()) ? spoofRowwise.getConstDim2() : spoofRowwise.getRowType().isRowTypeB1() ? sparkExecutionContext.getDataCharacteristics(this._in[1].getName()).getCols() : -1L)), spoofRowwise.getRowType() == SpoofRowwise.RowType.ROW_AGG || spoofRowwise.getRowType() == SpoofRowwise.RowType.NO_AGG);
        if (spoofRowwise.getRowType().isColumnAgg() || spoofRowwise.getRowType() == SpoofRowwise.RowType.FULL_AGG) {
            MatrixBlock sumStable = RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>) mapPartitionsToPair3);
            if (spoofRowwise.getRowType().isColumnAgg()) {
                sparkExecutionContext.setMatrixOutput(this._out.getName(), sumStable);
                return;
            } else {
                sparkExecutionContext.setScalarOutput(this._out.getName(), new DoubleObject(sumStable.quickGetValue(0, 0)));
                return;
            }
        }
        if (spoofRowwise.getRowType() == SpoofRowwise.RowType.ROW_AGG && dataCharacteristics.getCols() > dataCharacteristics.getBlocksize()) {
            mapPartitionsToPair3 = RDDAggregateUtils.sumByKeyStable(mapPartitionsToPair3, (int) Math.min(mapPartitionsToPair3.getNumPartitions(), dataCharacteristics.getNumRowBlocks()), false);
        }
        sparkExecutionContext.setRDDHandleForVariable(this._out.getName(), mapPartitionsToPair3);
        maintainLineageInfo(sparkExecutionContext, this._in, determineBroadcastInputs, this._out);
        updateOutputDataCharacteristics(sparkExecutionContext, spoofRowwise);
    }

    private static boolean[] determineBroadcastInputs(SparkExecutionContext sparkExecutionContext, CPOperand[] cPOperandArr) {
        boolean[] zArr = new boolean[cPOperandArr.length];
        double localMemBudget = OptimizerUtils.getLocalMemBudget() - CacheableData.getBroadcastSize();
        double broadcastMemoryBudget = SparkExecutionContext.getBroadcastMemoryBudget();
        for (int i = 0; i < cPOperandArr.length; i++) {
            if (cPOperandArr[i].getDataType().isMatrix()) {
                DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(cPOperandArr[i].getName());
                double estimateSizeExactSparsity = OptimizerUtils.estimateSizeExactSparsity(dataCharacteristics);
                double estimatePartitionedSizeExactSparsity = OptimizerUtils.estimatePartitionedSizeExactSparsity(dataCharacteristics);
                zArr[i] = localMemBudget > estimateSizeExactSparsity + estimatePartitionedSizeExactSparsity && broadcastMemoryBudget > estimatePartitionedSizeExactSparsity;
                localMemBudget -= zArr[i] ? estimatePartitionedSizeExactSparsity : DataExpression.DEFAULT_DELIM_FILL_VALUE;
                broadcastMemoryBudget -= zArr[i] ? estimatePartitionedSizeExactSparsity : DataExpression.DEFAULT_DELIM_FILL_VALUE;
            }
        }
        if (!IntStream.range(0, zArr.length).anyMatch(i2 -> {
            return cPOperandArr[i2].isMatrix() && !zArr[i2];
        })) {
            zArr[0] = false;
        }
        return zArr;
    }

    private static boolean[] getMatrixBroadcastVector(SparkExecutionContext sparkExecutionContext, CPOperand[] cPOperandArr, boolean[] zArr) {
        boolean[] zArr2 = new boolean[(int) Arrays.stream(cPOperandArr).filter(cPOperand -> {
            return cPOperand.getDataType().isMatrix();
        }).count()];
        int i = 0;
        for (int i2 = 0; i2 < cPOperandArr.length; i2++) {
            if (cPOperandArr[i2].getDataType().isMatrix()) {
                int i3 = i;
                i++;
                zArr2[i3] = zArr[i2];
            }
        }
        return zArr2;
    }

    private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD(SparkExecutionContext sparkExecutionContext, CPOperand[] cPOperandArr, boolean[] zArr, boolean z) {
        int mainInputIndex = getMainInputIndex(cPOperandArr, zArr);
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(cPOperandArr[mainInputIndex].getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock[]> mapValues = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(cPOperandArr[mainInputIndex].getName()).mapValues(new MapInputSignature());
        for (int i = 0; i < cPOperandArr.length; i++) {
            if (i != mainInputIndex && cPOperandArr[i].getDataType().isMatrix() && !zArr[i]) {
                String name = cPOperandArr[i].getName();
                JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(name);
                DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(name);
                if (z && i == 2) {
                    binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new ReplicateRightFactorFunction(dataCharacteristics.getRows(), dataCharacteristics.getBlocksize()));
                } else if (dataCharacteristics.getNumRowBlocks() > dataCharacteristics2.getNumRowBlocks()) {
                    binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new ReplicateBlockFunction(dataCharacteristics.getRows(), dataCharacteristics.getBlocksize(), false));
                } else if (dataCharacteristics.getNumColBlocks() > dataCharacteristics2.getNumColBlocks()) {
                    binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new ReplicateBlockFunction(dataCharacteristics.getCols(), dataCharacteristics.getBlocksize(), true));
                }
                mapValues = mapValues.join(binaryMatrixBlockRDDHandleForVariable).mapValues(new MapJoinSignature());
            }
        }
        return mapValues;
    }

    private static void maintainLineageInfo(SparkExecutionContext sparkExecutionContext, CPOperand[] cPOperandArr, boolean[] zArr, CPOperand cPOperand) {
        for (int i = 0; i < cPOperandArr.length; i++) {
            if (cPOperandArr[i].getDataType().isMatrix()) {
                sparkExecutionContext.addLineage(cPOperand.getName(), cPOperandArr[i].getName(), zArr[i]);
            }
        }
    }

    private static int getMainInputIndex(CPOperand[] cPOperandArr, boolean[] zArr) {
        return IntStream.range(0, zArr.length).filter(i -> {
            return cPOperandArr[i].isMatrix() && !zArr[i];
        }).min().orElse(0);
    }

    private void updateOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext, SpoofOperator spoofOperator) {
        if (spoofOperator instanceof SpoofCellwise) {
            DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this._in[0].getName());
            DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this._out.getName());
            if (((SpoofCellwise) spoofOperator).getCellType() == SpoofCellwise.CellType.ROW_AGG) {
                dataCharacteristics2.set(dataCharacteristics.getRows(), 1L, dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
                return;
            } else {
                if (((SpoofCellwise) spoofOperator).getCellType() == SpoofCellwise.CellType.NO_AGG) {
                    dataCharacteristics2.set(dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
                    return;
                }
                return;
            }
        }
        if (spoofOperator instanceof SpoofOuterProduct) {
            DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this._in[0].getName());
            DataCharacteristics dataCharacteristics4 = sparkExecutionContext.getDataCharacteristics(this._in[1].getName());
            DataCharacteristics dataCharacteristics5 = sparkExecutionContext.getDataCharacteristics(this._in[2].getName());
            DataCharacteristics dataCharacteristics6 = sparkExecutionContext.getDataCharacteristics(this._out.getName());
            SpoofOuterProduct.OutProdType outerProdType = ((SpoofOuterProduct) spoofOperator).getOuterProdType();
            if (outerProdType == SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT) {
                dataCharacteristics6.set(dataCharacteristics3.getRows(), dataCharacteristics3.getCols(), dataCharacteristics3.getBlocksize(), dataCharacteristics3.getBlocksize());
                return;
            } else if (outerProdType == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT) {
                dataCharacteristics6.set(dataCharacteristics5.getRows(), dataCharacteristics5.getCols(), dataCharacteristics5.getBlocksize(), dataCharacteristics5.getBlocksize());
                return;
            } else {
                if (outerProdType == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) {
                    dataCharacteristics6.set(dataCharacteristics4.getRows(), dataCharacteristics4.getCols(), dataCharacteristics4.getBlocksize(), dataCharacteristics4.getBlocksize());
                    return;
                }
                return;
            }
        }
        if (spoofOperator instanceof SpoofRowwise) {
            DataCharacteristics dataCharacteristics7 = sparkExecutionContext.getDataCharacteristics(this._in[0].getName());
            DataCharacteristics dataCharacteristics8 = sparkExecutionContext.getDataCharacteristics(this._out.getName());
            SpoofRowwise.RowType rowType = ((SpoofRowwise) spoofOperator).getRowType();
            if (rowType == SpoofRowwise.RowType.NO_AGG) {
                dataCharacteristics8.set(dataCharacteristics7);
                return;
            }
            if (rowType == SpoofRowwise.RowType.ROW_AGG) {
                dataCharacteristics8.set(dataCharacteristics7.getRows(), 1L, dataCharacteristics7.getBlocksize(), dataCharacteristics7.getBlocksize());
            } else if (rowType == SpoofRowwise.RowType.COL_AGG) {
                dataCharacteristics8.set(1L, dataCharacteristics7.getCols(), dataCharacteristics7.getBlocksize(), dataCharacteristics7.getBlocksize());
            } else if (rowType == SpoofRowwise.RowType.COL_AGG_T) {
                dataCharacteristics8.set(dataCharacteristics7.getCols(), 1L, dataCharacteristics7.getBlocksize(), dataCharacteristics7.getBlocksize());
            }
        }
    }

    public CPOperand getOutput() {
        return this._out;
    }

    public static AggregateOperator getAggregateOperator(SpoofCellwise.AggOp aggOp) {
        if (aggOp == SpoofCellwise.AggOp.SUM || aggOp == SpoofCellwise.AggOp.SUM_SQ) {
            return new AggregateOperator(DataExpression.DEFAULT_DELIM_FILL_VALUE, KahanPlus.getKahanPlusFnObject(), Types.CorrectionLocationType.NONE);
        }
        if (aggOp == SpoofCellwise.AggOp.MIN) {
            return new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MIN), Types.CorrectionLocationType.NONE);
        }
        if (aggOp == SpoofCellwise.AggOp.MAX) {
            return new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAX), Types.CorrectionLocationType.NONE);
        }
        return null;
    }

    public CPOperand[] getInputs() {
        return this._in;
    }
}
