55import argparse
66import torch
77import platform
8+ import importlib
9+ import subprocess
810
911gpu_arch_ver = os .getenv ("GPU_ARCH_VER" )
1012gpu_arch_type = os .getenv ("GPU_ARCH_TYPE" )
1416SCRIPT_DIR = Path (__file__ ).parent
1517NIGHTLY_ALLOWED_DELTA = 3
1618
19+ MODULES = [
20+ {
21+ "name" : "torchvision" ,
22+ "repo" : "https://github.com/pytorch/vision.git" ,
23+ "smoke_test" : "python ./vision/test/smoke_test.py" ,
24+ "extension" : "extension" ,
25+ },
26+ {
27+ "name" : "torchaudio" ,
28+ "repo" : "https://github.com/pytorch/audio.git" ,
29+ "smoke_test" : "python ./audio/test/smoke_test/smoke_test.py --no-ffmpeg" ,
30+ "extension" : "_extension" ,
31+ },
32+ ]
33+
1734def check_nightly_binaries_date (package : str ) -> None :
1835 from datetime import datetime , timedelta
1936 format_dt = '%Y%m%d'
@@ -27,33 +44,16 @@ def check_nightly_binaries_date(package: str) -> None:
2744 )
2845
2946 if (package == "all" ):
30- import torchaudio
31- import torchvision
32- ta_str = torchaudio .__version__
33- tv_str = torchvision .__version__
34- date_ta_str = re .findall ("dev\d+" , torchaudio .__version__ )
35- date_tv_str = re .findall ("dev\d+" , torchvision .__version__ )
36- date_ta_delta = datetime .now () - datetime .strptime (date_ta_str [0 ][3 :], format_dt )
37- date_tv_delta = datetime .now () - datetime .strptime (date_tv_str [0 ][3 :], format_dt )
38-
39- # check that the above three lists are equal and none of them is empty
40- if date_ta_delta .days > NIGHTLY_ALLOWED_DELTA or date_tv_delta .days > NIGHTLY_ALLOWED_DELTA :
41- raise RuntimeError (
42- f"Expected torchaudio, torchvision to be less then { NIGHTLY_ALLOWED_DELTA } days. But they are from { date_ta_str } , { date_tv_str } respectively"
43- )
44-
45- def check_cuda_version (version : str , dlibary : str ):
46- version = torch .ops .torchaudio .cuda_version ()
47- if version is not None and torch .version .cuda is not None :
48- version_str = str (version )
49- ta_version = f"{ version_str [:- 3 ]} .{ version_str [- 2 ]} "
50- t_version = torch .version .cuda .split ("." )
51- t_version = f"{ t_version [0 ]} .{ t_version [1 ]} "
52- if ta_version != t_version :
53- raise RuntimeError (
54- "Detected that PyTorch and {dlibary} were compiled with different CUDA versions. "
55- f"PyTorch has CUDA version { t_version } whereas { dlibary } has CUDA version { ta_version } . "
56- )
47+ for module in MODULES :
48+ imported_module = importlib .import_module (module ["name" ])
49+ module_version = imported_module .__version__
50+ date_m_str = re .findall ("dev\d+" , module_version )
51+ date_m_delta = datetime .now () - datetime .strptime (date_m_str [0 ][3 :], format_dt )
52+ print (f"Nightly date check for { module ['name' ]} version { module_version } " )
53+ if date_m_delta .days > NIGHTLY_ALLOWED_DELTA :
54+ raise RuntimeError (
55+ f"Expected { module ['name' ]} to be less then { NIGHTLY_ALLOWED_DELTA } days. But its { date_m_delta } "
56+ )
5757
5858def smoke_test_cuda (package : str ) -> None :
5959 if not torch .cuda .is_available () and is_cuda_system :
@@ -69,12 +69,15 @@ def smoke_test_cuda(package: str) -> None:
6969 print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
7070
7171 if (package == 'all' and is_cuda_system ):
72- import torchaudio
73- import torchvision
74- print (f"torchvision cuda: { torch .ops .torchvision ._cuda_version ()} " )
75- print (f"torchaudio cuda: { torch .ops .torchaudio .cuda_version ()} " )
76- check_cuda_version (torch .ops .torchvision ._cuda_version (), "TorchVision" )
77- check_cuda_version (torch .ops .torchaudio .cuda_version (), "TorchAudio" )
72+ for module in MODULES :
73+ imported_module = importlib .import_module (module ["name" ])
74+ # TBD for vision move extension module to private so it will
75+ # be _extention. For audio add version return from the check
76+ if module ["extension" ] == "extension" :
77+ version = imported_module .extension ._check_cuda_version ()
78+ print (f"{ module ['name' ]} CUDA: { version } " )
79+ else :
80+ imported_module ._extension ._check_cuda_version ()
7881
7982
8083def smoke_test_conv2d () -> None :
@@ -97,67 +100,20 @@ def smoke_test_conv2d() -> None:
97100 out = conv (x )
98101
99102
100- def smoke_test_torchvision () -> None :
101- print (
102- "Is torchvision useable?" ,
103- all (
104- x is not None
105- for x in [torch .ops .image .decode_png , torch .ops .torchvision .roi_align ]
106- ),
107- )
108-
109-
110- def smoke_test_torchvision_read_decode () -> None :
111- from torchvision .io import read_image
112-
113- img_jpg = read_image (str (SCRIPT_DIR / "assets" / "rgb_pytorch.jpg" ))
114- if img_jpg .ndim != 3 or img_jpg .numel () < 100 :
115- raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
116- img_png = read_image (str (SCRIPT_DIR / "assets" / "rgb_pytorch.png" ))
117- if img_png .ndim != 3 or img_png .numel () < 100 :
118- raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
119-
120-
121- def smoke_test_torchvision_resnet50_classify (device : str = "cpu" ) -> None :
122- from torchvision .io import read_image
123- from torchvision .models import resnet50 , ResNet50_Weights
124-
125- img = read_image (str (SCRIPT_DIR / "assets" / "dog2.jpg" )).to (device )
126-
127- # Step 1: Initialize model with the best available weights
128- weights = ResNet50_Weights .DEFAULT
129- model = resnet50 (weights = weights ).to (device )
130- model .eval ()
131-
132- # Step 2: Initialize the inference transforms
133- preprocess = weights .transforms ()
134-
135- # Step 3: Apply inference preprocessing transforms
136- batch = preprocess (img ).unsqueeze (0 )
137-
138- # Step 4: Use the model and print the predicted category
139- prediction = model (batch ).squeeze (0 ).softmax (0 )
140- class_id = prediction .argmax ().item ()
141- score = prediction [class_id ].item ()
142- category_name = weights .meta ["categories" ][class_id ]
143- expected_category = "German shepherd"
144- print (f"{ category_name } : { 100 * score :.1f} %" )
145- if category_name != expected_category :
146- raise RuntimeError (
147- f"Failed ResNet50 classify { category_name } Expected: { expected_category } "
148- )
149-
150-
151- def smoke_test_torchaudio () -> None :
152- import torchaudio
153- import torchaudio .compliance .kaldi # noqa: F401
154- import torchaudio .datasets # noqa: F401
155- import torchaudio .functional # noqa: F401
156- import torchaudio .models # noqa: F401
157- import torchaudio .pipelines # noqa: F401
158- import torchaudio .sox_effects # noqa: F401
159- import torchaudio .transforms # noqa: F401
160- import torchaudio .utils # noqa: F401
103+ def smoke_test_modules ():
104+ for module in MODULES :
105+ if module ["repo" ]:
106+ subprocess .check_output (f"git clone --depth 1 { module ['repo' ]} " , stderr = subprocess .STDOUT , shell = True )
107+ try :
108+ output = subprocess .check_output (
109+ module ["smoke_test" ], stderr = subprocess .STDOUT , shell = True ,
110+ universal_newlines = True )
111+ except subprocess .CalledProcessError as exc :
112+ raise RuntimeError (
113+ f"Module { module ['name' ]} FAIL: { exc .returncode } Output: { exc .output } "
114+ )
115+ else :
116+ print ("Output: \n {}\n " .format (output ))
161117
162118
163119def main () -> None :
@@ -171,25 +127,16 @@ def main() -> None:
171127 )
172128 options = parser .parse_args ()
173129 print (f"torch: { torch .__version__ } " )
174-
175130 smoke_test_cuda (options .package )
176131 smoke_test_conv2d ()
177132
133+ if options .package == "all" :
134+ smoke_test_modules ()
135+
178136 # only makes sense to check nightly package where dates are known
179137 if installation_str .find ("nightly" ) != - 1 :
180138 check_nightly_binaries_date (options .package )
181139
182- if options .package == "all" :
183- import torchaudio
184- import torchvision
185- print (f"torchvision: { torchvision .__version__ } " )
186- print (f"torchaudio: { torchaudio .__version__ } " )
187- smoke_test_torchaudio ()
188- smoke_test_torchvision ()
189- smoke_test_torchvision_read_decode ()
190- smoke_test_torchvision_resnet50_classify ()
191- if torch .cuda .is_available ():
192- smoke_test_torchvision_resnet50_classify ("cuda" )
193140
194141if __name__ == "__main__" :
195142 main ()
0 commit comments