Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Enforcer modality-aware #124

Merged
merged 19 commits into from
Mar 30, 2023
Merged
52 changes: 36 additions & 16 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

__all__ = ["Enforcer"]


class ConstraintViolated(Exception):
pass
Expand All @@ -19,16 +21,12 @@ class ConstraintViolated(Exception):
class Constraint(abc.ABC):
def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> None:
if isinstance(input_adv, tuple):
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self.verify(input_adv_i, input=input_i, target=target_i)
else:
self.verify(input_adv, input=input, target=target)
self.verify(input_adv, input=input, target=target)
dxoigmn marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
def verify(
Expand Down Expand Up @@ -97,17 +95,39 @@ def verify(self, input_adv, *, input, target):


class Enforcer:
def __init__(self, constraints: dict[str, Constraint] | None = None) -> None:
self.constraints = constraints or {}
def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None:
self.modality_constraints = modality_constraints

@torch.no_grad()
def _enforce(
self,
input_adv: torch.Tensor,
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
modality: str,
):
for constraint in self.modality_constraints[modality].values():
constraint(input_adv, input=input, target=target)

def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
target: torch.Tensor | dict[str, Any],
modality: str = "constraints",
**kwargs,
) -> None:
for constraint in self.constraints.values():
constraint(input_adv, input=input, target=target)
):
if isinstance(input_adv, torch.Tensor):
# Finally we can verify constraints on tensor, per its modality.
# Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities.
self._enforce(input_adv, input=input, target=target, modality=modality)
elif isinstance(input_adv, dict):
# The dict input has modalities specified in keys, passing them recursively.
for modality in input_adv:
self(input_adv[modality], input=input[modality], target=target, modality=modality)
elif isinstance(input_adv, (list, tuple)):
# The list or tuple input is a collection of sub-input and sub-target.
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self(input_adv_i, input=input_i, target=target_i)
dxoigmn marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion mart/configs/attack/enforcer/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- constraints: null

_target_: mart.attack.enforcer.Enforcer
_target_: mart.attack.Enforcer
22 changes: 22 additions & 0 deletions mart/configs/attack/object_detection_rgb_mask_adversary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defaults:
- iterative_sgd
- perturber: batch
- perturber/initializer: constant
- perturber/gradient_modifier: sign
- perturber/projector: mask_range
- callbacks: [progress_bar, image_visualizer]
- objective: zero_ap
- gain: rcnn_training_loss
- composer: overlay
- enforcer: default
- enforcer/constraints@enforcer.rgb: [mask, pixel_range]

# Make a 5-step attack for the demonstration purpose.
optimizer:
lr: 55

max_iters: 5

perturber:
initializer:
constant: 127
49 changes: 48 additions & 1 deletion tests/test_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from mart.attack.enforcer import ConstraintViolated, Integer, Lp, Mask, Range
from mart.attack.enforcer import ConstraintViolated, Enforcer, Integer, Lp, Mask, Range


def test_constraint_range():
Expand Down Expand Up @@ -67,3 +67,50 @@ def test_constraint_mask():
constraint(input + perturbation * mask, input=input, target=target)
with pytest.raises(ConstraintViolated):
constraint(input + perturbation, input=input, target=target)


def test_enforcer_non_modality():
enforcer = Enforcer(constraints={"range": Range(min=0, max=255)})

input = torch.tensor([0, 0, 0])
perturbation = torch.tensor([0, 128, 255])
input_adv = input + perturbation
target = None

# tensor input.
enforcer(input_adv, input=input, target=target)
# list of tensor input.
enforcer([input_adv], input=[input], target=[target])

perturbation = torch.tensor([0, -1, 255])
input_adv = input + perturbation

with pytest.raises(ConstraintViolated):
enforcer(input_adv, input=input, target=target)

with pytest.raises(ConstraintViolated):
enforcer([input_adv], input=[input], target=[target])
dxoigmn marked this conversation as resolved.
Show resolved Hide resolved


def test_enforcer_modality():
# Assume a rgb modality.
enforcer = Enforcer(rgb={"range": Range(min=0, max=255)})

input = torch.tensor([0, 0, 0])
perturbation = torch.tensor([0, 128, 255])
input_adv = input + perturbation
target = None

# Dictionary input.
enforcer({"rgb": input_adv}, input={"rgb": input}, target=target)
# List of dictionary input.
enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target])
dxoigmn marked this conversation as resolved.
Show resolved Hide resolved

perturbation = torch.tensor([0, -1, 255])
input_adv = input + perturbation

with pytest.raises(ConstraintViolated):
enforcer({"rgb": input_adv}, input={"rgb": input}, target=target)

with pytest.raises(ConstraintViolated):
enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target])