package org.apache.sysds.runtime.instructions.spark;

import org.antlr.v4.runtime.atn.PredictionContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction2;
import org.apache.sysds.runtime.instructions.spark.functions.ReorgMapFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.class */
public class CpmmSPInstruction extends BinarySPInstruction {
    private final boolean _outputEmptyBlocks;
    private final AggBinaryOp.SparkAggType _aggtype;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction$Cpmm2MultiplyFunction.class */
    private static class Cpmm2MultiplyFunction implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -3718880362385713416L;
        private AggregateBinaryOperator _op;
        private ReorgOperator _rop;

        private Cpmm2MultiplyFunction() {
            this._op = null;
            this._rop = null;
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> tuple2) throws Exception {
            if (this._op == null) {
                this._op = InstructionUtils.getMatMultOperator(1);
                this._rop = new ReorgOperator(SwapIndex.getSwapIndexFnObject());
            }
            return OperationsOnMatrixValues.matMult((MatrixBlock) tuple2._1(), ((MatrixBlock) tuple2._2()).reorgOperations(this._rop, (MatrixValue) new MatrixBlock(), 0, 0, 0), new MatrixBlock(), this._op);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction$CpmmIndexFunction.class */
    private static class CpmmIndexFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, IndexedMatrixValue> {
        private static final long serialVersionUID = -1187183128301671162L;
        private final boolean _left;

        public CpmmIndexFunction(boolean z) {
            this._left = z;
        }

        public Tuple2<Long, IndexedMatrixValue> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            return new Tuple2<>(Long.valueOf(this._left ? ((MatrixIndexes) tuple2._1).getColumnIndex() : ((MatrixIndexes) tuple2._1).getRowIndex()), new IndexedMatrixValue((MatrixIndexes) tuple2._1(), (MatrixValue) tuple2._2()));
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction$CpmmMultiplyFunction.class */
    private static class CpmmMultiplyFunction implements PairFunction<Tuple2<Long, Tuple2<IndexedMatrixValue, IndexedMatrixValue>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -2009255629093036642L;
        private AggregateBinaryOperator _op;

        private CpmmMultiplyFunction() {
            this._op = null;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<Long, Tuple2<IndexedMatrixValue, IndexedMatrixValue>> tuple2) throws Exception {
            if (this._op == null) {
                this._op = InstructionUtils.getMatMultOperator(1);
            }
            MatrixBlock matrixBlock = (MatrixBlock) ((IndexedMatrixValue) ((Tuple2) tuple2._2())._1()).getValue();
            MatrixBlock matrixBlock2 = (MatrixBlock) ((IndexedMatrixValue) ((Tuple2) tuple2._2())._2()).getValue();
            MatrixIndexes matrixIndexes = new MatrixIndexes();
            MatrixBlock matMult = OperationsOnMatrixValues.matMult(matrixBlock, matrixBlock2, new MatrixBlock(), this._op);
            matrixIndexes.setIndexes(((IndexedMatrixValue) ((Tuple2) tuple2._2())._1()).getIndexes().getRowIndex(), ((IndexedMatrixValue) ((Tuple2) tuple2._2())._2()).getIndexes().getColumnIndex());
            return new Tuple2<>(matrixIndexes, matMult);
        }
    }

    private CpmmSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, boolean z, AggBinaryOp.SparkAggType sparkAggType, String str, String str2) {
        super(SPInstruction.SPType.CPMM, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._outputEmptyBlocks = z;
        this._aggtype = sparkAggType;
    }

    public static CpmmSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase("cpmm")) {
            throw new DMLRuntimeException("CpmmSPInstruction.parseInstruction(): Unknown opcode " + str2);
        }
        return new CpmmSPInstruction(InstructionUtils.getMatMultOperator(1), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), Boolean.parseBoolean(instructionPartsWithValueType[4]), AggBinaryOp.SparkAggType.valueOf(instructionPartsWithValueType[5]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable2 = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName());
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        if (!this._outputEmptyBlocks || this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK || dataCharacteristics.isNoEmptyBlocks() || dataCharacteristics2.isNoEmptyBlocks()) {
            binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.filter(new FilterNonEmptyBlocksFunction());
            binaryMatrixBlockRDDHandleForVariable2 = binaryMatrixBlockRDDHandleForVariable2.filter(new FilterNonEmptyBlocksFunction());
        }
        if (SparkUtils.isHashPartitioned(binaryMatrixBlockRDDHandleForVariable) && dataCharacteristics.getNumRowBlocks() == 1 && dataCharacteristics2.getCols() == 1) {
            sparkExecutionContext.setMatrixOutput(this.output.getName(), RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>) binaryMatrixBlockRDDHandleForVariable.join(binaryMatrixBlockRDDHandleForVariable2.mapToPair(new ReorgMapFunction("r'"))).values().map(new Cpmm2MultiplyFunction()).filter(new FilterNonEmptyBlocksFunction2())));
            return;
        }
        JavaPairRDD mapToPair = binaryMatrixBlockRDDHandleForVariable.mapToPair(new CpmmIndexFunction(true)).join(binaryMatrixBlockRDDHandleForVariable2.mapToPair(new CpmmIndexFunction(false)), Math.min(getMaxParJoin(dataCharacteristics, dataCharacteristics2), getPreferredParJoin(dataCharacteristics, dataCharacteristics2, binaryMatrixBlockRDDHandleForVariable.getNumPartitions(), binaryMatrixBlockRDDHandleForVariable2.getNumPartitions()))).mapToPair(new CpmmMultiplyFunction());
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            sparkExecutionContext.setMatrixOutput(this.output.getName(), RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>) mapToPair.filter(new FilterNonEmptyBlocksFunction())));
            return;
        }
        if (!this._outputEmptyBlocks || dataCharacteristics.isNoEmptyBlocks() || dataCharacteristics2.isNoEmptyBlocks()) {
            mapToPair = mapToPair.filter(new FilterNonEmptyBlocksFunction());
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), RDDAggregateUtils.sumByKeyStable(mapToPair, false));
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input2.getName());
        updateBinaryMMOutputDataCharacteristics(sparkExecutionContext, true);
    }

    private static int getPreferredParJoin(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, int i, int i2) {
        int defaultParallelism = SparkExecutionContext.getDefaultParallelism(true);
        int max = Math.max(i, i2);
        int max2 = (dataCharacteristics.dimsKnown(true) && dataCharacteristics2.dimsKnown(true)) ? Math.max(SparkUtils.getNumPreferredPartitions(dataCharacteristics) + SparkUtils.getNumPreferredPartitions(dataCharacteristics2), max) : max;
        return max2 > defaultParallelism / 2 ? Math.max(max2, defaultParallelism) : max2;
    }

    private static int getMaxParJoin(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2) {
        return dataCharacteristics.colsKnown() ? (int) dataCharacteristics.getNumColBlocks() : dataCharacteristics2.rowsKnown() ? (int) dataCharacteristics2.getNumRowBlocks() : PredictionContext.EMPTY_RETURN_STATE;
    }
}
