Skip to content

Commit 5ec99dc

Browse files
committed
Finalize integration of BEYOND detector
Signed-off-by: Beat Buesser <beat.buesser@ibm.com>
1 parent 94c6ced commit 5ec99dc

File tree

6 files changed

+175
-132
lines changed

6 files changed

+175
-132
lines changed

art/defences/detector/evasion/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@
66
from art.defences.detector.evasion.binary_input_detector import BinaryInputDetector
77
from art.defences.detector.evasion.binary_activation_detector import BinaryActivationDetector
88
from art.defences.detector.evasion.subsetscanning.detector import SubsetScanningDetector
9-
from art.defences.detector.evasion.beyond_detector import BeyondDetector
10-
9+
from art.defences.detector.evasion.beyond_detector import BeyondDetectorPyTorch

art/defences/detector/evasion/beyond_detector.py

+64-44
Original file line numberDiff line numberDiff line change
@@ -22,83 +22,95 @@
2222
"""
2323
from __future__ import annotations
2424

25+
import math
26+
from typing import TYPE_CHECKING, Callable
27+
2528
import numpy as np
26-
from typing import TYPE_CHECKING
29+
2730
if TYPE_CHECKING:
31+
import torch
2832
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
2933

3034

3135
from art.defences.detector.evasion.evasion_detector import EvasionDetector
3236

33-
class BeyondDetector(EvasionDetector):
37+
38+
class BeyondDetectorPyTorch(EvasionDetector):
3439
"""
3540
BEYOND detector for adversarial samples detection.
3641
This detector uses a combination of SSL and target model predictions to detect adversarial examples.
37-
42+
3843
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
3944
"""
40-
45+
4146
defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"]
4247

43-
def __init__(self,
44-
target_model: "CLASSIFIER_NEURALNETWORK_TYPE",
45-
ssl_model: "CLASSIFIER_NEURALNETWORK_TYPE",
48+
def __init__(
49+
self,
50+
target_classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
51+
ssl_classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
4652
augmentations: Callable | None,
47-
aug_num: int=50,
48-
alpha: float=0.8,
49-
K:int=20,
50-
percentile:int=5) -> None:
53+
aug_num: int = 50,
54+
alpha: float = 0.8,
55+
K: int = 20,
56+
percentile: int = 5,
57+
) -> None:
5158
"""
5259
Initialize the BEYOND detector.
5360
54-
:param target_model: The target model to be protected
55-
:param ssl_model: The self-supervised learning model used for feature extraction
56-
:param augmentation: data augmentations for generating neighborhoods
61+
:param target_classifier: The target model to be protected
62+
:param ssl_classifier: The self-supervised learning model used for feature extraction
63+
:param augmentations: data augmentations for generating neighborhoods
5764
:param aug_num: Number of augmentations to apply to each sample (default: 50)
5865
:param alpha: Weight factor for combining label and representation similarities (default: 0.8)
5966
:param K: Number of top similarities to consider (default: 20)
6067
:param percentile: using to calculate the threshold
6168
"""
69+
import torch
70+
6271
super().__init__()
6372
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6473

65-
self.target_model = target_model.to(self.device)
66-
self.ssl_model = ssl_model.to(self.device)
74+
self.target_model = target_classifier.model.to(self.device)
75+
self.ssl_model = ssl_classifier.model.to(self.device)
6776
self.aug_num = aug_num
6877
self.alpha = alpha
6978
self.K = K
7079

71-
self.backbone = ssl_model.backbone
72-
self.classifier = ssl_model.classifier
73-
self.projector = ssl_model.projector
80+
self.backbone = self.ssl_model.backbone
81+
self.model_classifier = self.ssl_model.classifier
82+
self.projector = self.ssl_model.projector
7483

7584
self.img_augmentations = augmentations
7685

77-
self.percentile = percentile # determinate the threshold
78-
self.threshold = None
86+
self.percentile = percentile # determine the threshold
87+
self.threshold: float | None = None
88+
89+
def _multi_transform(self, img: "torch.Tensor") -> "torch.Tensor":
90+
import torch
7991

80-
81-
82-
def _multi_transform(self, img: torch.Tensor) -> torch.Tensor:
8392
return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1)
8493

85-
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]:
94+
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray:
8695
"""
8796
Calculate similarities that combining label consistency and representation similarity for given samples
8897
8998
:param x: Input samples
9099
:param batch_size: Batch size for processing
91100
:return: A report similarities
92101
"""
102+
import torch
103+
import torch.nn.functional as F
104+
93105
samples = torch.from_numpy(x).to(self.device)
94-
106+
95107
self.target_model.eval()
96108
self.backbone.eval()
97-
self.classifier.eval()
109+
self.model_classifier.eval()
98110
self.projector.eval()
99111

100112
number_batch = int(math.ceil(len(samples) / batch_size))
101-
113+
102114
similarities = []
103115

104116
with torch.no_grad():
@@ -113,23 +125,31 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.n
113125
ssl_backbone_out = self.backbone(batch_samples)
114126

115127
ssl_repre = self.projector(ssl_backbone_out)
116-
ssl_pred = self.classifier(ssl_backbone_out)
128+
ssl_pred = self.model_classifier(ssl_backbone_out)
117129
ssl_label = torch.max(ssl_pred, -1)[1]
118130

119131
aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w))
120132
aug_repre = self.projector(aug_backbone_out)
121-
aug_pred = self.classifier(aug_backbone_out)
133+
aug_pred = self.model_classifier(aug_backbone_out)
122134
aug_pred = aug_pred.reshape(b, self.aug_num, -1)
123135

124-
sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2)
125-
sim_preds = F.cosine_similarity(F.one_hot(torch.argmax(ssl_label, dim=1), num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1), aug_pred, dim=2)
136+
sim_repre = F.cosine_similarity(
137+
ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2
138+
)
139+
140+
sim_preds = F.cosine_similarity(
141+
F.one_hot(ssl_label, num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1),
142+
aug_pred,
143+
dim=2,
144+
)
126145

127-
similarities.append((self.alpha * sim_preds + (1-self.alpha)*sim_repre).sort(descending=True)[0].cpu().numpy())
146+
similarities.append(
147+
(self.alpha * sim_preds + (1 - self.alpha) * sim_repre).sort(descending=True)[0].cpu().numpy()
148+
)
128149

129150
similarities = np.concatenate(similarities, axis=0)
130-
131-
return similarities
132151

152+
return similarities
133153

134154
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
135155
"""
@@ -140,26 +160,26 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
140160
:param batch_size: Batch size for processing
141161
:param nb_epochs: Number of training epochs (not used in this method)
142162
"""
143-
k_minus_one_metrics = clean_metrics[:, self.K-1]
144-
145-
self.threshold = np.percentile(k_minus_one_metrics, self.threshold)
163+
clean_metrics = self._get_metrics(x=x, batch_size=batch_size)
164+
k_minus_one_metrics = clean_metrics[:, self.K - 1]
165+
self.threshold = np.percentile(k_minus_one_metrics, q=self.percentile)
146166

147167
def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]:
148168
"""
149169
Detect whether given samples are adversarial
150-
170+
151171
:param x: Input samples
152172
:param batch_size: Batch size for processing
153173
:return: (report, is_adversarial):
154-
where report containing detection results
174+
where report containing detection results
155175
where is_adversarial is a boolean list indicating whether samples are adversarial or not
156176
"""
157177
if self.threshold is None:
158178
raise ValueError("Detector has not been fitted. Call fit() before detect().")
159-
179+
160180
similarities = self._get_metrics(x, batch_size)
161-
162-
report = similarities[:, self.K-1]
181+
182+
report = similarities[:, self.K - 1]
163183
is_adversarial = report < self.threshold
164-
184+
165185
return report, is_adversarial

run_tests.sh

+4
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ else
146146
"tests/defences/test_rounded.py" \
147147
"tests/defences/test_thermometer_encoding.py" \
148148
"tests/defences/test_variance_minimization.py" \
149+
"tests/defences/detector/evasion/test_beyond_detector.py" \
150+
"tests/defences/detector/evasion/test_binary_activation_detector.py" \
151+
"tests/defences/detector/evasion/test_binary_input_detector.py" \
152+
"tests/defences/detector/evasion/test_subsetscanning_detector.py" \
149153
"tests/defences/detector/poison/test_activation_defence.py" \
150154
"tests/defences/detector/poison/test_clustering_analyzer.py" \
151155
"tests/defences/detector/poison/test_ground_truth_evaluator.py" \

0 commit comments

Comments
 (0)