@@ -81,9 +81,10 @@ def test_clip_resnet50x4_image_load_and_forward(self) -> None:
8181 + " insufficient Torch version."
8282 )
8383 x = torch .zeros (1 , 3 , 288 , 288 )
84- model = clip_resnet50x4_image (pretrained = True )
84+ model = clip_resnet50x4_image (pretrained = True , use_attnpool = True )
8585 output = model (x )
8686 self .assertEqual (list (output .shape ), [1 , 640 ])
87+ self .assertTrue (model .use_attnpool )
8788
8889 def test_untrained_clip_resnet50x4_image_load_and_forward (self ) -> None :
8990 if version .parse (torch .__version__ ) <= version .parse ("1.6.0" ):
@@ -92,9 +93,10 @@ def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
9293 + " insufficient Torch version."
9394 )
9495 x = torch .zeros (1 , 3 , 288 , 288 )
95- model = clip_resnet50x4_image (pretrained = False )
96+ model = clip_resnet50x4_image (pretrained = False , use_attnpool = True )
9697 output = model (x )
9798 self .assertEqual (list (output .shape ), [1 , 640 ])
99+ self .assertTrue (model .use_attnpool )
98100
99101 def test_clip_resnet50x4_image_warning (self ) -> None :
100102 if version .parse (torch .__version__ ) <= version .parse ("1.6.0" ):
@@ -109,6 +111,30 @@ def test_clip_resnet50x4_image_warning(self) -> None:
109111 with self .assertWarns (UserWarning ):
110112 _ = model ._transform_input (x )
111113
114+ def test_clip_resnet50x4_image_use_attnpool_false (self ) -> None :
115+ if version .parse (torch .__version__ ) <= version .parse ("1.6.0" ):
116+ raise unittest .SkipTest (
117+ "Skipping basic pretrained CLIP ResNet 50x4 Image use_attnpool"
118+ + " forward due to insufficient Torch version."
119+ )
120+ x = torch .zeros (1 , 3 , 288 , 288 )
121+ model = clip_resnet50x4_image (pretrained = True , use_attnpool = False )
122+ output = model (x )
123+ self .assertEqual (list (output .shape ), [1 , 2560 , 9 , 9 ])
124+ self .assertFalse (model .use_attnpool )
125+
126+ def test_clip_resnet50x4_image_use_attnpool_false_size_128 (self ) -> None :
127+ if version .parse (torch .__version__ ) <= version .parse ("1.6.0" ):
128+ raise unittest .SkipTest (
129+ "Skipping basic pretrained CLIP ResNet 50x4 Image use_attnpool"
130+ + " forward with 128x128 input due to insufficient Torch version."
131+ )
132+ x = torch .zeros (1 , 3 , 128 , 128 )
133+ model = clip_resnet50x4_image (pretrained = True , use_attnpool = False )
134+ output = model (x )
135+ self .assertEqual (list (output .shape ), [1 , 2560 , 4 , 4 ])
136+ self .assertFalse (model .use_attnpool )
137+
112138 def test_clip_resnet50x4_image_forward_cuda (self ) -> None :
113139 if version .parse (torch .__version__ ) <= version .parse ("1.6.0" ):
114140 raise unittest .SkipTest (
@@ -121,11 +147,12 @@ def test_clip_resnet50x4_image_forward_cuda(self) -> None:
121147 + " not supporting CUDA."
122148 )
123149 x = torch .zeros (1 , 3 , 288 , 288 ).cuda ()
124- model = clip_resnet50x4_image (pretrained = True ).cuda ()
150+ model = clip_resnet50x4_image (pretrained = True , use_attnpool = True ).cuda ()
125151 output = model (x )
126152
127153 self .assertTrue (output .is_cuda )
128154 self .assertEqual (list (output .shape ), [1 , 640 ])
155+ self .assertTrue (model .use_attnpool )
129156
130157 def test_clip_resnet50x4_image_jit_module_no_redirected_relu (self ) -> None :
131158 if version .parse (torch .__version__ ) <= version .parse ("1.8.0" ):
@@ -135,11 +162,12 @@ def test_clip_resnet50x4_image_jit_module_no_redirected_relu(self) -> None:
135162 )
136163 x = torch .zeros (1 , 3 , 288 , 288 )
137164 model = clip_resnet50x4_image (
138- pretrained = True , replace_relus_with_redirectedrelu = False
165+ pretrained = True , replace_relus_with_redirectedrelu = False , use_attnpool = True
139166 )
140167 jit_model = torch .jit .script (model )
141168 output = jit_model (x )
142169 self .assertEqual (list (output .shape ), [1 , 640 ])
170+ self .assertTrue (model .use_attnpool )
143171
144172 def test_clip_resnet50x4_image_jit_module_with_redirected_relu (self ) -> None :
145173 if version .parse (torch .__version__ ) <= version .parse ("1.8.0" ):
@@ -149,8 +177,9 @@ def test_clip_resnet50x4_image_jit_module_with_redirected_relu(self) -> None:
149177 )
150178 x = torch .zeros (1 , 3 , 288 , 288 )
151179 model = clip_resnet50x4_image (
152- pretrained = True , replace_relus_with_redirectedrelu = True
180+ pretrained = True , replace_relus_with_redirectedrelu = True , use_attnpool = True
153181 )
154182 jit_model = torch .jit .script (model )
155183 output = jit_model (x )
156184 self .assertEqual (list (output .shape ), [1 , 640 ])
185+ self .assertTrue (model .use_attnpool )
0 commit comments