22
22
"""
23
23
from __future__ import annotations
24
24
25
+ import math
26
+ from typing import TYPE_CHECKING , Callable
27
+
25
28
import numpy as np
26
- from typing import TYPE_CHECKING
29
+
27
30
if TYPE_CHECKING :
31
+ import torch
28
32
from art .utils import CLASSIFIER_NEURALNETWORK_TYPE
29
33
30
34
31
35
from art .defences .detector .evasion .evasion_detector import EvasionDetector
32
36
33
- class BeyondDetector (EvasionDetector ):
37
+
38
+ class BeyondDetectorPyTorch (EvasionDetector ):
34
39
"""
35
40
BEYOND detector for adversarial samples detection.
36
41
This detector uses a combination of SSL and target model predictions to detect adversarial examples.
37
-
42
+
38
43
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
39
44
"""
40
-
45
+
41
46
defence_params = ["target_model" , "ssl_model" , "augmentations" , "aug_num" , "alpha" , "K" , "percentile" ]
42
47
43
- def __init__ (self ,
44
- target_model : "CLASSIFIER_NEURALNETWORK_TYPE" ,
45
- ssl_model : "CLASSIFIER_NEURALNETWORK_TYPE" ,
48
+ def __init__ (
49
+ self ,
50
+ target_classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
51
+ ssl_classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
46
52
augmentations : Callable | None ,
47
- aug_num : int = 50 ,
48
- alpha : float = 0.8 ,
49
- K :int = 20 ,
50
- percentile :int = 5 ) -> None :
53
+ aug_num : int = 50 ,
54
+ alpha : float = 0.8 ,
55
+ K : int = 20 ,
56
+ percentile : int = 5 ,
57
+ ) -> None :
51
58
"""
52
59
Initialize the BEYOND detector.
53
60
54
- :param target_model : The target model to be protected
55
- :param ssl_model : The self-supervised learning model used for feature extraction
56
- :param augmentation : data augmentations for generating neighborhoods
61
+ :param target_classifier : The target model to be protected
62
+ :param ssl_classifier : The self-supervised learning model used for feature extraction
63
+ :param augmentations : data augmentations for generating neighborhoods
57
64
:param aug_num: Number of augmentations to apply to each sample (default: 50)
58
65
:param alpha: Weight factor for combining label and representation similarities (default: 0.8)
59
66
:param K: Number of top similarities to consider (default: 20)
60
67
:param percentile: using to calculate the threshold
61
68
"""
69
+ import torch
70
+
62
71
super ().__init__ ()
63
72
self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
64
73
65
- self .target_model = target_model .to (self .device )
66
- self .ssl_model = ssl_model .to (self .device )
74
+ self .target_model = target_classifier . model .to (self .device )
75
+ self .ssl_model = ssl_classifier . model .to (self .device )
67
76
self .aug_num = aug_num
68
77
self .alpha = alpha
69
78
self .K = K
70
79
71
- self .backbone = ssl_model .backbone
72
- self .classifier = ssl_model .classifier
73
- self .projector = ssl_model .projector
80
+ self .backbone = self . ssl_model .backbone
81
+ self .model_classifier = self . ssl_model .classifier
82
+ self .projector = self . ssl_model .projector
74
83
75
84
self .img_augmentations = augmentations
76
85
77
- self .percentile = percentile # determinate the threshold
78
- self .threshold = None
86
+ self .percentile = percentile # determine the threshold
87
+ self .threshold : float | None = None
88
+
89
+ def _multi_transform (self , img : "torch.Tensor" ) -> "torch.Tensor" :
90
+ import torch
79
91
80
-
81
-
82
- def _multi_transform (self , img : torch .Tensor ) -> torch .Tensor :
83
92
return torch .stack ([self .img_augmentations (img ) for _ in range (self .aug_num )], dim = 1 )
84
93
85
- def _get_metrics (self , x : np .ndarray , batch_size : int = 128 ) -> tuple [ dict , np .ndarray ] :
94
+ def _get_metrics (self , x : np .ndarray , batch_size : int = 128 ) -> np .ndarray :
86
95
"""
87
96
Calculate similarities that combining label consistency and representation similarity for given samples
88
97
89
98
:param x: Input samples
90
99
:param batch_size: Batch size for processing
91
100
:return: A report similarities
92
101
"""
102
+ import torch
103
+ import torch .nn .functional as F
104
+
93
105
samples = torch .from_numpy (x ).to (self .device )
94
-
106
+
95
107
self .target_model .eval ()
96
108
self .backbone .eval ()
97
- self .classifier .eval ()
109
+ self .model_classifier .eval ()
98
110
self .projector .eval ()
99
111
100
112
number_batch = int (math .ceil (len (samples ) / batch_size ))
101
-
113
+
102
114
similarities = []
103
115
104
116
with torch .no_grad ():
@@ -113,23 +125,31 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.n
113
125
ssl_backbone_out = self .backbone (batch_samples )
114
126
115
127
ssl_repre = self .projector (ssl_backbone_out )
116
- ssl_pred = self .classifier (ssl_backbone_out )
128
+ ssl_pred = self .model_classifier (ssl_backbone_out )
117
129
ssl_label = torch .max (ssl_pred , - 1 )[1 ]
118
130
119
131
aug_backbone_out = self .backbone (trans_images .reshape (- 1 , c , h , w ))
120
132
aug_repre = self .projector (aug_backbone_out )
121
- aug_pred = self .classifier (aug_backbone_out )
133
+ aug_pred = self .model_classifier (aug_backbone_out )
122
134
aug_pred = aug_pred .reshape (b , self .aug_num , - 1 )
123
135
124
- sim_repre = F .cosine_similarity (ssl_repre .unsqueeze (dim = 1 ), aug_repre .reshape (b , self .aug_num , - 1 ), dim = 2 )
125
- sim_preds = F .cosine_similarity (F .one_hot (torch .argmax (ssl_label , dim = 1 ), num_classes = ssl_pred .shape [- 1 ]).unsqueeze (dim = 1 ), aug_pred , dim = 2 )
136
+ sim_repre = F .cosine_similarity (
137
+ ssl_repre .unsqueeze (dim = 1 ), aug_repre .reshape (b , self .aug_num , - 1 ), dim = 2
138
+ )
139
+
140
+ sim_preds = F .cosine_similarity (
141
+ F .one_hot (ssl_label , num_classes = ssl_pred .shape [- 1 ]).unsqueeze (dim = 1 ),
142
+ aug_pred ,
143
+ dim = 2 ,
144
+ )
126
145
127
- similarities .append ((self .alpha * sim_preds + (1 - self .alpha )* sim_repre ).sort (descending = True )[0 ].cpu ().numpy ())
146
+ similarities .append (
147
+ (self .alpha * sim_preds + (1 - self .alpha ) * sim_repre ).sort (descending = True )[0 ].cpu ().numpy ()
148
+ )
128
149
129
150
similarities = np .concatenate (similarities , axis = 0 )
130
-
131
- return similarities
132
151
152
+ return similarities
133
153
134
154
def fit (self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , ** kwargs ) -> None :
135
155
"""
@@ -140,26 +160,26 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
140
160
:param batch_size: Batch size for processing
141
161
:param nb_epochs: Number of training epochs (not used in this method)
142
162
"""
143
- k_minus_one_metrics = clean_metrics [:, self .K - 1 ]
144
-
145
- self .threshold = np .percentile (k_minus_one_metrics , self .threshold )
163
+ clean_metrics = self ._get_metrics ( x = x , batch_size = batch_size )
164
+ k_minus_one_metrics = clean_metrics [:, self . K - 1 ]
165
+ self .threshold = np .percentile (k_minus_one_metrics , q = self .percentile )
146
166
147
167
def detect (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> tuple [dict , np .ndarray ]:
148
168
"""
149
169
Detect whether given samples are adversarial
150
-
170
+
151
171
:param x: Input samples
152
172
:param batch_size: Batch size for processing
153
173
:return: (report, is_adversarial):
154
- where report containing detection results
174
+ where report containing detection results
155
175
where is_adversarial is a boolean list indicating whether samples are adversarial or not
156
176
"""
157
177
if self .threshold is None :
158
178
raise ValueError ("Detector has not been fitted. Call fit() before detect()." )
159
-
179
+
160
180
similarities = self ._get_metrics (x , batch_size )
161
-
162
- report = similarities [:, self .K - 1 ]
181
+
182
+ report = similarities [:, self .K - 1 ]
163
183
is_adversarial = report < self .threshold
164
-
184
+
165
185
return report , is_adversarial
0 commit comments