package org.apache.sysds.runtime.privacy.propagation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListAppendRemoveCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.privacy.PrivacyUtils;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.class */
public class PrivacyPropagator {
    public static Data parseAndSetPrivacyConstraint(Data data, JSONObject jSONObject) throws JSONException {
        PrivacyConstraint parseAndReturnPrivacyConstraint = parseAndReturnPrivacyConstraint(jSONObject);
        if (parseAndReturnPrivacyConstraint != null) {
            data.setPrivacyConstraints(parseAndReturnPrivacyConstraint);
        }
        return data;
    }

    public static PrivacyConstraint parseAndReturnPrivacyConstraint(JSONObject jSONObject) throws JSONException {
        String string;
        if (!jSONObject.containsKey(DataExpression.PRIVACY) || (string = jSONObject.getString(DataExpression.PRIVACY)) == null) {
            return null;
        }
        return new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.valueOf(string));
    }

    private static boolean anyInputHasLevel(PrivacyConstraint.PrivacyLevel[] privacyLevelArr, PrivacyConstraint.PrivacyLevel privacyLevel) {
        return Arrays.stream(privacyLevelArr).anyMatch(privacyLevel2 -> {
            return privacyLevel2 == privacyLevel;
        });
    }

    public static PrivacyConstraint.PrivacyLevel corePropagation(PrivacyConstraint.PrivacyLevel[] privacyLevelArr, OperatorType operatorType) {
        return anyInputHasLevel(privacyLevelArr, PrivacyConstraint.PrivacyLevel.Private) ? PrivacyConstraint.PrivacyLevel.Private : operatorType == OperatorType.Aggregate ? PrivacyConstraint.PrivacyLevel.None : (operatorType == OperatorType.NonAggregate && anyInputHasLevel(privacyLevelArr, PrivacyConstraint.PrivacyLevel.PrivateAggregation)) ? PrivacyConstraint.PrivacyLevel.PrivateAggregation : PrivacyConstraint.PrivacyLevel.None;
    }

    private static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraintArr, OperatorType operatorType) {
        return new PrivacyConstraint(corePropagation((PrivacyConstraint.PrivacyLevel[]) Arrays.stream(privacyConstraintArr).map(privacyConstraint -> {
            return privacyConstraint != null ? privacyConstraint.getPrivacyLevel() : PrivacyConstraint.PrivacyLevel.None;
        }).toArray(i -> {
            return new PrivacyConstraint.PrivacyLevel[i];
        }), operatorType));
    }

    public static PrivacyConstraint mergeBinary(PrivacyConstraint privacyConstraint, PrivacyConstraint privacyConstraint2) {
        if (privacyConstraint != null && privacyConstraint2 != null) {
            return new PrivacyConstraint(corePropagation(new PrivacyConstraint.PrivacyLevel[]{privacyConstraint.getPrivacyLevel(), privacyConstraint2.getPrivacyLevel()}, OperatorType.NonAggregate));
        }
        if (privacyConstraint != null) {
            return privacyConstraint;
        }
        if (privacyConstraint2 != null) {
            return privacyConstraint2;
        }
        return null;
    }

    public static void hopPropagation(Hop hop) {
        PrivacyConstraint[] privacyConstraintArr = (PrivacyConstraint[]) hop.getInput().stream().map((v0) -> {
            return v0.getPrivacy();
        }).toArray(i -> {
            return new PrivacyConstraint[i];
        });
        if ((hop instanceof TernaryOp) || (hop instanceof BinaryOp) || (hop instanceof ReorgOp)) {
            hop.setPrivacy(mergeNary(privacyConstraintArr, OperatorType.NonAggregate));
        } else if ((hop instanceof AggBinaryOp) || (hop instanceof AggUnaryOp) || (hop instanceof UnaryOp)) {
            hop.setPrivacy(mergeNary(privacyConstraintArr, OperatorType.Aggregate));
        }
    }

    public static void postProcessInstruction(Instruction instruction, ExecutionContext executionContext) {
        List<CPOperand> outputOperands = getOutputOperands(instruction);
        if (outputOperands.isEmpty()) {
            return;
        }
        for (CPOperand cPOperand : outputOperands) {
            PrivacyConstraint privacyConstraint = cPOperand.getPrivacyConstraint();
            if (PrivacyUtils.someConstraintSetUnary(privacyConstraint)) {
                setOutputPrivacyConstraint(executionContext, privacyConstraint, cPOperand.getName());
            }
        }
    }

    public static Instruction preprocessInstruction(Instruction instruction, ExecutionContext executionContext) {
        switch (instruction.getType()) {
            case CONTROL_PROGRAM:
                return preprocessCPInstruction((CPInstruction) instruction, executionContext);
            case BREAKPOINT:
            case SPARK:
            case GPU:
            case FEDERATED:
                return instruction;
            default:
                return throwExceptionIfInputOrInstPrivacy(instruction, executionContext);
        }
    }

    private static Instruction preprocessCPInstruction(CPInstruction cPInstruction, ExecutionContext executionContext) {
        switch (cPInstruction.getCPInstructionType()) {
            case Binary:
            case Builtin:
            case BuiltinNary:
            case FCall:
            case ParameterizedBuiltin:
            case Quaternary:
            case Reorg:
            case Ternary:
            case Unary:
            case MultiReturnBuiltin:
            case MultiReturnParameterizedBuiltin:
            case MatrixIndexing:
                return mergePrivacyConstraintsFromInput(cPInstruction, executionContext, OperatorType.NonAggregate);
            case AggregateTernary:
            case AggregateUnary:
                return mergePrivacyConstraintsFromInput(cPInstruction, executionContext, OperatorType.Aggregate);
            case Append:
                return preprocessAppendCPInstruction((AppendCPInstruction) cPInstruction, executionContext);
            case AggregateBinary:
                return cPInstruction instanceof AggregateBinaryCPInstruction ? preprocessAggregateBinaryCPInstruction((AggregateBinaryCPInstruction) cPInstruction, executionContext) : throwExceptionIfInputOrInstPrivacy(cPInstruction, executionContext);
            case MMTSJ:
                return mergePrivacyConstraintsFromInput(cPInstruction, executionContext, OperatorType.getAggregationType((MMTSJCPInstruction) cPInstruction, executionContext));
            case MMChain:
                return mergePrivacyConstraintsFromInput(cPInstruction, executionContext, OperatorType.getAggregationType((MMChainCPInstruction) cPInstruction, executionContext));
            case Variable:
                return preprocessVariableCPInstruction((VariableCPInstruction) cPInstruction, executionContext);
            default:
                return throwExceptionIfInputOrInstPrivacy(cPInstruction, executionContext);
        }
    }

    private static Instruction preprocessVariableCPInstruction(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        switch (variableCPInstruction.getVariableOpcode()) {
            case CopyVariable:
            case MoveVariable:
            case RemoveVariableAndFile:
            case CastAsMatrixVariable:
            case CastAsFrameVariable:
            case Write:
            case SetFileName:
            case CastAsScalarVariable:
            case CastAsDoubleVariable:
            case CastAsIntegerVariable:
            case CastAsBooleanVariable:
                return propagateFirstInputPrivacy(variableCPInstruction, executionContext);
            case CreateVariable:
                return propagateSecondInputPrivacy(variableCPInstruction, executionContext);
            case AssignVariable:
            case RemoveVariable:
                return mergePrivacyConstraintsFromInput(variableCPInstruction, executionContext, OperatorType.NonAggregate);
            case Read:
                return variableCPInstruction;
            default:
                return throwExceptionIfInputOrInstPrivacy(variableCPInstruction, executionContext);
        }
    }

    private static Instruction preprocessAggregateBinaryCPInstruction(AggregateBinaryCPInstruction aggregateBinaryCPInstruction, ExecutionContext executionContext) {
        PrivacyConstraint mergeNary;
        PrivacyConstraint[] inputPrivacyConstraints = getInputPrivacyConstraints(executionContext, aggregateBinaryCPInstruction.getInputs());
        if (PrivacyUtils.someConstraintSetBinary(inputPrivacyConstraints)) {
            if ((inputPrivacyConstraints[0] == null || !inputPrivacyConstraints[0].hasFineGrainedConstraints()) && (inputPrivacyConstraints[1] == null || !inputPrivacyConstraints[1].hasFineGrainedConstraints())) {
                mergeNary = mergeNary(inputPrivacyConstraints, OperatorType.getAggregationType(aggregateBinaryCPInstruction, executionContext));
                aggregateBinaryCPInstruction.setPrivacyConstraint(mergeNary);
            } else {
                mergeNary = new MatrixMultiplicationPropagatorPrivateFirst(executionContext.getMatrixInput(aggregateBinaryCPInstruction.input1.getName()), inputPrivacyConstraints[0], executionContext.getMatrixInput(aggregateBinaryCPInstruction.input2.getName()), inputPrivacyConstraints[1]).propagate();
                executionContext.releaseMatrixInput(aggregateBinaryCPInstruction.input1.getName(), aggregateBinaryCPInstruction.input2.getName());
            }
            aggregateBinaryCPInstruction.output.setPrivacyConstraint(mergeNary);
        }
        return aggregateBinaryCPInstruction;
    }

    private static Instruction preprocessAppendCPInstruction(AppendCPInstruction appendCPInstruction, ExecutionContext executionContext) {
        AppendPropagator cBindPropagator;
        PrivacyConstraint[] inputPrivacyConstraints = getInputPrivacyConstraints(executionContext, appendCPInstruction.getInputs());
        if (PrivacyUtils.someConstraintSetBinary(inputPrivacyConstraints)) {
            if (appendCPInstruction.getAppendType() == AppendCPInstruction.AppendType.STRING) {
                appendCPInstruction.output.setPrivacyConstraint(new PrivacyConstraint(corePropagation(new PrivacyConstraint.PrivacyLevel[]{PrivacyUtils.getGeneralPrivacyLevel(inputPrivacyConstraints[0]), PrivacyUtils.getGeneralPrivacyLevel(inputPrivacyConstraints[1])}, OperatorType.NonAggregate)));
            } else if (appendCPInstruction.getAppendType() == AppendCPInstruction.AppendType.LIST) {
                ListObject listObject = (ListObject) executionContext.getVariable(appendCPInstruction.input1);
                if (appendCPInstruction.getOpcode().equals("remove")) {
                    ScalarObject scalarInput = executionContext.getScalarInput(appendCPInstruction.input2);
                    PrivacyConstraint[] propagate = new ListRemovePropagator(listObject, inputPrivacyConstraints[0], scalarInput, scalarInput.getPrivacyConstraint()).propagate();
                    appendCPInstruction.output.setPrivacyConstraint(propagate[0]);
                    ((ListAppendRemoveCPInstruction) appendCPInstruction).getOutput2().setPrivacyConstraint(propagate[1]);
                } else {
                    appendCPInstruction.output.setPrivacyConstraint(new ListAppendPropagator(listObject, inputPrivacyConstraints[0], (ListObject) executionContext.getVariable(appendCPInstruction.input2), inputPrivacyConstraints[1]).propagate());
                }
            } else {
                MatrixBlock matrixInput = executionContext.getMatrixInput(appendCPInstruction.input1.getName());
                MatrixBlock matrixInput2 = executionContext.getMatrixInput(appendCPInstruction.input2.getName());
                if (appendCPInstruction.getAppendType() == AppendCPInstruction.AppendType.RBIND) {
                    cBindPropagator = new RBindPropagator(matrixInput, inputPrivacyConstraints[0], matrixInput2, inputPrivacyConstraints[1]);
                } else {
                    if (appendCPInstruction.getAppendType() != AppendCPInstruction.AppendType.CBIND) {
                        throw new DMLPrivacyException("Instruction " + appendCPInstruction.getCPInstructionType() + " with append type " + appendCPInstruction.getAppendType() + " is not supported by the privacy propagator");
                    }
                    cBindPropagator = new CBindPropagator(matrixInput, inputPrivacyConstraints[0], matrixInput2, inputPrivacyConstraints[1]);
                }
                appendCPInstruction.output.setPrivacyConstraint(cBindPropagator.propagate());
                executionContext.releaseMatrixInput(appendCPInstruction.input1.getName(), appendCPInstruction.input2.getName());
            }
        }
        return appendCPInstruction;
    }

    private static Instruction mergePrivacyConstraintsFromInput(Instruction instruction, ExecutionContext executionContext, OperatorType operatorType) {
        return mergePrivacyConstraintsFromInput(instruction, executionContext, getInputOperands(instruction), getOutputOperands(instruction), operatorType);
    }

    private static Instruction mergePrivacyConstraintsFromInput(Instruction instruction, ExecutionContext executionContext, CPOperand[] cPOperandArr, List<CPOperand> list, OperatorType operatorType) {
        PrivacyConstraint[] inputPrivacyConstraints;
        if (cPOperandArr != null && cPOperandArr.length > 0 && (inputPrivacyConstraints = getInputPrivacyConstraints(executionContext, cPOperandArr)) != null) {
            PrivacyConstraint mergeNary = mergeNary(inputPrivacyConstraints, operatorType);
            instruction.setPrivacyConstraint(mergeNary);
            for (CPOperand cPOperand : list) {
                if (cPOperand != null) {
                    cPOperand.setPrivacyConstraint(mergeNary);
                }
            }
        }
        return instruction;
    }

    private static Instruction throwExceptionIfInputOrInstPrivacy(Instruction instruction, ExecutionContext executionContext) {
        throwExceptionIfPrivacyActivated(instruction);
        CPOperand[] inputOperands = getInputOperands(instruction);
        if (inputOperands != null) {
            for (CPOperand cPOperand : inputOperands) {
                if (getInputPrivacyConstraint(executionContext, cPOperand) != null) {
                    throw new DMLPrivacyException("Input of instruction " + instruction + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
                }
            }
        }
        return instruction;
    }

    private static void throwExceptionIfPrivacyActivated(Instruction instruction) {
        if (instruction.getPrivacyConstraint() != null && instruction.getPrivacyConstraint().hasConstraints()) {
            throw new DMLPrivacyException("Instruction " + instruction + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
        }
    }

    private static Instruction propagateFirstInputPrivacy(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        return propagateInputPrivacy(variableCPInstruction, executionContext, variableCPInstruction.getInput1(), variableCPInstruction.getOutput());
    }

    private static Instruction propagateSecondInputPrivacy(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        return propagateInputPrivacy(variableCPInstruction, executionContext, variableCPInstruction.getInput2(), variableCPInstruction.getOutput());
    }

    private static Instruction propagateInputPrivacy(Instruction instruction, ExecutionContext executionContext, CPOperand cPOperand, CPOperand cPOperand2) {
        PrivacyConstraint inputPrivacyConstraint = getInputPrivacyConstraint(executionContext, cPOperand);
        if (inputPrivacyConstraint != null) {
            instruction.setPrivacyConstraint(inputPrivacyConstraint);
            if (cPOperand2 != null) {
                cPOperand2.setPrivacyConstraint(inputPrivacyConstraint);
            }
        }
        return instruction;
    }

    private static PrivacyConstraint getInputPrivacyConstraint(ExecutionContext executionContext, CPOperand cPOperand) {
        Data variable;
        if (cPOperand == null || cPOperand.getName() == null || (variable = executionContext.getVariable(cPOperand.getName())) == null) {
            return null;
        }
        return variable.getPrivacyConstraint();
    }

    private static PrivacyConstraint[] getInputPrivacyConstraints(ExecutionContext executionContext, CPOperand[] cPOperandArr) {
        if (cPOperandArr == null || cPOperandArr.length <= 0) {
            return null;
        }
        boolean z = false;
        PrivacyConstraint[] privacyConstraintArr = new PrivacyConstraint[cPOperandArr.length];
        for (int i = 0; i < cPOperandArr.length; i++) {
            privacyConstraintArr[i] = getInputPrivacyConstraint(executionContext, cPOperandArr[i]);
            if (privacyConstraintArr[i] != null) {
                z = true;
            }
        }
        if (z) {
            return privacyConstraintArr;
        }
        return null;
    }

    private static void setOutputPrivacyConstraint(ExecutionContext executionContext, PrivacyConstraint privacyConstraint, String str) {
        Data variable;
        if (privacyConstraint == null || (variable = executionContext.getVariable(str)) == null) {
            return;
        }
        variable.setPrivacyConstraints(privacyConstraint);
        executionContext.setVariable(str, variable);
    }

    private static CPOperand[] getInputOperands(Instruction instruction) {
        if (instruction instanceof ComputationCPInstruction) {
            return ((ComputationCPInstruction) instruction).getInputs();
        }
        if (instruction instanceof BuiltinNaryCPInstruction) {
            return ((BuiltinNaryCPInstruction) instruction).getInputs();
        }
        if (instruction instanceof FunctionCallCPInstruction) {
            return ((FunctionCallCPInstruction) instruction).getInputs();
        }
        if (instruction instanceof SqlCPInstruction) {
            return ((SqlCPInstruction) instruction).getInputs();
        }
        return null;
    }

    private static List<CPOperand> getOutputOperands(Instruction instruction) {
        return instruction instanceof MultiReturnParameterizedBuiltinCPInstruction ? ((MultiReturnParameterizedBuiltinCPInstruction) instruction).getOutputs() : instruction instanceof MultiReturnBuiltinCPInstruction ? ((MultiReturnBuiltinCPInstruction) instruction).getOutputs() : instruction instanceof ComputationCPInstruction ? getSingletonList(((ComputationCPInstruction) instruction).getOutput()) : instruction instanceof VariableCPInstruction ? getSingletonList(((VariableCPInstruction) instruction).getOutput()) : instruction instanceof SqlCPInstruction ? getSingletonList(((SqlCPInstruction) instruction).getOutput()) : instruction instanceof BuiltinNaryCPInstruction ? getSingletonList(((BuiltinNaryCPInstruction) instruction).getOutput()) : new ArrayList();
    }

    private static List<CPOperand> getSingletonList(CPOperand cPOperand) {
        return cPOperand != null ? new ArrayList(Collections.singletonList(cPOperand)) : new ArrayList();
    }
}
