package org.apache.sysds.runtime.instructions.spark;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.class */
public class TsmmSPInstruction extends UnarySPInstruction {
    private MMTSJ.MMTSJType _type;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction$RDDTSMMFunction.class */
    private static class RDDTSMMFunction implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 2935770425858019666L;
        private MMTSJ.MMTSJType _type;

        public RDDTSMMFunction(MMTSJ.MMTSJType mMTSJType) {
            this._type = null;
            this._type = mMTSJType;
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            return ((MatrixBlock) tuple2._2()).transposeSelfMatrixMultOperations(new MatrixBlock(), this._type);
        }
    }

    private TsmmSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, MMTSJ.MMTSJType mMTSJType, String str, String str2) {
        super(SPInstruction.SPType.TSMM, operator, cPOperand, cPOperand2, str, str2);
        this._type = null;
        this._type = mMTSJType;
    }

    public static TsmmSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase("tsmm")) {
            return new TsmmSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), MMTSJ.MMTSJType.valueOf(instructionPartsWithValueType[3]), str2, str);
        }
        throw new DMLRuntimeException("TsmmSPInstruction.parseInstruction():: Unknown opcode " + str2);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        sparkExecutionContext.setMatrixOutput(this.output.getName(), RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>) sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName()).map(new RDDTSMMFunction(this._type))));
    }
}
