package org.apache.sysds.runtime.instructions.spark;

import java.io.Serializable;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.class */
public class TernarySPInstruction extends ComputationSPInstruction {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunction.class */
    public static abstract class TernaryFunction implements Serializable {
        private static final long serialVersionUID = 8345737737972434426L;
        protected final TernaryOperator _op;
        protected final MatrixBlock _m1;
        protected final MatrixBlock _m2;
        protected final MatrixBlock _m3;

        public TernaryFunction(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            this._op = ternaryOperator;
            this._m1 = matrixBlock;
            this._m2 = matrixBlock2;
            this._m3 = matrixBlock3;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunctionMMM.class */
    private static class TernaryFunctionMMM extends TernaryFunction implements Function<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1;

        public TernaryFunctionMMM(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            super(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3);
        }

        public MatrixBlock call(Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> tuple2) throws Exception {
            return ((MatrixBlock) ((Tuple2) tuple2._1())._1()).ternaryOperations(this._op, (MatrixBlock) ((Tuple2) tuple2._1())._2(), (MatrixBlock) tuple2._2(), new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunctionMMS.class */
    private static class TernaryFunctionMMS extends TernaryFunction implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1;

        public TernaryFunctionMMS(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            super(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3);
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> tuple2) throws Exception {
            return ((MatrixBlock) tuple2._1()).ternaryOperations(this._op, (MatrixBlock) tuple2._2(), this._m3, new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunctionMSM.class */
    private static class TernaryFunctionMSM extends TernaryFunction implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1;

        public TernaryFunctionMSM(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            super(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3);
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> tuple2) throws Exception {
            return ((MatrixBlock) tuple2._1()).ternaryOperations(this._op, this._m2, (MatrixBlock) tuple2._2(), new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunctionMSS.class */
    private static class TernaryFunctionMSS extends TernaryFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1;

        public TernaryFunctionMSS(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            super(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3);
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            return matrixBlock.ternaryOperations(this._op, this._m2, this._m3, new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunctionSMM.class */
    private static class TernaryFunctionSMM extends TernaryFunction implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1;

        public TernaryFunctionSMM(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            super(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3);
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> tuple2) throws Exception {
            return this._m1.ternaryOperations(this._op, (MatrixBlock) tuple2._1(), (MatrixBlock) tuple2._2(), new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunctionSMS.class */
    private static class TernaryFunctionSMS extends TernaryFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1;

        public TernaryFunctionSMS(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            super(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3);
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            return this._m1.ternaryOperations(this._op, matrixBlock, this._m3, new MatrixBlock());
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/TernarySPInstruction$TernaryFunctionSSM.class */
    private static class TernaryFunctionSSM extends TernaryFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1;

        public TernaryFunctionSSM(TernaryOperator ternaryOperator, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3) {
            super(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3);
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            return this._m1.ternaryOperations(this._op, this._m2, matrixBlock, new MatrixBlock());
        }
    }

    private TernarySPInstruction(TernaryOperator ternaryOperator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(SPInstruction.SPType.Ternary, ternaryOperator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
    }

    public static TernarySPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        return new TernarySPInstruction(InstructionUtils.parseTernaryOperator(str2), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.spark.SPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = !this.input1.isMatrix() ? null : sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable2 = !this.input2.isMatrix() ? null : sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable3 = !this.input3.isMatrix() ? null : sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input3.getName());
        MatrixBlock matrixBlock = this.input1.isMatrix() ? null : new MatrixBlock(executionContext.getScalarInput(this.input1).getDoubleValue());
        MatrixBlock matrixBlock2 = this.input2.isMatrix() ? null : new MatrixBlock(executionContext.getScalarInput(this.input2).getDoubleValue());
        MatrixBlock matrixBlock3 = this.input3.isMatrix() ? null : new MatrixBlock(executionContext.getScalarInput(this.input3).getDoubleValue());
        TernaryOperator ternaryOperator = (TernaryOperator) this._optr;
        JavaPairRDD<?, ?> mapValues = (!this.input1.isMatrix() || this.input2.isMatrix() || this.input3.isMatrix()) ? (this.input1.isMatrix() || !this.input2.isMatrix() || this.input3.isMatrix()) ? (this.input1.isMatrix() || this.input2.isMatrix() || !this.input3.isMatrix()) ? (this.input1.isMatrix() && this.input2.isMatrix() && !this.input3.isMatrix()) ? binaryMatrixBlockRDDHandleForVariable.join(binaryMatrixBlockRDDHandleForVariable2).mapValues(new TernaryFunctionMMS(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3)) : (this.input1.isMatrix() && !this.input2.isMatrix() && this.input3.isMatrix()) ? binaryMatrixBlockRDDHandleForVariable.join(binaryMatrixBlockRDDHandleForVariable3).mapValues(new TernaryFunctionMSM(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3)) : (!this.input1.isMatrix() && this.input2.isMatrix() && this.input3.isMatrix()) ? binaryMatrixBlockRDDHandleForVariable2.join(binaryMatrixBlockRDDHandleForVariable3).mapValues(new TernaryFunctionSMM(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3)) : binaryMatrixBlockRDDHandleForVariable.join(binaryMatrixBlockRDDHandleForVariable2).join(binaryMatrixBlockRDDHandleForVariable3).mapValues(new TernaryFunctionMMM(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3)) : binaryMatrixBlockRDDHandleForVariable3.mapValues(new TernaryFunctionSSM(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3)) : binaryMatrixBlockRDDHandleForVariable2.mapValues(new TernaryFunctionSMS(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3)) : binaryMatrixBlockRDDHandleForVariable.mapValues(new TernaryFunctionMSS(ternaryOperator, matrixBlock, matrixBlock2, matrixBlock3));
        updateTernaryOutputDataCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapValues);
        if (this.input1.isMatrix()) {
            sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        }
        if (this.input2.isMatrix()) {
            sparkExecutionContext.addLineageRDD(this.output.getName(), this.input2.getName());
        }
        if (this.input3.isMatrix()) {
            sparkExecutionContext.addLineageRDD(this.output.getName(), this.input3.getName());
        }
    }

    protected void updateTernaryOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        for (CPOperand cPOperand : new CPOperand[]{this.input1, this.input2, this.input3}) {
            if (cPOperand.isMatrix()) {
                DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(cPOperand.getName());
                if (dataCharacteristics2.dimsKnown()) {
                    dataCharacteristics.set(dataCharacteristics2);
                }
            }
        }
    }
}
