package org.apache.sysds.runtime.instructions.spark;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.class */
public class MapmmChainSPInstruction extends SPInstruction implements LineageTraceable {
    private MapMultChain.ChainType _chainType;
    private CPOperand _input1;
    private CPOperand _input2;
    private CPOperand _input3;
    private CPOperand _output;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction$RDDMapMMChainFunction.class */
    private static class RDDMapMMChainFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private PartitionedBroadcast<MatrixBlock> _pmV;

        public RDDMapMMChainFunction(PartitionedBroadcast<MatrixBlock> partitionedBroadcast) {
            this._pmV = null;
            this._pmV = partitionedBroadcast;
        }

        public MatrixBlock call(MatrixBlock matrixBlock) {
            return matrixBlock.chainMatrixMultOperations(this._pmV.getBlock(1, 1), null, new MatrixBlock(), MapMultChain.ChainType.XtXv);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction$RDDMapMMChainFunction2.class */
    private static class RDDMapMMChainFunction2 implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -7926980450209760212L;
        private PartitionedBroadcast<MatrixBlock> _pmV;
        private PartitionedBroadcast<MatrixBlock> _pmW;
        private MapMultChain.ChainType _chainType;

        public RDDMapMMChainFunction2(PartitionedBroadcast<MatrixBlock> partitionedBroadcast, PartitionedBroadcast<MatrixBlock> partitionedBroadcast2, MapMultChain.ChainType chainType) {
            this._pmV = null;
            this._pmW = null;
            this._chainType = null;
            this._pmV = partitionedBroadcast;
            this._pmW = partitionedBroadcast2;
            this._chainType = chainType;
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
            MatrixBlock block = this._pmV.getBlock(1, 1);
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            return ((MatrixBlock) tuple2._2()).chainMatrixMultOperations(block, this._pmW.getBlock((int) matrixIndexes.getRowIndex(), 1), new MatrixBlock(), this._chainType);
        }
    }

    private MapmmChainSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, MapMultChain.ChainType chainType, String str, String str2) {
        super(SPInstruction.SPType.MAPMMCHAIN, operator, str, str2);
        this._chainType = null;
        this._input1 = null;
        this._input2 = null;
        this._input3 = null;
        this._output = null;
        this._input1 = cPOperand;
        this._input2 = cPOperand2;
        this._output = cPOperand3;
        this._chainType = chainType;
    }

    private MapmmChainSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, MapMultChain.ChainType chainType, String str, String str2) {
        super(SPInstruction.SPType.MAPMMCHAIN, operator, str, str2);
        this._chainType = null;
        this._input1 = null;
        this._input2 = null;
        this._input3 = null;
        this._output = null;
        this._input1 = cPOperand;
        this._input2 = cPOperand2;
        this._input3 = cPOperand3;
        this._output = cPOperand4;
        this._chainType = chainType;
    }

    public static MapmmChainSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 4, 5);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase(MapMultChain.OPCODE)) {
            throw new DMLRuntimeException("MapmmChainSPInstruction.parseInstruction():: Unknown opcode " + str2);
        }
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        return instructionPartsWithValueType.length == 5 ? new MapmmChainSPInstruction(null, cPOperand, cPOperand2, new CPOperand(instructionPartsWithValueType[3]), MapMultChain.ChainType.valueOf(instructionPartsWithValueType[4]), str2, str) : new MapmmChainSPInstruction(null, cPOperand, cPOperand2, new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), MapMultChain.ChainType.valueOf(instructionPartsWithValueType[5]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this._input1.getName());
        PartitionedBroadcast<MatrixBlock> broadcastForVariable = sparkExecutionContext.getBroadcastForVariable(this._input2.getName());
        sparkExecutionContext.setMatrixOutput(this._output.getName(), this._chainType == MapMultChain.ChainType.XtXv ? RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>) binaryMatrixBlockRDDHandleForVariable.values().map(new RDDMapMMChainFunction(broadcastForVariable))) : RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>) binaryMatrixBlockRDDHandleForVariable.map(new RDDMapMMChainFunction2(broadcastForVariable, sparkExecutionContext.getBroadcastForVariable(this._input3.getName()), this._chainType))));
    }

    @Override // org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return Pair.of(this._output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, this._input1, this._input2, this._input3, new CPOperand(this._chainType.name(), Types.ValueType.INT64, Types.DataType.SCALAR, true))));
    }
}
