Skip to content

Commit

Permalink
Make Enforcer modality-aware (#124)
Browse files Browse the repository at this point in the history
* Add modality-aware enforcer.

* Add example.

* Backward compatible.

* Decorator.

* Make it the default Enforcer.

* Fix backward compatibility.

* Remove previous tuple-aware mechanism.

* Fix example.

* Fix typo.

* Comment.

* Fix type annotation.

* Move recursion to __call__().

* Simplify condition experession.

* Type annotation.

* Test Enforcer().

* Comment.

* Add tuple tests.

* Strengthen data type checking.

* Add logic checking.
  • Loading branch information
mzweilin committed Mar 30, 2023
1 parent 70e62d9 commit eced15b
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 18 deletions.
58 changes: 42 additions & 16 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

__all__ = ["Enforcer"]


class ConstraintViolated(Exception):
pass
Expand All @@ -19,16 +21,12 @@ class ConstraintViolated(Exception):
class Constraint(abc.ABC):
def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> None:
if isinstance(input_adv, tuple):
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self.verify(input_adv_i, input=input_i, target=target_i)
else:
self.verify(input_adv, input=input, target=target)
self.verify(input_adv, input=input, target=target)

@abc.abstractmethod
def verify(
Expand Down Expand Up @@ -97,17 +95,45 @@ def verify(self, input_adv, *, input, target):


class Enforcer:
def __init__(self, constraints: dict[str, Constraint] | None = None) -> None:
self.constraints = constraints or {}
def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None:
self.modality_constraints = modality_constraints

@torch.no_grad()
def _enforce(
self,
input_adv: torch.Tensor,
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
modality: str,
):
for constraint in self.modality_constraints[modality].values():
constraint(input_adv, input=input, target=target)

def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
target: torch.Tensor | dict[str, Any],
modality: str = "constraints",
**kwargs,
) -> None:
for constraint in self.constraints.values():
constraint(input_adv, input=input, target=target)
):
assert type(input_adv) == type(input)

if isinstance(input_adv, torch.Tensor):
# Finally we can verify constraints on tensor, per its modality.
# Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities.
self._enforce(input_adv, input=input, target=target, modality=modality)
elif isinstance(input_adv, dict):
# The dict input has modalities specified in keys, passing them recursively.
for modality in input_adv:
self(input_adv[modality], input=input[modality], target=target, modality=modality)
elif isinstance(input_adv, (list, tuple)):
# We assume a modality-dictionary only contains tensors, but not list/tuple.
assert modality == "constraints"
# The list or tuple input is a collection of sub-input and sub-target.
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self(input_adv_i, input=input_i, target=target_i, modality=modality)
else:
raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.")
2 changes: 1 addition & 1 deletion mart/configs/attack/enforcer/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- constraints: null

_target_: mart.attack.enforcer.Enforcer
_target_: mart.attack.Enforcer
22 changes: 22 additions & 0 deletions mart/configs/attack/object_detection_rgb_mask_adversary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defaults:
- iterative_sgd
- perturber: batch
- perturber/initializer: constant
- perturber/gradient_modifier: sign
- perturber/projector: mask_range
- callbacks: [progress_bar, image_visualizer]
- objective: zero_ap
- gain: rcnn_training_loss
- composer: overlay
- enforcer: default
- enforcer/constraints@enforcer.rgb: [mask, pixel_range]

# Make a 5-step attack for the demonstration purpose.
optimizer:
lr: 55

max_iters: 5

perturber:
initializer:
constant: 127
59 changes: 58 additions & 1 deletion tests/test_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from mart.attack.enforcer import ConstraintViolated, Integer, Lp, Mask, Range
from mart.attack.enforcer import ConstraintViolated, Enforcer, Integer, Lp, Mask, Range


def test_constraint_range():
Expand Down Expand Up @@ -67,3 +67,60 @@ def test_constraint_mask():
constraint(input + perturbation * mask, input=input, target=target)
with pytest.raises(ConstraintViolated):
constraint(input + perturbation, input=input, target=target)


def test_enforcer_non_modality():
enforcer = Enforcer(constraints={"range": Range(min=0, max=255)})

input = torch.tensor([0, 0, 0])
perturbation = torch.tensor([0, 128, 255])
input_adv = input + perturbation
target = None

# tensor input.
enforcer(input_adv, input=input, target=target)
# list of tensor input.
enforcer([input_adv], input=[input], target=[target])
# tuple of tensor input.
enforcer((input_adv,), input=(input,), target=(target,))

perturbation = torch.tensor([0, -1, 255])
input_adv = input + perturbation

with pytest.raises(ConstraintViolated):
enforcer(input_adv, input=input, target=target)

with pytest.raises(ConstraintViolated):
enforcer([input_adv], input=[input], target=[target])

with pytest.raises(ConstraintViolated):
enforcer((input_adv,), input=(input,), target=(target,))


def test_enforcer_modality():
# Assume a rgb modality.
enforcer = Enforcer(rgb={"range": Range(min=0, max=255)})

input = torch.tensor([0, 0, 0])
perturbation = torch.tensor([0, 128, 255])
input_adv = input + perturbation
target = None

# Dictionary input.
enforcer({"rgb": input_adv}, input={"rgb": input}, target=target)
# List of dictionary input.
enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target])
# Tuple of dictionary input.
enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,))

perturbation = torch.tensor([0, -1, 255])
input_adv = input + perturbation

with pytest.raises(ConstraintViolated):
enforcer({"rgb": input_adv}, input={"rgb": input}, target=target)

with pytest.raises(ConstraintViolated):
enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target])

with pytest.raises(ConstraintViolated):
enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,))

0 comments on commit eced15b

Please sign in to comment.