Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Enforcer modality-aware #124

Merged
merged 19 commits into from
Mar 30, 2023
Merged
58 changes: 42 additions & 16 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

__all__ = ["Enforcer"]


class ConstraintViolated(Exception):
pass
Expand All @@ -19,16 +21,12 @@ class ConstraintViolated(Exception):
class Constraint(abc.ABC):
def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> None:
if isinstance(input_adv, tuple):
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self.verify(input_adv_i, input=input_i, target=target_i)
else:
self.verify(input_adv, input=input, target=target)
self.verify(input_adv, input=input, target=target)
dxoigmn marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
def verify(
Expand Down Expand Up @@ -97,17 +95,45 @@ def verify(self, input_adv, *, input, target):


class Enforcer:
def __init__(self, constraints: dict[str, Constraint] | None = None) -> None:
self.constraints = constraints or {}
def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None:
self.modality_constraints = modality_constraints

@torch.no_grad()
def _enforce(
self,
input_adv: torch.Tensor,
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
modality: str,
):
for constraint in self.modality_constraints[modality].values():
constraint(input_adv, input=input, target=target)

def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
target: torch.Tensor | dict[str, Any],
modality: str = "constraints",
**kwargs,
) -> None:
for constraint in self.constraints.values():
constraint(input_adv, input=input, target=target)
):
assert type(input_adv) == type(input)

if isinstance(input_adv, torch.Tensor):
# Finally we can verify constraints on tensor, per its modality.
# Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities.
self._enforce(input_adv, input=input, target=target, modality=modality)
elif isinstance(input_adv, dict):
# The dict input has modalities specified in keys, passing them recursively.
for modality in input_adv:
self(input_adv[modality], input=input[modality], target=target, modality=modality)
elif isinstance(input_adv, (list, tuple)):
# We assume a modality-dictionary only contains tensors, but not list/tuple.
assert modality == "constraints"
# The list or tuple input is a collection of sub-input and sub-target.
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self(input_adv_i, input=input_i, target=target_i, modality=modality)
else:
raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.")
2 changes: 1 addition & 1 deletion mart/configs/attack/enforcer/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- constraints: null

_target_: mart.attack.enforcer.Enforcer
_target_: mart.attack.Enforcer
22 changes: 22 additions & 0 deletions mart/configs/attack/object_detection_rgb_mask_adversary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defaults:
- iterative_sgd
- perturber: batch
- perturber/initializer: constant
- perturber/gradient_modifier: sign
- perturber/projector: mask_range
- callbacks: [progress_bar, image_visualizer]
- objective: zero_ap
- gain: rcnn_training_loss
- composer: overlay
- enforcer: default
- enforcer/constraints@enforcer.rgb: [mask, pixel_range]

# Make a 5-step attack for the demonstration purpose.
optimizer:
lr: 55

max_iters: 5

perturber:
initializer:
constant: 127
59 changes: 58 additions & 1 deletion tests/test_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from mart.attack.enforcer import ConstraintViolated, Integer, Lp, Mask, Range
from mart.attack.enforcer import ConstraintViolated, Enforcer, Integer, Lp, Mask, Range


def test_constraint_range():
Expand Down Expand Up @@ -67,3 +67,60 @@ def test_constraint_mask():
constraint(input + perturbation * mask, input=input, target=target)
with pytest.raises(ConstraintViolated):
constraint(input + perturbation, input=input, target=target)


def test_enforcer_non_modality():
enforcer = Enforcer(constraints={"range": Range(min=0, max=255)})

input = torch.tensor([0, 0, 0])
perturbation = torch.tensor([0, 128, 255])
input_adv = input + perturbation
target = None

# tensor input.
enforcer(input_adv, input=input, target=target)
# list of tensor input.
enforcer([input_adv], input=[input], target=[target])
# tuple of tensor input.
enforcer((input_adv,), input=(input,), target=(target,))

perturbation = torch.tensor([0, -1, 255])
input_adv = input + perturbation

with pytest.raises(ConstraintViolated):
enforcer(input_adv, input=input, target=target)

with pytest.raises(ConstraintViolated):
enforcer([input_adv], input=[input], target=[target])

with pytest.raises(ConstraintViolated):
enforcer((input_adv,), input=(input,), target=(target,))


def test_enforcer_modality():
# Assume a rgb modality.
enforcer = Enforcer(rgb={"range": Range(min=0, max=255)})

input = torch.tensor([0, 0, 0])
perturbation = torch.tensor([0, 128, 255])
input_adv = input + perturbation
target = None

# Dictionary input.
enforcer({"rgb": input_adv}, input={"rgb": input}, target=target)
# List of dictionary input.
enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target])
dxoigmn marked this conversation as resolved.
Show resolved Hide resolved
# Tuple of dictionary input.
enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,))

perturbation = torch.tensor([0, -1, 255])
input_adv = input + perturbation

with pytest.raises(ConstraintViolated):
enforcer({"rgb": input_adv}, input={"rgb": input}, target=target)

with pytest.raises(ConstraintViolated):
enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target])

with pytest.raises(ConstraintViolated):
enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,))