package org.apache.sysds.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.BinaryM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
import org.apache.sysds.runtime.instructions.spark.functions.MatrixScalarUnaryFunction;
import org.apache.sysds.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction;
import org.apache.sysds.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReblockTensorFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateTensorFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysds.runtime.instructions.spark.functions.TensorTensorBinaryOpFunction;
import org.apache.sysds.runtime.instructions.spark.functions.TensorTensorBinaryOpPartitionFunction;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataUtils;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.class */
public abstract class BinarySPInstruction extends ComputationSPInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public BinarySPInstruction(SPInstruction.SPType sPType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(sPType, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    public static BinarySPInstruction parseInstruction(String str) {
        String parseBinaryInstruction;
        CPOperand cPOperand = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand cPOperand2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand cPOperand3 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        boolean z = false;
        BinaryM.VectorType vectorType = null;
        if (str.startsWith("SPARK°map")) {
            String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
            parseBinaryInstruction = instructionPartsWithValueType[0];
            cPOperand.split(instructionPartsWithValueType[1]);
            cPOperand2.split(instructionPartsWithValueType[2]);
            cPOperand3.split(instructionPartsWithValueType[3]);
            vectorType = BinaryM.VectorType.valueOf(instructionPartsWithValueType[5]);
            z = true;
        } else {
            parseBinaryInstruction = parseBinaryInstruction(str, cPOperand, cPOperand2, cPOperand3);
        }
        Types.DataType dataType = cPOperand.getDataType();
        Types.DataType dataType2 = cPOperand2.getDataType();
        Operator parseExtendedBinaryOrBuiltinOperator = InstructionUtils.parseExtendedBinaryOrBuiltinOperator(parseBinaryInstruction, cPOperand, cPOperand2);
        if (dataType == Types.DataType.MATRIX || dataType2 == Types.DataType.MATRIX) {
            return (dataType == Types.DataType.MATRIX && dataType2 == Types.DataType.MATRIX) ? z ? new BinaryMatrixBVectorSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, vectorType, parseBinaryInstruction, str) : new BinaryMatrixMatrixSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, parseBinaryInstruction, str) : (dataType == Types.DataType.FRAME && dataType2 == Types.DataType.MATRIX) ? new BinaryFrameMatrixSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, parseBinaryInstruction, str) : new BinaryMatrixScalarSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, parseBinaryInstruction, str);
        }
        if (dataType == Types.DataType.TENSOR || dataType2 == Types.DataType.TENSOR) {
            if (dataType == Types.DataType.TENSOR && dataType2 == Types.DataType.TENSOR) {
                return z ? new BinaryTensorTensorBroadcastSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, parseBinaryInstruction, str) : new BinaryTensorTensorSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, parseBinaryInstruction, str);
            }
            throw new DMLRuntimeException("Tensor binary operation not yet implemented for tensor-scalar, or tensor-matrix");
        }
        if (dataType != Types.DataType.FRAME && dataType2 != Types.DataType.FRAME) {
            return null;
        }
        if (dataType == Types.DataType.FRAME && dataType2 == Types.DataType.FRAME) {
            return new BinaryFrameFrameSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, parseBinaryInstruction, str);
        }
        if (dataType == Types.DataType.FRAME && dataType2 == Types.DataType.SCALAR) {
            return new BinaryFrameScalarSPInstruction(parseExtendedBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, parseBinaryInstruction, str);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String parseBinaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        cPOperand3.split(instructionPartsWithValueType[3]);
        return str2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String parseBinaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 4);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        cPOperand3.split(instructionPartsWithValueType[3]);
        cPOperand4.split(instructionPartsWithValueType[4]);
        return str2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processMatrixMatrixBinaryInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        checkMatrixMatrixBinaryCharacteristics(sparkExecutionContext);
        updateBinaryOutputDataCharacteristics(sparkExecutionContext);
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable2 = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName());
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        BinaryOperator binaryOperator = (BinaryOperator) this._optr;
        boolean z = dataCharacteristics2.getRows() == 1 && dataCharacteristics.getRows() > 1;
        long numReplicas = getNumReplicas(dataCharacteristics, dataCharacteristics2, true);
        long numReplicas2 = getNumReplicas(dataCharacteristics, dataCharacteristics2, false);
        if (numReplicas > 1) {
            binaryMatrixBlockRDDHandleForVariable = binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new ReplicateVectorFunction(false, numReplicas));
        }
        if (numReplicas2 > 1) {
            binaryMatrixBlockRDDHandleForVariable2 = binaryMatrixBlockRDDHandleForVariable2.flatMapToPair(new ReplicateVectorFunction(z, numReplicas2));
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), binaryMatrixBlockRDDHandleForVariable.join(binaryMatrixBlockRDDHandleForVariable2, SparkUtils.isHashPartitioned(binaryMatrixBlockRDDHandleForVariable) ? binaryMatrixBlockRDDHandleForVariable.getNumPartitions() : SparkUtils.isHashPartitioned(binaryMatrixBlockRDDHandleForVariable2) ? binaryMatrixBlockRDDHandleForVariable2.getNumPartitions() : Math.min(binaryMatrixBlockRDDHandleForVariable.getNumPartitions() + binaryMatrixBlockRDDHandleForVariable2.getNumPartitions(), 2 * SparkUtils.getNumPreferredPartitions(dataCharacteristics3))).mapValues(new MatrixMatrixBinaryOpFunction(binaryOperator)));
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input2.getName());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processTensorTensorBinaryInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        checkTensorTensorBinaryCharacteristics(sparkExecutionContext);
        updateBinaryTensorOutputDataCharacteristics(sparkExecutionContext);
        JavaPairRDD<TensorIndexes, TensorBlock> binaryTensorBlockRDDHandleForVariable = sparkExecutionContext.getBinaryTensorBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<TensorIndexes, TensorBlock> binaryTensorBlockRDDHandleForVariable2 = sparkExecutionContext.getBinaryTensorBlockRDDHandleForVariable(this.input2.getName());
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        BinaryOperator binaryOperator = (BinaryOperator) this._optr;
        if (dataCharacteristics2.getNumDims() < dataCharacteristics.getNumDims()) {
            binaryTensorBlockRDDHandleForVariable2 = binaryTensorBlockRDDHandleForVariable2.flatMapToPair(new ReblockTensorFunction(dataCharacteristics.getNumDims(), dataCharacteristics.getBlocksize()));
        }
        for (int i = 0; i < dataCharacteristics.getNumDims(); i++) {
            long numDimReplicas = getNumDimReplicas(dataCharacteristics, dataCharacteristics2, i);
            if (numDimReplicas > 1) {
                binaryTensorBlockRDDHandleForVariable2 = binaryTensorBlockRDDHandleForVariable2.flatMapToPair(new ReplicateTensorFunction(i, numDimReplicas));
            }
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), binaryTensorBlockRDDHandleForVariable.join(binaryTensorBlockRDDHandleForVariable2, SparkUtils.isHashPartitioned(binaryTensorBlockRDDHandleForVariable) ? binaryTensorBlockRDDHandleForVariable.getNumPartitions() : SparkUtils.isHashPartitioned(binaryTensorBlockRDDHandleForVariable2) ? binaryTensorBlockRDDHandleForVariable2.getNumPartitions() : Math.min(binaryTensorBlockRDDHandleForVariable.getNumPartitions() + binaryTensorBlockRDDHandleForVariable2.getNumPartitions(), 2 * SparkUtils.getNumPreferredPartitions(dataCharacteristics3))).mapValues(new TensorTensorBinaryOpFunction(binaryOperator)));
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input2.getName());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processMatrixBVectorBinaryInstruction(ExecutionContext executionContext, BinaryM.VectorType vectorType) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        checkMatrixMatrixBinaryCharacteristics(sparkExecutionContext);
        String name = this.input1.getName();
        String name2 = this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryMatrixBlockRDDHandleForVariable = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(name);
        PartitionedBroadcast<MatrixBlock> broadcastForVariable = sparkExecutionContext.getBroadcastForVariable(name2);
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(name);
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(name2);
        BinaryOperator binaryOperator = (BinaryOperator) this._optr;
        JavaPairRDD<?, ?> flatMapToPair = (dataCharacteristics.getRows() > 1L ? 1 : (dataCharacteristics.getRows() == 1L ? 0 : -1)) > 0 && (dataCharacteristics.getCols() > 1L ? 1 : (dataCharacteristics.getCols() == 1L ? 0 : -1)) == 0 && (dataCharacteristics2.getRows() > 1L ? 1 : (dataCharacteristics2.getRows() == 1L ? 0 : -1)) == 0 && (dataCharacteristics2.getCols() > 1L ? 1 : (dataCharacteristics2.getCols() == 1L ? 0 : -1)) > 0 ? binaryMatrixBlockRDDHandleForVariable.flatMapToPair(new OuterVectorBinaryOpFunction(binaryOperator, broadcastForVariable)) : binaryMatrixBlockRDDHandleForVariable.mapPartitionsToPair(new MatrixVectorBinaryOpPartitionFunction(binaryOperator, broadcastForVariable, vectorType), true);
        updateBinaryOutputDataCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), flatMapToPair);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        sparkExecutionContext.addLineageBroadcast(this.output.getName(), name2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processTensorTensorBroadcastBinaryInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        checkTensorTensorBinaryCharacteristics(sparkExecutionContext);
        String name = this.input1.getName();
        String name2 = this.input2.getName();
        JavaPairRDD<TensorIndexes, TensorBlock> binaryTensorBlockRDDHandleForVariable = sparkExecutionContext.getBinaryTensorBlockRDDHandleForVariable(name);
        DataCharacteristics blocksize = sparkExecutionContext.getDataCharacteristics(name2).setBlocksize(sparkExecutionContext.getDataCharacteristics(name).getBlocksize());
        PartitionedBroadcast<TensorBlock> broadcastForTensorVariable = sparkExecutionContext.getBroadcastForTensorVariable(name2);
        BinaryOperator binaryOperator = (BinaryOperator) this._optr;
        boolean[] zArr = new boolean[blocksize.getNumDims()];
        for (int i = 0; i < zArr.length; i++) {
            zArr[i] = blocksize.getDim(i) == 1;
        }
        JavaPairRDD<?, ?> mapPartitionsToPair = binaryTensorBlockRDDHandleForVariable.mapPartitionsToPair(new TensorTensorBinaryOpPartitionFunction(binaryOperator, broadcastForTensorVariable, zArr), true);
        updateBinaryTensorOutputDataCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapPartitionsToPair);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        sparkExecutionContext.addLineageBroadcast(this.output.getName(), name2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processMatrixScalarBinaryInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        String name = this.input1.getDataType() == Types.DataType.MATRIX ? this.input1.getName() : this.input2.getName();
        JavaPairRDD<?, ?> mapValues = sparkExecutionContext.getBinaryMatrixBlockRDDHandleForVariable(name).mapValues(new MatrixScalarUnaryFunction(((ScalarOperator) this._optr).setConstant(executionContext.getScalarInput(this.input1.getDataType() == Types.DataType.MATRIX ? this.input2 : this.input1).getDoubleValue())));
        updateUnaryOutputDataCharacteristics(sparkExecutionContext, name, this.output.getName());
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapValues);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DataCharacteristics updateBinaryMMOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext, boolean z) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        if (!dataCharacteristics3.dimsKnown()) {
            if (!dataCharacteristics.dimsKnown() || !dataCharacteristics2.dimsKnown()) {
                throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
            }
            if (dataCharacteristics.getBlocksize() != dataCharacteristics2.getBlocksize()) {
                throw new DMLRuntimeException("Incompatible block sizes for BinarySPInstruction.");
            }
            if (z && dataCharacteristics.getCols() != dataCharacteristics2.getRows()) {
                throw new DMLRuntimeException("Incompatible dimensions for BinarySPInstruction");
            }
            dataCharacteristics3.set(dataCharacteristics.getRows(), dataCharacteristics2.getCols(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
        }
        return dataCharacteristics3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateBinaryAppendOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext, boolean z) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        MetaDataUtils.updateAppendDataCharacteristics(dataCharacteristics, dataCharacteristics2, dataCharacteristics3, z);
        if (!dataCharacteristics3.nnzKnown() && dataCharacteristics.nnzKnown() && dataCharacteristics2.nnzKnown()) {
            dataCharacteristics3.setNonZeros(dataCharacteristics.getNonZeros() + dataCharacteristics2.getNonZeros());
        }
    }

    protected long getNumReplicas(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, boolean z) {
        if (z) {
            if (dataCharacteristics.getCols() == 1) {
                return dataCharacteristics2.getNumColBlocks();
            }
            return 1L;
        }
        if (dataCharacteristics2.getRows() == 1 && dataCharacteristics.getRows() > 1) {
            return dataCharacteristics.getNumRowBlocks();
        }
        if (dataCharacteristics2.getCols() != 1 || dataCharacteristics.getCols() <= 1) {
            return 1L;
        }
        return dataCharacteristics2.getNumColBlocks();
    }

    protected long getNumDimReplicas(DataCharacteristics dataCharacteristics, DataCharacteristics dataCharacteristics2, int i) {
        if (i >= dataCharacteristics2.getNumDims() || (dataCharacteristics2.getDim(i) == 1 && dataCharacteristics2.getDim(i) > 1)) {
            return dataCharacteristics.getNumBlocks(i);
        }
        return 1L;
    }

    protected void checkMatrixMatrixBinaryCharacteristics(SparkExecutionContext sparkExecutionContext) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        if (!dataCharacteristics.dimsKnown() || !dataCharacteristics2.dimsKnown()) {
            throw new DMLRuntimeException("Unknown dimensions matrix-matrix binary operations: [" + dataCharacteristics.getRows() + "x" + dataCharacteristics.getCols() + " vs " + dataCharacteristics2.getRows() + "x" + dataCharacteristics2.getCols() + "]");
        }
        if ((dataCharacteristics.getRows() != dataCharacteristics2.getRows() || dataCharacteristics.getCols() != dataCharacteristics2.getCols()) && ((dataCharacteristics.getRows() != dataCharacteristics2.getRows() || dataCharacteristics2.getCols() != 1) && ((dataCharacteristics.getCols() != dataCharacteristics2.getCols() || dataCharacteristics2.getRows() != 1) && (dataCharacteristics.getCols() != 1 || dataCharacteristics2.getRows() != 1)))) {
            throw new DMLRuntimeException("Dimensions mismatch matrix-matrix binary operations: [" + dataCharacteristics.getRows() + "x" + dataCharacteristics.getCols() + " vs " + dataCharacteristics2.getRows() + "x" + dataCharacteristics2.getCols() + "]");
        }
        if (dataCharacteristics.getBlocksize() != dataCharacteristics2.getBlocksize()) {
            throw new DMLRuntimeException("Blocksize mismatch matrix-matrix binary operations: [" + dataCharacteristics.getBlocksize() + "x" + dataCharacteristics.getBlocksize() + " vs " + dataCharacteristics2.getBlocksize() + "x" + dataCharacteristics2.getBlocksize() + "]");
        }
    }

    protected void checkTensorTensorBinaryCharacteristics(SparkExecutionContext sparkExecutionContext) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        if (!dataCharacteristics.dimsKnown() || !dataCharacteristics2.dimsKnown()) {
            throw new DMLRuntimeException("Unknown dimensions tensor-tensor binary operations");
        }
        boolean z = dataCharacteristics.getNumDims() < dataCharacteristics2.getNumDims();
        if (!z) {
            int i = 0;
            while (true) {
                if (i < dataCharacteristics2.getNumDims()) {
                    if (dataCharacteristics.getDim(i) != dataCharacteristics2.getDim(i) && dataCharacteristics2.getDim(i) != 1) {
                        z = true;
                        break;
                    }
                    i++;
                } else {
                    break;
                }
            }
        }
        if (z) {
            throw new DMLRuntimeException("Dimensions mismatch tensor-tensor binary operations");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void checkBinaryAppendInputCharacteristics(SparkExecutionContext sparkExecutionContext, boolean z, boolean z2, boolean z3) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        if (!dataCharacteristics.dimsKnown() || !dataCharacteristics2.dimsKnown()) {
            throw new DMLRuntimeException("The dimensions unknown for inputs");
        }
        if (z && dataCharacteristics.getRows() != dataCharacteristics2.getRows()) {
            throw new DMLRuntimeException("The number of rows of inputs should match for append-cbind instruction");
        }
        if (!z && dataCharacteristics.getCols() != dataCharacteristics2.getCols()) {
            throw new DMLRuntimeException("The number of columns of inputs should match for append-rbind instruction");
        }
        if (dataCharacteristics.getBlocksize() != dataCharacteristics2.getBlocksize()) {
            throw new DMLRuntimeException("The block sizes do not match for input matrices");
        }
        if (z2 && dataCharacteristics.getCols() + dataCharacteristics2.getCols() > dataCharacteristics.getBlocksize()) {
            throw new DMLRuntimeException("Output must have at most one column block");
        }
        if (z3) {
            if ((z ? dataCharacteristics.getCols() : dataCharacteristics.getRows()) % dataCharacteristics.getBlocksize() != 0) {
                throw new DMLRuntimeException("Input matrices are not aligned to blocksize boundaries. Wrong append selected");
            }
        }
    }
}
