package org.apache.sysds.runtime.instructions.spark;

import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.functions.MapInputSignature;
import org.apache.sysds.runtime.instructions.spark.functions.MapJoinSignature;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.class */
public class BuiltinNarySPInstruction extends SPInstruction implements LineageTraceable {
    private CPOperand[] inputs;
    private CPOperand output;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction$MinMaxAddFunction.class */
    private static class MinMaxAddFunction implements Function<MatrixBlock[], MatrixBlock> {
        private static final long serialVersionUID = -4227447915387484397L;
        private final SimpleOperator _op;
        private final ScalarObject[] _scalars;

        public MinMaxAddFunction(String str, List<ScalarObject> list) {
            this._scalars = (ScalarObject[]) list.toArray(new ScalarObject[0]);
            this._op = new SimpleOperator(str.equals("n+") ? Plus.getPlusFnObject() : Builtin.getBuiltinFnObject(str.substring(1)));
        }

        public MatrixBlock call(MatrixBlock[] matrixBlockArr) throws Exception {
            return MatrixBlock.naryOperations(this._op, matrixBlockArr, this._scalars, new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction$PadBlocksFunction.class */
    public static class PadBlocksFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1291358959908299855L;
        private final DataCharacteristics _mcOut;

        public PadBlocksFunction(DataCharacteristics dataCharacteristics) {
            this._mcOut = dataCharacteristics;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            int computeBlockSize = UtilFunctions.computeBlockSize(this._mcOut.getRows(), matrixIndexes.getRowIndex(), this._mcOut.getBlocksize());
            int computeBlockSize2 = UtilFunctions.computeBlockSize(this._mcOut.getCols(), matrixIndexes.getColumnIndex(), this._mcOut.getBlocksize());
            if (computeBlockSize == matrixBlock.getNumRows() && computeBlockSize2 == matrixBlock.getNumColumns()) {
                return tuple2;
            }
            if (computeBlockSize > matrixBlock.getNumRows()) {
                matrixBlock = matrixBlock.append(new MatrixBlock(computeBlockSize - matrixBlock.getNumRows(), computeBlockSize2, true), new MatrixBlock(), false);
            } else if (computeBlockSize2 > matrixBlock.getNumColumns()) {
                matrixBlock = matrixBlock.append(new MatrixBlock(computeBlockSize, computeBlockSize2 - matrixBlock.getNumColumns(), true), new MatrixBlock(), true);
            }
            return new Tuple2<>(matrixIndexes, matrixBlock);
        }
    }

    protected BuiltinNarySPInstruction(CPOperand[] cPOperandArr, CPOperand cPOperand, String str, String str2) {
        super(SPInstruction.SPType.BuiltinNary, str, str2);
        this.inputs = cPOperandArr;
        this.output = cPOperand;
    }

    public static BuiltinNarySPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 1]);
        CPOperand[] cPOperandArr = new CPOperand[instructionPartsWithValueType.length - 2];
        for (int i = 1; i < instructionPartsWithValueType.length - 1; i++) {
            cPOperandArr[i - 1] = new CPOperand(instructionPartsWithValueType[i]);
        }
        return new BuiltinNarySPInstruction(cPOperandArr, cPOperand, str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD = null;
        DataCharacteristics dataCharacteristics = null;
        if (getOpcode().equals("cbind") || getOpcode().equals("rbind")) {
            boolean equals = getOpcode().equals("cbind");
            dataCharacteristics = computeAppendOutputDataCharacteristics(sparkExecutionContext, this.inputs, equals);
            MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(0L, 0L, dataCharacteristics.getBlocksize(), 0L);
            for (CPOperand cPOperand : this.inputs) {
                DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(cPOperand.getName());
                JavaPairRDD<MatrixIndexes, MatrixBlock> mapToPair = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(cPOperand.getName()).flatMapToPair(new AppendGSPInstruction.ShiftMatrix(matrixCharacteristics, dataCharacteristics2, equals)).mapToPair(new PadBlocksFunction(dataCharacteristics));
                javaPairRDD = javaPairRDD != null ? javaPairRDD.union(mapToPair) : mapToPair;
                updateAppendDataCharacteristics(dataCharacteristics2, matrixCharacteristics, equals);
            }
            javaPairRDD = RDDAggregateUtils.mergeByKey(javaPairRDD, SparkUtils.getNumPreferredPartitions(dataCharacteristics), false);
        } else if (ArrayUtils.contains(new String[]{"nmin", "nmax", "n+"}, getOpcode())) {
            dataCharacteristics = computeMinMaxOutputDataCharacteristics(sparkExecutionContext, this.inputs);
            List<ScalarObject> scalarInputs = sparkExecutionContext.getScalarInputs(this.inputs);
            JavaPairRDD javaPairRDD2 = null;
            for (CPOperand cPOperand2 : this.inputs) {
                if (cPOperand2.getDataType().isMatrix()) {
                    JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(cPOperand2.getName());
                    javaPairRDD2 = javaPairRDD2 == null ? binaryMatrixBlockRDDHandleForVariable.mapValues(new MapInputSignature()) : javaPairRDD2.join(binaryMatrixBlockRDDHandleForVariable).mapValues(new MapJoinSignature());
                }
            }
            javaPairRDD = javaPairRDD2.mapValues(new MinMaxAddFunction(getOpcode(), scalarInputs));
        }
        sparkExecutionContext.getDataCharacteristics(this.output.getName()).set(dataCharacteristics);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), javaPairRDD);
        for (CPOperand cPOperand3 : this.inputs) {
            if (!cPOperand3.isScalar()) {
                sparkExecutionContext.addLineageRDD(this.output.getName(), cPOperand3.getName());
            }
        }
    }

    private static DataCharacteristics computeAppendOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext, CPOperand[] cPOperandArr, boolean z) {
        MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(0L, 0L, sparkExecutionContext.getDataCharacteristics(cPOperandArr[0].getName()).getBlocksize(), 0L);
        for (CPOperand cPOperand : cPOperandArr) {
            updateAppendDataCharacteristics(sparkExecutionContext.getDataCharacteristics(cPOperand.getName()), matrixCharacteristics, z);
        }
        return matrixCharacteristics;
    }

    private static void updateAppendDataCharacteristics(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, boolean z) {
        dataCharacteristics2.setDimension(z ? Math.max(dataCharacteristics2.getRows(), dataCharacteristics.getRows()) : dataCharacteristics2.getRows() + dataCharacteristics.getRows(), z ? dataCharacteristics2.getCols() + dataCharacteristics.getCols() : Math.max(dataCharacteristics2.getCols(), dataCharacteristics.getCols()));
        dataCharacteristics2.setNonZeros((dataCharacteristics2.getNonZeros() == -1 || !dataCharacteristics.dimsKnown(true)) ? -1L : dataCharacteristics2.getNonZeros() + dataCharacteristics.getNonZeros());
    }

    private static DataCharacteristics computeMinMaxOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext, CPOperand[] cPOperandArr) {
        MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics();
        for (CPOperand cPOperand : cPOperandArr) {
            if (cPOperand.getDataType().isMatrix()) {
                DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(cPOperand.getName());
                matrixCharacteristics.setRows(Math.max(matrixCharacteristics.getRows(), dataCharacteristics.getRows()));
                matrixCharacteristics.setCols(Math.max(matrixCharacteristics.getCols(), dataCharacteristics.getCols()));
                matrixCharacteristics.setBlocksize(dataCharacteristics.getBlocksize());
            }
        }
        return matrixCharacteristics;
    }

    @Override // org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, this.inputs)));
    }
}
