package org.apache.sysds.runtime.privacy.propagation;

import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/* loaded from: input_file:org/apache/sysds/runtime/privacy/propagation/OperatorType.class */
public enum OperatorType {
    Aggregate,
    NonAggregate;

    public static OperatorType getAggregationType(MMChainCPInstruction mMChainCPInstruction, ExecutionContext executionContext) {
        DataCharacteristics dataCharacteristics = executionContext.getDataCharacteristics(mMChainCPInstruction.getInputs()[0].getName());
        return (dataCharacteristics.getRows() == 1 && dataCharacteristics.getCols() == 1) ? NonAggregate : Aggregate;
    }

    public static OperatorType getAggregationType(MMTSJCPInstruction mMTSJCPInstruction, ExecutionContext executionContext) {
        DataCharacteristics dataCharacteristics = executionContext.getDataCharacteristics(mMTSJCPInstruction.getInputs()[0].getName());
        return (!(dataCharacteristics.getRows() == 1 && mMTSJCPInstruction.getMMTSJType() == MMTSJ.MMTSJType.LEFT) && (dataCharacteristics.getCols() != 1 || mMTSJCPInstruction.getMMTSJType() == MMTSJ.MMTSJType.LEFT)) ? Aggregate : NonAggregate;
    }

    public static OperatorType getAggregationType(AggregateBinaryCPInstruction aggregateBinaryCPInstruction, ExecutionContext executionContext) {
        DataCharacteristics dataCharacteristics = executionContext.getDataCharacteristics(aggregateBinaryCPInstruction.input1.getName());
        return ((dataCharacteristics.getCols() != 1 || aggregateBinaryCPInstruction.transposeLeft) && !(dataCharacteristics.getRows() == 1 && aggregateBinaryCPInstruction.transposeLeft)) ? Aggregate : NonAggregate;
    }
}
