Skip to content

Commit

Permalink
Remove BatchEnforcer in favor of batch-aware Enforcer (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
dxoigmn authored Mar 29, 2023
1 parent c36c71d commit 70e62d9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 27 deletions.
56 changes: 34 additions & 22 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,27 @@ class ConstraintViolated(Exception):


class Constraint(abc.ABC):
@abc.abstractclassmethod
def __call__(self, input_adv, *, input, target) -> None:
def __call__(
self,
input_adv: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
) -> None:
if isinstance(input_adv, tuple):
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self.verify(input_adv_i, input=input_i, target=target_i)
else:
self.verify(input_adv, input=input, target=target)

@abc.abstractmethod
def verify(
self,
input_adv: torch.Tensor,
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> None:
raise NotImplementedError


Expand All @@ -27,19 +46,21 @@ def __init__(self, min, max):
self.min = min
self.max = max

def __call__(self, input_adv, *, input, target):
def verify(self, input_adv, *, input, target):
if torch.any(input_adv < self.min) or torch.any(input_adv > self.max):
raise ConstraintViolated(f"Adversarial input is outside [{self.min}, {self.max}].")


class Lp(Constraint):
def __init__(self, eps: float, p: int | float | None = torch.inf, dim=None, keepdim=False):
def __init__(
self, eps: float, p: int | float = torch.inf, dim: int | None = None, keepdim: bool = False
):
self.p = p
self.eps = eps
self.dim = dim
self.keepdim = keepdim

def __call__(self, input_adv, *, input, target):
def verify(self, input_adv, *, input, target):
perturbation = input_adv - input
norm_vals = perturbation.norm(p=self.p, dim=self.dim, keepdim=self.keepdim)
norm_max = norm_vals.max()
Expand All @@ -50,20 +71,20 @@ def __call__(self, input_adv, *, input, target):


class Integer(Constraint):
def __init__(self, rtol=0, atol=0, equal_nan=False):
def __init__(self, rtol: float = 0.0, atol: float = 0.0, equal_nan: bool = False):
self.rtol = rtol
self.atol = atol
self.equal_nan = equal_nan

def __call__(self, input_adv, *, input, target):
def verify(self, input_adv, *, input, target):
if not torch.isclose(
input_adv, input_adv.round(), rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan
).all():
raise ConstraintViolated("The adversarial example is not in the integer domain.")


class Mask(Constraint):
def __call__(self, input_adv, *, input, target):
def verify(self, input_adv, *, input, target):
# True/1 is mutable, False/0 is immutable.
# mask.shape=(H, W)
mask = target["perturbable_mask"]
Expand All @@ -76,26 +97,17 @@ def __call__(self, input_adv, *, input, target):


class Enforcer:
def __init__(self, constraints=None) -> None:
def __init__(self, constraints: dict[str, Constraint] | None = None) -> None:
self.constraints = constraints or {}

def _check_constraints(self, input_adv, *, input, target):
for constraint in self.constraints.values():
constraint(input_adv, input=input, target=target)

@torch.no_grad()
def __call__(self, input_adv, *, input, target):
self._check_constraints(input_adv, input=input, target=target)


class BatchEnforcer(Enforcer):
@torch.no_grad()
def __call__(
self,
input_adv: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
) -> torch.Tensor | tuple:
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self._check_constraints(input_adv_i, input=input_i, target=target_i)
**kwargs,
) -> None:
for constraint in self.constraints.values():
constraint(input_adv, input=input, target=target)
4 changes: 0 additions & 4 deletions mart/configs/attack/enforcer/batch.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion mart/configs/attack/object_detection_mask_adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defaults:
- objective: zero_ap
- gain: rcnn_training_loss
- composer: overlay
- enforcer: batch
- enforcer: default
- enforcer/constraints: [mask, pixel_range]

# Make a 5-step attack for the demonstration purpose.
Expand Down

0 comments on commit 70e62d9

Please sign in to comment.