package org.apache.sysds.runtime.privacy.propagation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.class */
public class PrivacyPropagator {
    public static Data parseAndSetPrivacyConstraint(Data data, JSONObject jSONObject) throws JSONException {
        String string;
        if (jSONObject.containsKey(DataExpression.PRIVACY) && (string = jSONObject.getString(DataExpression.PRIVACY)) != null) {
            data.setPrivacyConstraints(new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.valueOf(string)));
        }
        return data;
    }

    private static boolean anyInputHasLevel(PrivacyConstraint.PrivacyLevel[] privacyLevelArr, PrivacyConstraint.PrivacyLevel privacyLevel) {
        return Arrays.stream(privacyLevelArr).anyMatch(privacyLevel2 -> {
            return privacyLevel2 == privacyLevel;
        });
    }

    public static PrivacyConstraint.PrivacyLevel corePropagation(PrivacyConstraint.PrivacyLevel[] privacyLevelArr, OperatorType operatorType) {
        return anyInputHasLevel(privacyLevelArr, PrivacyConstraint.PrivacyLevel.Private) ? PrivacyConstraint.PrivacyLevel.Private : operatorType == OperatorType.Aggregate ? PrivacyConstraint.PrivacyLevel.None : (operatorType == OperatorType.NonAggregate && anyInputHasLevel(privacyLevelArr, PrivacyConstraint.PrivacyLevel.PrivateAggregation)) ? PrivacyConstraint.PrivacyLevel.PrivateAggregation : PrivacyConstraint.PrivacyLevel.None;
    }

    public static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraintArr, OperatorType operatorType) {
        return new PrivacyConstraint(corePropagation((PrivacyConstraint.PrivacyLevel[]) Arrays.stream(privacyConstraintArr).map(privacyConstraint -> {
            return privacyConstraint != null ? privacyConstraint.getPrivacyLevel() : PrivacyConstraint.PrivacyLevel.None;
        }).toArray(i -> {
            return new PrivacyConstraint.PrivacyLevel[i];
        }), operatorType));
    }

    public static PrivacyConstraint mergeBinary(PrivacyConstraint privacyConstraint, PrivacyConstraint privacyConstraint2) {
        if (privacyConstraint == null || privacyConstraint2 == null) {
            if (privacyConstraint != null) {
                return privacyConstraint;
            }
            if (privacyConstraint2 != null) {
                return privacyConstraint2;
            }
            return null;
        }
        PrivacyConstraint.PrivacyLevel privacyLevel = privacyConstraint.getPrivacyLevel();
        PrivacyConstraint.PrivacyLevel privacyLevel2 = privacyConstraint2.getPrivacyLevel();
        if (privacyLevel == PrivacyConstraint.PrivacyLevel.Private || privacyLevel2 == PrivacyConstraint.PrivacyLevel.Private) {
            return new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.Private);
        }
        if (privacyLevel == PrivacyConstraint.PrivacyLevel.PrivateAggregation || privacyLevel2 == PrivacyConstraint.PrivacyLevel.PrivateAggregation) {
            return new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation);
        }
        return null;
    }

    public static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraintArr) {
        PrivacyConstraint privacyConstraint = privacyConstraintArr[0];
        for (int i = 1; i < privacyConstraintArr.length; i++) {
            privacyConstraint = mergeBinary(privacyConstraint, privacyConstraintArr[i]);
        }
        return privacyConstraint;
    }

    public static Instruction preprocessInstruction(Instruction instruction, ExecutionContext executionContext) {
        switch (instruction.getType()) {
            case CONTROL_PROGRAM:
                return preprocessCPInstructionFineGrained((CPInstruction) instruction, executionContext);
            case BREAKPOINT:
            case SPARK:
            case GPU:
            case FEDERATED:
                return instruction;
            default:
                throwExceptionIfPrivacyActivated(instruction);
                return instruction;
        }
    }

    public static Instruction preprocessCPInstructionFineGrained(CPInstruction cPInstruction, ExecutionContext executionContext) {
        switch (cPInstruction.getCPInstructionType()) {
            case AggregateBinary:
                if (!(cPInstruction instanceof AggregateBinaryCPInstruction)) {
                    if (!(cPInstruction instanceof CovarianceCPInstruction)) {
                        preprocessInstructionSimple(cPInstruction, executionContext);
                        break;
                    } else {
                        return preprocessCovarianceCPInstruction((CovarianceCPInstruction) cPInstruction, executionContext);
                    }
                } else {
                    return preprocessAggregateBinaryCPInstruction((AggregateBinaryCPInstruction) cPInstruction, executionContext);
                }
            case AggregateTernary:
                break;
            case AggregateUnary:
                return preprocessAggregateUnaryCPInstruction((AggregateUnaryCPInstruction) cPInstruction, executionContext);
            case Append:
            case Binary:
                return preprocessBinaryCPInstruction((BinaryCPInstruction) cPInstruction, executionContext);
            case Builtin:
            case BuiltinNary:
                return preprocessBuiltinNary((BuiltinNaryCPInstruction) cPInstruction, executionContext);
            case FCall:
                return preprocessExternal((FunctionCallCPInstruction) cPInstruction, executionContext);
            case MultiReturnBuiltin:
            case MultiReturnParameterizedBuiltin:
                return preprocessMultiReturn((ComputationCPInstruction) cPInstruction, executionContext);
            case ParameterizedBuiltin:
                return preprocessParameterizedBuiltin((ParameterizedBuiltinCPInstruction) cPInstruction, executionContext);
            case Quaternary:
                return preprocessQuaternary((QuaternaryCPInstruction) cPInstruction, executionContext);
            case Reorg:
                return preprocessUnaryCPInstruction((UnaryCPInstruction) cPInstruction, executionContext);
            case Ternary:
                return preprocessTernaryCPInstruction((ComputationCPInstruction) cPInstruction, executionContext);
            case Unary:
                return preprocessUnaryCPInstruction((UnaryCPInstruction) cPInstruction, executionContext);
            case Variable:
                return preprocessVariableCPInstruction((VariableCPInstruction) cPInstruction, executionContext);
            default:
                return preprocessInstructionSimple(cPInstruction, executionContext);
        }
        return preprocessTernaryCPInstruction((ComputationCPInstruction) cPInstruction, executionContext);
    }

    private static Instruction preprocessCovarianceCPInstruction(CovarianceCPInstruction covarianceCPInstruction, ExecutionContext executionContext) {
        throwExceptionIfPrivacyActivated(covarianceCPInstruction);
        for (CPOperand cPOperand : covarianceCPInstruction.getInputs()) {
            if (getInputPrivacyConstraint(executionContext, cPOperand) != null) {
                throw new DMLPrivacyException("Input of instruction " + covarianceCPInstruction + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
            }
        }
        return covarianceCPInstruction;
    }

    private static Instruction preprocessAggregateBinaryCPInstruction(AggregateBinaryCPInstruction aggregateBinaryCPInstruction, ExecutionContext executionContext) {
        PrivacyConstraint mergeNary;
        PrivacyConstraint inputPrivacyConstraint = getInputPrivacyConstraint(executionContext, aggregateBinaryCPInstruction.input1);
        PrivacyConstraint inputPrivacyConstraint2 = getInputPrivacyConstraint(executionContext, aggregateBinaryCPInstruction.input2);
        if ((inputPrivacyConstraint != null && inputPrivacyConstraint.hasConstraints()) || (inputPrivacyConstraint2 != null && inputPrivacyConstraint2.hasConstraints())) {
            if ((inputPrivacyConstraint == null || !inputPrivacyConstraint.hasFineGrainedConstraints()) && (inputPrivacyConstraint2 == null || !inputPrivacyConstraint2.hasFineGrainedConstraints())) {
                mergeNary = mergeNary(new PrivacyConstraint[]{inputPrivacyConstraint, inputPrivacyConstraint2}, OperatorType.NonAggregate);
                aggregateBinaryCPInstruction.setPrivacyConstraint(mergeNary);
            } else {
                mergeNary = new MatrixMultiplicationPropagatorPrivateFirst(executionContext.getMatrixInput(aggregateBinaryCPInstruction.input1.getName()), inputPrivacyConstraint, executionContext.getMatrixInput(aggregateBinaryCPInstruction.input2.getName()), inputPrivacyConstraint2).propagate();
            }
            aggregateBinaryCPInstruction.output.setPrivacyConstraint(mergeNary);
        }
        return aggregateBinaryCPInstruction;
    }

    public static Instruction preprocessBinaryCPInstruction(BinaryCPInstruction binaryCPInstruction, ExecutionContext executionContext) {
        PrivacyConstraint inputPrivacyConstraint = getInputPrivacyConstraint(executionContext, binaryCPInstruction.input1);
        PrivacyConstraint inputPrivacyConstraint2 = getInputPrivacyConstraint(executionContext, binaryCPInstruction.input2);
        if (inputPrivacyConstraint != null || inputPrivacyConstraint2 != null) {
            PrivacyConstraint mergeBinary = mergeBinary(inputPrivacyConstraint, inputPrivacyConstraint2);
            binaryCPInstruction.setPrivacyConstraint(mergeBinary);
            binaryCPInstruction.output.setPrivacyConstraint(mergeBinary);
        }
        return binaryCPInstruction;
    }

    private static Instruction preprocessAggregateUnaryCPInstruction(AggregateUnaryCPInstruction aggregateUnaryCPInstruction, ExecutionContext executionContext) {
        PrivacyConstraint inputPrivacyConstraint = getInputPrivacyConstraint(executionContext, aggregateUnaryCPInstruction.input1);
        if (inputPrivacyConstraint != null) {
            aggregateUnaryCPInstruction.setPrivacyConstraint(inputPrivacyConstraint);
            if (aggregateUnaryCPInstruction.output != null && inputPrivacyConstraint.hasPrivateElements()) {
                aggregateUnaryCPInstruction.output.setPrivacyConstraint(new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.Private));
            }
        }
        return aggregateUnaryCPInstruction;
    }

    public static Instruction preprocessInstructionSimple(Instruction instruction, ExecutionContext executionContext) {
        throwExceptionIfPrivacyActivated(instruction);
        return instruction;
    }

    public static Instruction preprocessExternal(FunctionCallCPInstruction functionCallCPInstruction, ExecutionContext executionContext) {
        return mergePrivacyConstraintsFromInput(functionCallCPInstruction, executionContext, functionCallCPInstruction.getInputs(), (String[]) functionCallCPInstruction.getBoundOutputParamNames().toArray(new String[0]));
    }

    public static Instruction preprocessMultiReturn(ComputationCPInstruction computationCPInstruction, ExecutionContext executionContext) {
        return mergePrivacyConstraintsFromInput(computationCPInstruction, executionContext, computationCPInstruction.getInputs(), getOutputOperands(computationCPInstruction));
    }

    public static Instruction preprocessParameterizedBuiltin(ParameterizedBuiltinCPInstruction parameterizedBuiltinCPInstruction, ExecutionContext executionContext) {
        return mergePrivacyConstraintsFromInput(parameterizedBuiltinCPInstruction, executionContext, parameterizedBuiltinCPInstruction.getInputs(), parameterizedBuiltinCPInstruction.getOutput());
    }

    private static Instruction mergePrivacyConstraintsFromInput(Instruction instruction, ExecutionContext executionContext, CPOperand[] cPOperandArr, String[] strArr) {
        PrivacyConstraint[] inputPrivacyConstraints;
        if (cPOperandArr != null && cPOperandArr.length > 0 && (inputPrivacyConstraints = getInputPrivacyConstraints(executionContext, cPOperandArr)) != null) {
            PrivacyConstraint mergeNary = mergeNary(inputPrivacyConstraints);
            instruction.setPrivacyConstraint(mergeNary);
            if (strArr != null) {
                for (String str : strArr) {
                    setOutputPrivacyConstraint(executionContext, mergeNary, str);
                }
            }
        }
        return instruction;
    }

    private static Instruction mergePrivacyConstraintsFromInput(Instruction instruction, ExecutionContext executionContext, CPOperand[] cPOperandArr, CPOperand cPOperand) {
        return mergePrivacyConstraintsFromInput(instruction, executionContext, cPOperandArr, getSingletonList(cPOperand));
    }

    private static Instruction mergePrivacyConstraintsFromInput(Instruction instruction, ExecutionContext executionContext, CPOperand[] cPOperandArr, List<CPOperand> list) {
        PrivacyConstraint[] inputPrivacyConstraints;
        if (cPOperandArr != null && cPOperandArr.length > 0 && (inputPrivacyConstraints = getInputPrivacyConstraints(executionContext, cPOperandArr)) != null) {
            PrivacyConstraint mergeNary = mergeNary(inputPrivacyConstraints);
            instruction.setPrivacyConstraint(mergeNary);
            for (CPOperand cPOperand : list) {
                if (cPOperand != null) {
                    cPOperand.setPrivacyConstraint(mergeNary);
                }
            }
        }
        return instruction;
    }

    public static Instruction preprocessBuiltinNary(BuiltinNaryCPInstruction builtinNaryCPInstruction, ExecutionContext executionContext) {
        return mergePrivacyConstraintsFromInput(builtinNaryCPInstruction, executionContext, builtinNaryCPInstruction.getInputs(), builtinNaryCPInstruction.getOutput());
    }

    public static Instruction preprocessQuaternary(QuaternaryCPInstruction quaternaryCPInstruction, ExecutionContext executionContext) {
        return mergePrivacyConstraintsFromInput(quaternaryCPInstruction, executionContext, new CPOperand[]{quaternaryCPInstruction.input1, quaternaryCPInstruction.input2, quaternaryCPInstruction.input3, quaternaryCPInstruction.getInput4()}, quaternaryCPInstruction.output);
    }

    public static Instruction preprocessTernaryCPInstruction(ComputationCPInstruction computationCPInstruction, ExecutionContext executionContext) {
        return mergePrivacyConstraintsFromInput(computationCPInstruction, executionContext, computationCPInstruction.getInputs(), computationCPInstruction.output);
    }

    public static Instruction preprocessUnaryCPInstruction(UnaryCPInstruction unaryCPInstruction, ExecutionContext executionContext) {
        return propagateInputPrivacy(unaryCPInstruction, executionContext, unaryCPInstruction.input1, unaryCPInstruction.output);
    }

    public static Instruction preprocessVariableCPInstruction(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        switch (variableCPInstruction.getVariableOpcode()) {
            case CreateVariable:
                return propagateSecondInputPrivacy(variableCPInstruction, executionContext);
            case AssignVariable:
                return propagateInputPrivacy(variableCPInstruction, executionContext, variableCPInstruction.getInput1(), variableCPInstruction.getInput2());
            case CopyVariable:
            case MoveVariable:
            case RemoveVariableAndFile:
            case CastAsMatrixVariable:
            case CastAsFrameVariable:
            case Write:
            case SetFileName:
                return propagateFirstInputPrivacy(variableCPInstruction, executionContext);
            case RemoveVariable:
                return propagateAllInputPrivacy(variableCPInstruction, executionContext);
            case CastAsScalarVariable:
            case CastAsDoubleVariable:
            case CastAsIntegerVariable:
            case CastAsBooleanVariable:
                return propagateCastAsScalarVariablePrivacy(variableCPInstruction, executionContext);
            case Read:
                return variableCPInstruction;
            default:
                throwExceptionIfPrivacyActivated(variableCPInstruction);
                return variableCPInstruction;
        }
    }

    private static void throwExceptionIfPrivacyActivated(Instruction instruction) {
        if (instruction.getPrivacyConstraint() != null && instruction.getPrivacyConstraint().hasConstraints()) {
            throw new DMLPrivacyException("Instruction " + instruction + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
        }
    }

    private static Instruction propagateCastAsScalarVariablePrivacy(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        return (VariableCPInstruction) propagateFirstInputPrivacy(variableCPInstruction, executionContext);
    }

    private static Instruction propagateAllInputPrivacy(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        return mergePrivacyConstraintsFromInput(variableCPInstruction, executionContext, (CPOperand[]) variableCPInstruction.getInputs().toArray(new CPOperand[0]), variableCPInstruction.getOutput());
    }

    private static Instruction propagateFirstInputPrivacy(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        return propagateInputPrivacy(variableCPInstruction, executionContext, variableCPInstruction.getInput1(), variableCPInstruction.getOutput());
    }

    private static Instruction propagateSecondInputPrivacy(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        return propagateInputPrivacy(variableCPInstruction, executionContext, variableCPInstruction.getInput2(), variableCPInstruction.getOutput());
    }

    private static Instruction propagateInputPrivacy(Instruction instruction, ExecutionContext executionContext, CPOperand cPOperand, CPOperand cPOperand2) {
        PrivacyConstraint inputPrivacyConstraint = getInputPrivacyConstraint(executionContext, cPOperand);
        if (inputPrivacyConstraint != null) {
            instruction.setPrivacyConstraint(inputPrivacyConstraint);
            if (cPOperand2 != null) {
                cPOperand2.setPrivacyConstraint(inputPrivacyConstraint);
            }
        }
        return instruction;
    }

    private static PrivacyConstraint getInputPrivacyConstraint(ExecutionContext executionContext, CPOperand cPOperand) {
        Data variable;
        if (cPOperand == null || cPOperand.getName() == null || (variable = executionContext.getVariable(cPOperand.getName())) == null) {
            return null;
        }
        return variable.getPrivacyConstraint();
    }

    private static PrivacyConstraint[] getInputPrivacyConstraints(ExecutionContext executionContext, CPOperand[] cPOperandArr) {
        if (cPOperandArr == null || cPOperandArr.length <= 0) {
            return null;
        }
        boolean z = false;
        PrivacyConstraint[] privacyConstraintArr = new PrivacyConstraint[cPOperandArr.length];
        for (int i = 0; i < cPOperandArr.length; i++) {
            privacyConstraintArr[i] = getInputPrivacyConstraint(executionContext, cPOperandArr[i]);
            if (privacyConstraintArr[i] != null) {
                z = true;
            }
        }
        if (z) {
            return privacyConstraintArr;
        }
        return null;
    }

    private static void setOutputPrivacyConstraint(ExecutionContext executionContext, PrivacyConstraint privacyConstraint, String str) {
        Data variable;
        if (privacyConstraint == null || (variable = executionContext.getVariable(str)) == null) {
            return;
        }
        variable.setPrivacyConstraints(privacyConstraint);
        executionContext.setVariable(str, variable);
    }

    public static void postProcessInstruction(Instruction instruction, ExecutionContext executionContext) {
        List<CPOperand> outputOperands = getOutputOperands(instruction);
        if (outputOperands.isEmpty()) {
            return;
        }
        for (CPOperand cPOperand : outputOperands) {
            PrivacyConstraint privacyConstraint = cPOperand.getPrivacyConstraint();
            if (privacyConstraintActivated(privacyConstraint)) {
                setOutputPrivacyConstraint(executionContext, privacyConstraint, cPOperand.getName());
            }
        }
    }

    private static boolean privacyConstraintActivated(PrivacyConstraint privacyConstraint) {
        return privacyConstraint != null && (privacyConstraint.getPrivacyLevel() == PrivacyConstraint.PrivacyLevel.Private || privacyConstraint.getPrivacyLevel() == PrivacyConstraint.PrivacyLevel.PrivateAggregation);
    }

    private static String[] getOutputVariableName(Instruction instruction) {
        String[] strArr = null;
        if (instruction instanceof MultiReturnParameterizedBuiltinCPInstruction) {
            strArr = ((MultiReturnParameterizedBuiltinCPInstruction) instruction).getOutputNames();
        } else if (instruction instanceof MultiReturnBuiltinCPInstruction) {
            strArr = ((MultiReturnBuiltinCPInstruction) instruction).getOutputNames();
        } else if (instruction instanceof ComputationCPInstruction) {
            strArr = new String[]{((ComputationCPInstruction) instruction).getOutputVariableName()};
        } else if (instruction instanceof VariableCPInstruction) {
            strArr = new String[]{((VariableCPInstruction) instruction).getOutputVariableName()};
        } else if (instruction instanceof SqlCPInstruction) {
            strArr = new String[]{((SqlCPInstruction) instruction).getOutputVariableName()};
        }
        return strArr;
    }

    private static List<CPOperand> getOutputOperands(Instruction instruction) {
        return instruction instanceof MultiReturnParameterizedBuiltinCPInstruction ? ((MultiReturnParameterizedBuiltinCPInstruction) instruction).getOutputs() : instruction instanceof MultiReturnBuiltinCPInstruction ? ((MultiReturnBuiltinCPInstruction) instruction).getOutputs() : instruction instanceof ComputationCPInstruction ? getSingletonList(((ComputationCPInstruction) instruction).getOutput()) : instruction instanceof VariableCPInstruction ? getSingletonList(((VariableCPInstruction) instruction).getOutput()) : instruction instanceof SqlCPInstruction ? getSingletonList(((SqlCPInstruction) instruction).getOutput()) : new ArrayList();
    }

    private static List<CPOperand> getSingletonList(CPOperand cPOperand) {
        return cPOperand != null ? new ArrayList(Collections.singletonList(cPOperand)) : new ArrayList();
    }
}
