@@ -33,6 +33,17 @@ def test_simple_input_conv(self) -> None:
3333            net , net .conv1 , inp , [[[[11.25 , 13.5 ], [20.25 , 22.5 ]]]]
3434        )
3535
36+     def  test_simple_input_conv_split_channels (self ) ->  None :
37+         net  =  BasicModel_ConvNet_One_Conv ()
38+         inp  =  torch .arange (16 ).view (1 , 1 , 4 , 4 ).float ()
39+         expected_result  =  [[[[- 3.7500 , 3.0000 ],
40+                              [23.2500 , 30.0000 ]],
41+                             [[15.0000 , 10.5000 ],
42+                              [- 3.0000 , - 7.5000 ]]]]
43+         self ._grad_cam_test_assert (
44+             net , net .conv1 , inp , expected_activation = expected_result , split_channels = True 
45+         )
46+ 
3647    def  test_simple_input_conv_no_grad (self ) ->  None :
3748        net  =  BasicModel_ConvNet_One_Conv ()
3849
@@ -100,6 +111,7 @@ def _grad_cam_test_assert(
100111        additional_input : Any  =  None ,
101112        attribute_to_layer_input : bool  =  False ,
102113        relu_attributions : bool  =  False ,
114+         split_channels : bool  =  False ,
103115    ):
104116        layer_gc  =  LayerGradCam (model , target_layer )
105117        self .assertFalse (layer_gc .multiplies_by_inputs )
@@ -109,6 +121,7 @@ def _grad_cam_test_assert(
109121            additional_forward_args = additional_input ,
110122            attribute_to_layer_input = attribute_to_layer_input ,
111123            relu_attributions = relu_attributions ,
124+             split_channels = split_channels ,
112125        )
113126        assertTensorTuplesAlmostEqual (
114127            self , attributions , expected_activation , delta = 0.01 
0 commit comments