package org.apache.sysds.runtime.instructions.spark;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.class */
public abstract class ComputationSPInstruction extends SPInstruction implements LineageTraceable {
    public CPOperand output;
    public CPOperand input1;
    public CPOperand input2;
    public CPOperand input3;

    /* JADX INFO: Access modifiers changed from: protected */
    public ComputationSPInstruction(SPInstruction.SPType sPType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(sPType, operator, str, str2);
        this.input1 = cPOperand;
        this.input2 = cPOperand2;
        this.input3 = null;
        this.output = cPOperand3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ComputationSPInstruction(SPInstruction.SPType sPType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(sPType, operator, str, str2);
        this.input1 = cPOperand;
        this.input2 = cPOperand2;
        this.input3 = cPOperand3;
        this.output = cPOperand4;
    }

    public String getOutputVariableName() {
        return this.output.getName();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateUnaryOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext) {
        updateUnaryOutputDataCharacteristics(sparkExecutionContext, this.input1.getName(), this.output.getName());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateUnaryOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext, String str, String str2) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(str);
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(str2);
        if (dataCharacteristics2.dimsKnown()) {
            return;
        }
        if (!dataCharacteristics.dimsKnown()) {
            throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + dataCharacteristics.toString() + " " + dataCharacteristics2.toString());
        }
        dataCharacteristics2.set(dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateBinaryOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        boolean z = dataCharacteristics.getRows() > 1 && dataCharacteristics.getCols() == 1 && dataCharacteristics2.getRows() == 1 && dataCharacteristics2.getCols() > 1;
        if (dataCharacteristics3.dimsKnown()) {
            return;
        }
        if (!dataCharacteristics.dimsKnown()) {
            throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + dataCharacteristics.toString() + " " + dataCharacteristics2.toString() + " " + dataCharacteristics3.toString());
        }
        if (z) {
            sparkExecutionContext.getDataCharacteristics(this.output.getName()).set(dataCharacteristics.getRows(), dataCharacteristics2.getCols(), dataCharacteristics.getBlocksize(), dataCharacteristics2.getBlocksize());
        } else {
            sparkExecutionContext.getDataCharacteristics(this.output.getName()).set(dataCharacteristics.getRows(), dataCharacteristics.getCols(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateBinaryTensorOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.input2.getName());
        DataCharacteristics dataCharacteristics3 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        if (!dataCharacteristics3.dimsKnown()) {
            if (!dataCharacteristics.dimsKnown()) {
                throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + dataCharacteristics.toString() + " " + dataCharacteristics2.toString() + " " + dataCharacteristics3.toString());
            }
            dataCharacteristics3.set(dataCharacteristics);
        }
        dataCharacteristics3.set(dataCharacteristics);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateUnaryAggOutputDataCharacteristics(SparkExecutionContext sparkExecutionContext, IndexFunction indexFunction) {
        DataCharacteristics dataCharacteristics = sparkExecutionContext.getDataCharacteristics(this.input1.getName());
        DataCharacteristics dataCharacteristics2 = sparkExecutionContext.getDataCharacteristics(this.output.getName());
        if (dataCharacteristics2.dimsKnown()) {
            return;
        }
        if (!dataCharacteristics.dimsKnown()) {
            throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + dataCharacteristics.toString() + " " + dataCharacteristics2.toString());
        }
        if (indexFunction instanceof ReduceAll) {
            dataCharacteristics2.set(1L, 1L, dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
        } else if (indexFunction instanceof ReduceCol) {
            dataCharacteristics2.set(dataCharacteristics.getRows(), 1L, dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
        } else if (indexFunction instanceof ReduceRow) {
            dataCharacteristics2.set(1L, dataCharacteristics.getCols(), dataCharacteristics.getBlocksize(), dataCharacteristics.getBlocksize());
        }
    }

    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, this.input1, this.input2, this.input3)));
    }
}
