55
66import torch
77
8- from captum .attr ._core .dataloader_attr import DataloaderAttribution , InputRole
8+ from captum .attr ._core .dataloader_attr import DataLoaderAttribution , InputRole
99from captum .attr ._core .feature_ablation import FeatureAblation
1010from parameterized import parameterized
1111from tests .helpers .basic import (
@@ -75,13 +75,13 @@ class Test(BaseTest):
7575 )
7676 def test_dl_attr (self , forward ) -> None :
7777 fa = FeatureAblation (forward )
78- dl_fa = DataloaderAttribution (fa )
78+ dl_fa = DataLoaderAttribution (fa )
7979
8080 dataloader = DataLoader (mock_dataset , batch_size = 2 )
8181
8282 dl_attributions = dl_fa .attribute (dataloader )
8383
84- # default reduce of DataloaderAttribution works the same as concat all batches
84+ # default reduce of DataLoaderAttribution works the same as concat all batches
8585 attr_list = []
8686 for batch in dataloader :
8787 batch_attr = fa .attribute (tuple (batch ))
@@ -109,13 +109,13 @@ def test_dl_attr_with_mask(self, forward) -> None:
109109 )
110110
111111 fa = FeatureAblation (forward )
112- dl_fa = DataloaderAttribution (fa )
112+ dl_fa = DataLoaderAttribution (fa )
113113
114114 dataloader = DataLoader (mock_dataset , batch_size = 2 )
115115
116116 dl_attributions = dl_fa .attribute (dataloader , feature_mask = masks )
117117
118- # default reduce of DataloaderAttribution works the same as concat all batches
118+ # default reduce of DataLoaderAttribution works the same as concat all batches
119119 attr_list = []
120120 for batch in dataloader :
121121 batch_attr = fa .attribute (tuple (batch ), feature_mask = masks )
@@ -141,13 +141,13 @@ def test_dl_attr_with_baseline(self, forward) -> None:
141141 )
142142
143143 fa = FeatureAblation (forward )
144- dl_fa = DataloaderAttribution (fa )
144+ dl_fa = DataLoaderAttribution (fa )
145145
146146 dataloader = DataLoader (mock_dataset , batch_size = 2 )
147147
148148 dl_attributions = dl_fa .attribute (dataloader , baselines = baselines )
149149
150- # default reduce of DataloaderAttribution works the same as concat all batches
150+ # default reduce of DataLoaderAttribution works the same as concat all batches
151151 attr_list = []
152152 for batch in dataloader :
153153 batch_attr = fa .attribute (tuple (batch ), baselines = baselines )
@@ -188,7 +188,7 @@ def to_metric(accum):
188188 )
189189
190190 fa = FeatureAblation (forward )
191- dl_fa = DataloaderAttribution (fa )
191+ dl_fa = DataLoaderAttribution (fa )
192192
193193 batch_size = 2
194194 dataloader = DataLoader (mock_dataset , batch_size = batch_size )
@@ -243,7 +243,7 @@ def forward(*forward_inputs):
243243 return sum_forward (* forward_inputs )
244244
245245 fa = FeatureAblation (forward )
246- dl_fa = DataloaderAttribution (fa )
246+ dl_fa = DataLoaderAttribution (fa )
247247
248248 batch_size = 2
249249 dataloader = DataLoader (mock_dataset , batch_size = batch_size )
@@ -257,7 +257,7 @@ def forward(*forward_inputs):
257257 # only inputs needs
258258 self .assertEqual (len (dl_attributions ), n_attr_inputs )
259259
260- # default reduce of DataloaderAttribution works the same as concat all batches
260+ # default reduce of DataLoaderAttribution works the same as concat all batches
261261 attr_list = []
262262 for batch in dataloader :
263263 attr_inputs = tuple (
@@ -283,7 +283,7 @@ def forward(*forward_inputs):
283283 def test_dl_attr_not_return_input_shape (self ) -> None :
284284 forward = sum_forward
285285 fa = FeatureAblation (forward )
286- dl_fa = DataloaderAttribution (fa )
286+ dl_fa = DataLoaderAttribution (fa )
287287
288288 dataloader = DataLoader (mock_dataset , batch_size = 2 )
289289
@@ -295,7 +295,7 @@ def test_dl_attr_not_return_input_shape(self) -> None:
295295 dl_attribution = cast (Tensor , dl_attribution )
296296 self .assertEqual (dl_attribution .shape , expected_attr_shape )
297297
298- # default reduce of DataloaderAttribution works the same as concat all batches
298+ # default reduce of DataLoaderAttribution works the same as concat all batches
299299 attr_list = []
300300 for batch in dataloader :
301301 batch_attr = fa .attribute (tuple (batch ))
@@ -321,7 +321,7 @@ def test_dl_attr_with_mask_not_return_input_shape(self) -> None:
321321 )
322322
323323 fa = FeatureAblation (forward )
324- dl_fa = DataloaderAttribution (fa )
324+ dl_fa = DataLoaderAttribution (fa )
325325
326326 dataloader = DataLoader (mock_dataset , batch_size = 2 )
327327
@@ -340,7 +340,7 @@ def test_dl_attr_with_perturb_per_pass(self, perturb_per_pass) -> None:
340340 forward = sum_forward
341341
342342 fa = FeatureAblation (forward )
343- dl_fa = DataloaderAttribution (fa )
343+ dl_fa = DataLoaderAttribution (fa )
344344
345345 mock_dl_iter = Mock (wraps = DataLoader .__iter__ )
346346
@@ -360,7 +360,7 @@ def test_dl_attr_with_perturb_per_pass(self, perturb_per_pass) -> None:
360360 math .ceil (n_features / perturb_per_pass ) + n_iter_overhead ,
361361 )
362362
363- # default reduce of DataloaderAttribution works the same as concat all batches
363+ # default reduce of DataLoaderAttribution works the same as concat all batches
364364 attr_list = []
365365 for batch in dataloader :
366366 batch_attr = fa .attribute (tuple (batch ))
0 commit comments