44import unittest .mock
55from datetime import datetime
66from os import path
7+ from urllib .error import HTTPError
78from urllib .parse import urlparse
89from urllib .request import urlopen , Request
910
@@ -86,25 +87,26 @@ def retry(fn, times=1, wait=5.0):
8687 )
8788
8889
89- def assert_server_response_ok (response , url = None ):
90- msg = f"The server returned status code { response .code } "
91- if url is not None :
92- msg += f"for the the URL { url } "
93- assert 200 <= response .code < 300 , msg
90+ @contextlib .contextmanager
91+ def assert_server_response_ok ():
92+ try :
93+ yield
94+ except HTTPError as error :
95+ raise AssertionError (f"The server returned { error .code } : { error .reason } ." ) from error
9496
9597
9698def assert_url_is_accessible (url ):
9799 request = Request (url , headers = dict (method = "HEAD" ))
98- response = urlopen ( request )
99- assert_server_response_ok ( response , url )
100+ with assert_server_response_ok ():
101+ urlopen ( request )
100102
101103
102104def assert_file_downloads_correctly (url , md5 ):
103105 with get_tmp_dir () as root :
104106 file = path .join (root , path .basename (url ))
105- with urlopen ( url ) as response , open ( file , "wb" ) as fh :
106- assert_server_response_ok ( response , url )
107- fh .write (response .read ())
107+ with assert_server_response_ok () :
108+ with urlopen ( url ) as response , open ( file , "wb" ) as fh :
109+ fh .write (response .read ())
108110
109111 assert check_integrity (file , md5 = md5 ), "The MD5 checksums mismatch"
110112
@@ -125,6 +127,16 @@ def make_download_configs(urls_and_md5s, name=None):
125127 ]
126128
127129
130+ def collect_download_configs (dataset_loader , name ):
131+ try :
132+ with log_download_attempts () as urls_and_md5s :
133+ dataset_loader ()
134+ except Exception :
135+ pass
136+
137+ return make_download_configs (urls_and_md5s , name )
138+
139+
128140def places365 ():
129141 with log_download_attempts (patch = False ) as urls_and_md5s :
130142 for split , small in itertools .product (("train-standard" , "train-challenge" , "val" ), (False , True )):
@@ -137,23 +149,19 @@ def places365():
137149
138150
139151def caltech101 ():
140- try :
141- with log_download_attempts () as urls_and_md5s :
142- datasets .Caltech101 ("." , download = True )
143- except Exception :
144- pass
145-
146- return make_download_configs (urls_and_md5s , "Caltech101" )
152+ return collect_download_configs (lambda : datasets .Caltech101 ("." , download = True ), "Caltech101" )
147153
148154
149155def caltech256 ():
150- try :
151- with log_download_attempts () as urls_and_md5s :
152- datasets .Caltech256 ("." , download = True )
153- except Exception :
154- pass
156+ return collect_download_configs (lambda : datasets .Caltech256 ("." , download = True ), "Caltech256" )
157+
158+
159+ def cifar10 ():
160+ return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), "CIFAR10" )
161+
155162
156- return make_download_configs (urls_and_md5s , "Caltech256" )
163+ def cifar100 ():
164+ return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), "CIFAR100" )
157165
158166
159167def make_parametrize_kwargs (download_configs ):
@@ -166,7 +174,9 @@ def make_parametrize_kwargs(download_configs):
166174 return dict (argnames = ("url" , "md5" ), argvalues = argvalues , ids = ids )
167175
168176
169- @pytest .mark .parametrize (** make_parametrize_kwargs (itertools .chain (places365 (), caltech101 (), caltech256 ())))
177+ @pytest .mark .parametrize (
178+ ** make_parametrize_kwargs (itertools .chain (places365 (), caltech101 (), caltech256 (), cifar10 (), cifar100 ()))
179+ )
170180def test_url_is_accessible (url , md5 ):
171181 retry (lambda : assert_url_is_accessible (url ))
172182
0 commit comments