From eced15bdad18b6683190997590e2500a332b03e7 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 12:39:21 -0700 Subject: [PATCH] Make Enforcer modality-aware (#124) * Add modality-aware enforcer. * Add example. * Backward compatible. * Decorator. * Make it the default Enforcer. * Fix backward compatibility. * Remove previous tuple-aware mechanism. * Fix example. * Fix typo. * Comment. * Fix type annotation. * Move recursion to __call__(). * Simplify condition experession. * Type annotation. * Test Enforcer(). * Comment. * Add tuple tests. * Strengthen data type checking. * Add logic checking. --- mart/attack/enforcer.py | 58 +++++++++++++----- mart/configs/attack/enforcer/default.yaml | 2 +- .../object_detection_rgb_mask_adversary.yaml | 22 +++++++ tests/test_enforcer.py | 59 ++++++++++++++++++- 4 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 mart/configs/attack/object_detection_rgb_mask_adversary.yaml diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 4d4a1364..babc44e6 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -11,6 +11,8 @@ import torch +__all__ = ["Enforcer"] + class ConstraintViolated(Exception): pass @@ -19,16 +21,12 @@ class ConstraintViolated(Exception): class Constraint(abc.ABC): def __call__( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor, *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], ) -> None: - if isinstance(input_adv, tuple): - for input_adv_i, input_i, target_i in zip(input_adv, input, target): - self.verify(input_adv_i, input=input_i, target=target_i) - else: - self.verify(input_adv, input=input, target=target) + self.verify(input_adv, input=input, target=target) @abc.abstractmethod def verify( @@ -97,17 +95,45 @@ def verify(self, input_adv, *, input, target): class Enforcer: - def __init__(self, constraints: dict[str, Constraint] | None = None) -> None: - self.constraints = constraints or {} + def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: + self.modality_constraints = modality_constraints @torch.no_grad() + def _enforce( + self, + input_adv: torch.Tensor, + *, + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], + modality: str, + ): + for constraint in self.modality_constraints[modality].values(): + constraint(input_adv, input=input, target=target) + def __call__( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + target: torch.Tensor | dict[str, Any], + modality: str = "constraints", **kwargs, - ) -> None: - for constraint in self.constraints.values(): - constraint(input_adv, input=input, target=target) + ): + assert type(input_adv) == type(input) + + if isinstance(input_adv, torch.Tensor): + # Finally we can verify constraints on tensor, per its modality. + # Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities. + self._enforce(input_adv, input=input, target=target, modality=modality) + elif isinstance(input_adv, dict): + # The dict input has modalities specified in keys, passing them recursively. + for modality in input_adv: + self(input_adv[modality], input=input[modality], target=target, modality=modality) + elif isinstance(input_adv, (list, tuple)): + # We assume a modality-dictionary only contains tensors, but not list/tuple. + assert modality == "constraints" + # The list or tuple input is a collection of sub-input and sub-target. + for input_adv_i, input_i, target_i in zip(input_adv, input, target): + self(input_adv_i, input=input_i, target=target_i, modality=modality) + else: + raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.") diff --git a/mart/configs/attack/enforcer/default.yaml b/mart/configs/attack/enforcer/default.yaml index 885e0860..94ed9441 100644 --- a/mart/configs/attack/enforcer/default.yaml +++ b/mart/configs/attack/enforcer/default.yaml @@ -1,4 +1,4 @@ defaults: - constraints: null -_target_: mart.attack.enforcer.Enforcer +_target_: mart.attack.Enforcer diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml new file mode 100644 index 00000000..a2bb039e --- /dev/null +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -0,0 +1,22 @@ +defaults: + - iterative_sgd + - perturber: batch + - perturber/initializer: constant + - perturber/gradient_modifier: sign + - perturber/projector: mask_range + - callbacks: [progress_bar, image_visualizer] + - objective: zero_ap + - gain: rcnn_training_loss + - composer: overlay + - enforcer: default + - enforcer/constraints@enforcer.rgb: [mask, pixel_range] + +# Make a 5-step attack for the demonstration purpose. +optimizer: + lr: 55 + +max_iters: 5 + +perturber: + initializer: + constant: 127 diff --git a/tests/test_enforcer.py b/tests/test_enforcer.py index 78053b53..2c56b3ad 100644 --- a/tests/test_enforcer.py +++ b/tests/test_enforcer.py @@ -7,7 +7,7 @@ import pytest import torch -from mart.attack.enforcer import ConstraintViolated, Integer, Lp, Mask, Range +from mart.attack.enforcer import ConstraintViolated, Enforcer, Integer, Lp, Mask, Range def test_constraint_range(): @@ -67,3 +67,60 @@ def test_constraint_mask(): constraint(input + perturbation * mask, input=input, target=target) with pytest.raises(ConstraintViolated): constraint(input + perturbation, input=input, target=target) + + +def test_enforcer_non_modality(): + enforcer = Enforcer(constraints={"range": Range(min=0, max=255)}) + + input = torch.tensor([0, 0, 0]) + perturbation = torch.tensor([0, 128, 255]) + input_adv = input + perturbation + target = None + + # tensor input. + enforcer(input_adv, input=input, target=target) + # list of tensor input. + enforcer([input_adv], input=[input], target=[target]) + # tuple of tensor input. + enforcer((input_adv,), input=(input,), target=(target,)) + + perturbation = torch.tensor([0, -1, 255]) + input_adv = input + perturbation + + with pytest.raises(ConstraintViolated): + enforcer(input_adv, input=input, target=target) + + with pytest.raises(ConstraintViolated): + enforcer([input_adv], input=[input], target=[target]) + + with pytest.raises(ConstraintViolated): + enforcer((input_adv,), input=(input,), target=(target,)) + + +def test_enforcer_modality(): + # Assume a rgb modality. + enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) + + input = torch.tensor([0, 0, 0]) + perturbation = torch.tensor([0, 128, 255]) + input_adv = input + perturbation + target = None + + # Dictionary input. + enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) + # List of dictionary input. + enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) + # Tuple of dictionary input. + enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) + + perturbation = torch.tensor([0, -1, 255]) + input_adv = input + perturbation + + with pytest.raises(ConstraintViolated): + enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) + + with pytest.raises(ConstraintViolated): + enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) + + with pytest.raises(ConstraintViolated): + enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,))