Skip to content

Commit

Permalink
Remove BatchComposer in favor of batch-aware Composer (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
dxoigmn authored Mar 29, 2023
1 parent ac61b48 commit c36c71d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 54 deletions.
62 changes: 19 additions & 43 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,72 +11,48 @@

import torch

__all__ = ["BatchComposer"]


class Composer(torch.nn.Module, abc.ABC):
@abc.abstractclassmethod
def forward(
self,
perturbation: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
) -> torch.Tensor | tuple:
raise NotImplementedError


class BatchComposer(Composer):
def __init__(self, composer: Composer):
super().__init__()

self.composer = composer

def forward(
class Composer(abc.ABC):
def __call__(
self,
perturbation: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
**kwargs,
) -> torch.Tensor | tuple:
output = []

for input_i, target_i, perturbation_i in zip(input, target, perturbation):
output_i = self.composer(perturbation_i, input=input_i, target=target_i, **kwargs)
output.append(output_i)

if isinstance(input, torch.Tensor):
output = torch.stack(output)
if isinstance(perturbation, tuple):
input_adv = tuple(
self.compose(perturbation_i, input=input_i, target=target_i)
for perturbation_i, input_i, target_i in zip(perturbation, input, target)
)
else:
output = tuple(output)

return output
input_adv = self.compose(perturbation, input=input, target=target)

return input_adv

class Additive(Composer):
"""We assume an adversary adds perturbation to the input."""

def forward(
@abc.abstractmethod
def compose(
self,
perturbation: torch.Tensor,
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> torch.Tensor:
raise NotImplementedError


class Additive(Composer):
"""We assume an adversary adds perturbation to the input."""

def compose(self, perturbation, *, input, target):
return input + perturbation


class Overlay(Composer):
"""We assume an adversary overlays a patch to the input."""

def forward(
self,
perturbation: torch.Tensor,
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> torch.Tensor:
def compose(self, perturbation, *, input, target):
# True is mutable, False is immutable.
mask = target["perturbable_mask"]

Expand Down
5 changes: 0 additions & 5 deletions mart/configs/attack/composer/batch_additive.yaml

This file was deleted.

5 changes: 0 additions & 5 deletions mart/configs/attack/composer/batch_overlay.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion mart/configs/attack/object_detection_mask_adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:
- callbacks: [progress_bar, image_visualizer]
- objective: zero_ap
- gain: rcnn_training_loss
- composer: batch_overlay
- composer: overlay
- enforcer: batch
- enforcer/constraints: [mask, pixel_range]

Expand Down

0 comments on commit c36c71d

Please sign in to comment.