22import re
33import sys
44from pathlib import Path
5-
5+ import argparse
66import torch
7- import torchaudio
87
9- # the following import would invoke
10- # _check_cuda_version()
11- # via torchvision.extension._check_cuda_version()
12- import torchvision
8+
139
1410gpu_arch_ver = os .getenv ("GPU_ARCH_VER" )
1511gpu_arch_type = os .getenv ("GPU_ARCH_TYPE" )
@@ -38,20 +34,21 @@ def get_anaconda_output_for_package(pkg_name_str):
3834 return output .strip ().split ('\n ' )[- 1 ]
3935
4036
41- def check_nightly_binaries_date () -> None :
37+ def check_nightly_binaries_date (package : str ) -> None :
4238 torch_str = torch .__version__
43- ta_str = torchaudio .__version__
44- tv_str = torchvision .__version__
45-
4639 date_t_str = re .findall ("dev\d+" , torch .__version__ )
47- date_ta_str = re .findall ("dev\d+" , torchaudio .__version__ )
48- date_tv_str = re .findall ("dev\d+" , torchvision .__version__ )
4940
50- # check that the above three lists are equal and none of them is empty
51- if not date_t_str or not date_t_str == date_ta_str == date_tv_str :
52- raise RuntimeError (
53- f"Expected torch, torchaudio, torchvision to be the same date. But they are from { date_t_str } , { date_ta_str } , { date_tv_str } respectively"
54- )
41+ if (package == "all" ):
42+ ta_str = torchaudio .__version__
43+ tv_str = torchvision .__version__
44+ date_ta_str = re .findall ("dev\d+" , torchaudio .__version__ )
45+ date_tv_str = re .findall ("dev\d+" , torchvision .__version__ )
46+
47+ # check that the above three lists are equal and none of them is empty
48+ if not date_t_str or not date_t_str == date_ta_str == date_tv_str :
49+ raise RuntimeError (
50+ f"Expected torch, torchaudio, torchvision to be the same date. But they are from { date_t_str } , { date_ta_str } , { date_tv_str } respectively"
51+ )
5552
5653 # check that the date is recent, at this point, date_torch_str is not empty
5754 binary_date_str = date_t_str [0 ][3 :]
@@ -65,8 +62,7 @@ def check_nightly_binaries_date() -> None:
6562 f"the binaries are from { binary_date_obj } and are more than 2 days old!"
6663 )
6764
68-
69- def smoke_test_cuda () -> None :
65+ def smoke_test_cuda (package : str ) -> None :
7066 if not torch .cuda .is_available () and is_cuda_system :
7167 raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
7268 if torch .cuda .is_available ():
@@ -79,23 +75,23 @@ def smoke_test_cuda() -> None:
7975 print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
8076 print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
8177
82- if installation_str . find ( "nightly" ) != - 1 :
83- # just print out cuda version, as version check were already performed during import
84- print ( f"torchvision cuda: { torch . ops . torchvision . _cuda_version () } " )
85- print (f"torchaudio cuda: { torch .ops .torchaudio . cuda_version ()} " )
86- else :
87- # torchaudio runtime added the cuda verison check on 09/23/2022 via
88- # https://github.com/pytorch/audio/pull/2707
89- # so relying on anaconda output for pytorch-test and pytorch channel
90- torchaudio_allstr = get_anaconda_output_for_package ( torchaudio . __name__ )
91- if (
92- is_cuda_system
93- and "cu" + str ( gpu_arch_ver ). replace ( "." , "" ) not in torchaudio_allstr
94- ):
95- raise RuntimeError (
96- f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } "
97- )
98-
78+ if ( package == 'all' ) :
79+ if installation_str . find ( "nightly" ) != - 1 :
80+ # just print out cuda version, as version check were already performed during import
81+ print (f"torchvision cuda: { torch .ops .torchvision . _cuda_version ()} " )
82+ print ( f"torchaudio cuda: { torch . ops . torchaudio . cuda_version () } " )
83+ else :
84+ # torchaudio runtime added the cuda verison check on 09/23/2022 via
85+ # https://github.com/ pytorch/audio/pull/2707
86+ # so relying on anaconda output for pytorch-test and pytorch channel
87+ torchaudio_allstr = get_anaconda_output_for_package ( torchaudio . __name__ )
88+ if (
89+ is_cuda_system
90+ and "cu" + str ( gpu_arch_ver ). replace ( "." , "" ) not in torchaudio_allstr
91+ ):
92+ raise RuntimeError (
93+ f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } "
94+ )
9995
10096def smoke_test_conv2d () -> None :
10197 import torch .nn as nn
@@ -180,24 +176,37 @@ def smoke_test_torchaudio() -> None:
180176
181177
182178def main () -> None :
183- # todo add torch, torchvision and torchaudio tests
179+ parser = argparse .ArgumentParser ()
180+ parser .add_argument (
181+ "--package" ,
182+ help = "Package to include in smoke testing" ,
183+ type = str ,
184+ choices = ["all" , "torchonly" ],
185+ default = "all" ,
186+ )
187+
184188 print (f"torch: { torch .__version__ } " )
185- print (f"torchvision: { torchvision .__version__ } " )
186- print (f"torchaudio: { torchaudio .__version__ } " )
187- smoke_test_cuda ()
189+ smoke_test_cuda (options .package )
190+ smoke_test_conv2d ()
188191
189192 # only makes sense to check nightly package where dates are known
190193 if installation_str .find ("nightly" ) != - 1 :
191194 check_nightly_binaries_date ()
192195
193- smoke_test_conv2d ()
194- smoke_test_torchaudio ()
195- smoke_test_torchvision ()
196- smoke_test_torchvision_read_decode ()
197- smoke_test_torchvision_resnet50_classify ()
198- if torch .cuda .is_available ():
199- smoke_test_torchvision_resnet50_classify ("cuda" )
200-
196+ if options .package == "all" :
197+ import torchaudio
198+ # the following import would invoke
199+ # _check_cuda_version()
200+ # via torchvision.extension._check_cuda_version()
201+ import torchvision
202+ print (f"torchvision: { torchvision .__version__ } " )
203+ print (f"torchaudio: { torchaudio .__version__ } " )
204+ smoke_test_torchaudio ()
205+ smoke_test_torchvision ()
206+ smoke_test_torchvision_read_decode ()
207+ smoke_test_torchvision_resnet50_classify ()
208+ if torch .cuda .is_available ():
209+ smoke_test_torchvision_resnet50_classify ("cuda" )
201210
202211if __name__ == "__main__" :
203212 main ()
0 commit comments