1+ from distutils .command .clean import clean
12import nox
23import os
34
89# TOP_DIR
910TOP_DIR = os .path .dirname (os .path .realpath (__file__ )) if not 'TOP_DIR' in os .environ else os .environ ["TOP_DIR" ]
1011
11- nox .options .sessions = ["developer_tests -3" ]
12+ nox .options .sessions = ["l0_api_tests -3" ]
1213
1314def install_deps (session ):
1415 print ("Installing deps" )
@@ -30,31 +31,6 @@ def install_torch_trt(session):
3031 session .chdir (os .path .join (TOP_DIR , "py" ))
3132 session .run ("python" , "setup.py" , "develop" )
3233
33- def run_base_tests (session , use_host_env = False ):
34- print ("Running basic tests" )
35- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
36- tests = [
37- "test_api.py" ,
38- "test_to_backend_api.py"
39- ]
40- for test in tests :
41- if use_host_env :
42- session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
43- else :
44- session .run_always ("python" , test )
45-
46-
47- # Install the latest build of torch-tensorrt
48- @nox .session (python = ["3" ], reuse_venv = True )
49- def developer_tests (session ):
50- """Basic set of tests that need to pass for code to get merged"""
51- install_deps (session )
52- install_torch_trt (session )
53- download_models (session )
54- run_base_tests (session )
55-
56- # Download the dataset
57- @nox .session (python = ["3" ], reuse_venv = True )
5834def download_datasets (session ):
5935 print ("Downloading dataset to path" , os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
6036 session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
@@ -68,98 +44,70 @@ def download_datasets(session):
6844 os .path .join (TOP_DIR , 'tests/accuracy/datasets/data/cidar-10-batches-bin' ),
6945 external = True )
7046
71- # Download the model
72- @nox .session (python = ["3" ], reuse_venv = True )
73- def download_test_models (session ):
74- download_models (session , use_host_env = True )
75-
76- # Train the model
77- @nox .session (python = ["3" ], reuse_venv = True )
78- def train_model (session ):
47+ def train_model (session , use_host_env = False ):
7948 session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
80- session . run_always ( 'python' ,
81- 'main.py ' ,
82- '--lr' , '0.01 ' ,
83- '--batch-size ' , '128 ' ,
84- '--drop-ratio ' , '0.15 ' ,
85- '--ckpt-dir ' , 'vgg16_ckpts ' ,
86- '--epochs ' , '25 ' ,
87- env = { 'PYTHONPATH' : PYT_PATH })
88-
89- # Export model
90- session .run_always ('python' ,
49+ if use_host_env :
50+ session . run_always ( 'python ' ,
51+ 'main.py ' ,
52+ '--lr ' , '0.01 ' ,
53+ '--batch-size ' , '128 ' ,
54+ '--drop-ratio ' , '0.15 ' ,
55+ '--ckpt-dir ' , 'vgg16_ckpts ' ,
56+ '--epochs' , '25' ,
57+ env = { 'PYTHONPATH' : PYT_PATH })
58+
59+ session .run_always ('python' ,
9160 'export_ckpt.py' ,
92- 'vgg16_ckpts/ckpt_epoch25.pth' ,
93- env = {'PYTHONPATH' : PYT_PATH })
61+ 'vgg16_ckpts/ckpt_epoch25.pth' )
62+ else :
63+ session .run_always ('python' ,
64+ 'main.py' ,
65+ '--lr' , '0.01' ,
66+ '--batch-size' , '128' ,
67+ '--drop-ratio' , '0.15' ,
68+ '--ckpt-dir' , 'vgg16_ckpts' ,
69+ '--epochs' , '25' )
9470
95- # Finetune the model
96- @nox .session (python = ["3" ], reuse_venv = True )
97- def finetune_model (session ):
71+ session .run_always ('python' ,
72+ 'export_ckpt.py' ,
73+ 'vgg16_ckpts/ckpt_epoch25.pth' )
74+
75+ def finetune_model (session , use_host_env = False ):
9876 # Install pytorch-quantization dependency
9977 session .install ('pytorch-quantization' , '--extra-index-url' , 'https://pypi.ngc.nvidia.com' )
100-
10178 session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
102- session .run_always ('python' ,
103- 'finetune_qat.py' ,
104- '--lr' , '0.01' ,
105- '--batch-size' , '128' ,
106- '--drop-ratio' , '0.15' ,
107- '--ckpt-dir' , 'vgg16_ckpts' ,
108- '--start-from' , '25' ,
109- '--epochs' , '26' ,
110- env = {'PYTHONPATH' : PYT_PATH })
111-
112- # Export model
113- session .run_always ('python' ,
114- 'export_qat.py' ,
115- 'vgg16_ckpts/ckpt_epoch26.pth' ,
116- env = {'PYTHONPATH' : PYT_PATH })
117-
118- # Run PTQ tests
119- @nox .session (python = ["3" ], reuse_venv = True )
120- def ptq_test (session ):
121- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
122- session .run_always ('cp' , '-rf' ,
123- os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16.jit.pt' ),
124- '.' ,
125- external = True )
126- tests = [
127- 'test_ptq_dataloader_calibrator.py' ,
128- 'test_ptq_to_backend.py' ,
129- 'test_ptq_trt_calibrator.py'
130- ]
131- for test in tests :
132- session .run_always ('python' , test ,
133- env = {'PYTHONPATH' : PYT_PATH })
13479
135- # Run QAT tests
136- @nox .session (python = ["3" ], reuse_venv = True )
137- def qat_test (session ):
138- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
139- session .run_always ('cp' , '-rf' ,
140- os .path .join (TOP_DIR , 'examples/int8/training/vgg16' , 'trained_vgg16_qat.jit.pt' ),
141- '.' ,
142- external = True )
143-
144- session .run_always ('python' ,
145- 'test_qat_trt_accuracy.py' ,
146- env = {'PYTHONPATH' : PYT_PATH })
80+ if use_host_env :
81+ session .run_always ('python' ,
82+ 'finetune_qat.py' ,
83+ '--lr' , '0.01' ,
84+ '--batch-size' , '128' ,
85+ '--drop-ratio' , '0.15' ,
86+ '--ckpt-dir' , 'vgg16_ckpts' ,
87+ '--start-from' , '25' ,
88+ '--epochs' , '26' ,
89+ env = {'PYTHONPATH' : PYT_PATH })
14790
148- # Run Python API tests
149- @nox .session (python = ["3" ], reuse_venv = True )
150- def api_test (session ):
151- session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
152- tests = [
153- "test_api.py" ,
154- "test_to_backend_api.py"
155- ]
156- for test in tests :
91+ # Export model
15792 session .run_always ('python' ,
158- test ,
93+ 'export_qat.py' ,
94+ 'vgg16_ckpts/ckpt_epoch26.pth' ,
15995 env = {'PYTHONPATH' : PYT_PATH })
96+ else :
97+ session .run_always ('python' ,
98+ 'finetune_qat.py' ,
99+ '--lr' , '0.01' ,
100+ '--batch-size' , '128' ,
101+ '--drop-ratio' , '0.15' ,
102+ '--ckpt-dir' , 'vgg16_ckpts' ,
103+ '--start-from' , '25' ,
104+ '--epochs' , '26' )
105+
106+ # Export model
107+ session .run_always ('python' ,
108+ 'export_qat.py' ,
109+ 'vgg16_ckpts/ckpt_epoch26.pth' )
160110
161- # Clean up
162- @nox .session (reuse_venv = True )
163111def cleanup (session ):
164112 target = [
165113 'examples/int8/training/vgg16/*.jit.pt' ,
@@ -173,4 +121,186 @@ def cleanup(session):
173121 target = ' ' .join (x for x in [os .path .join (TOP_DIR , i ) for i in target ])
174122 session .run_always ('bash' , '-c' ,
175123 str ('rm -rf ' ) + target ,
176- external = True )
124+ external = True )
125+
126+ def run_base_tests (session , use_host_env = False ):
127+ print ("Running basic tests" )
128+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
129+ tests = [
130+ "test_api.py" ,
131+ "test_to_backend_api.py"
132+ ]
133+ for test in tests :
134+ if use_host_env :
135+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
136+ else :
137+ session .run_always ("python" , test )
138+
139+ def run_accuracy_tests (session , use_host_env = False ):
140+ print ("Running accuracy tests" )
141+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
142+ tests = []
143+ for test in tests :
144+ if use_host_env :
145+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
146+ else :
147+ session .run_always ("python" , test )
148+
149+ def run_int8_accuracy_tests (session , use_host_env = False ):
150+ print ("Running accuracy tests" )
151+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
152+ tests = [
153+ "test_ptq_dataloader.py" ,
154+ "test_ptq_to_backend.py" ,
155+ "test_qat_trt_accuracy" ,
156+ ]
157+ for test in tests :
158+ if use_host_env :
159+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
160+ else :
161+ session .run_always ("python" , test )
162+
163+ def run_trt_compatibility_tests (session , use_host_env = False ):
164+ print ("Running TensorRT compatibility tests" )
165+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
166+ tests = [
167+ "test_trt_intercompatibilty.py" ,
168+ "test_ptq_trt_calibrator.py" ,
169+ ]
170+ for test in tests :
171+ if use_host_env :
172+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
173+ else :
174+ session .run_always ("python" , test )
175+
176+ def run_dla_tests (session , use_host_env = False ):
177+ print ("Running DLA tests" )
178+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
179+ tests = [
180+ "test_api_dla.py" ,
181+ ]
182+ for test in tests :
183+ if use_host_env :
184+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
185+ else :
186+ session .run_always ("python" , test )
187+
188+ def run_multi_gpu_tests (session , use_host_env = False ):
189+ print ("Running multi GPU tests" )
190+ session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
191+ tests = [
192+ "test_multi_gpu.py" ,
193+ ]
194+ for test in tests :
195+ if use_host_env :
196+ session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
197+ else :
198+ session .run_always ("python" , test )
199+
200+ def run_l0_api_tests (session , use_host_env = False ):
201+ if not use_host_env :
202+ install_deps (session )
203+ install_torch_trt (session )
204+ download_models (session , use_host_env )
205+ run_base_tests (session , use_host_env )
206+ cleanup (session )
207+
208+ def run_l0_dla_tests (session , use_host_env = False ):
209+ if not use_host_env :
210+ install_deps (session )
211+ install_torch_trt (session )
212+ download_models (session , use_host_env )
213+ run_base_tests (session , use_host_env )
214+ cleanup (session )
215+
216+ def run_l1_accuracy_tests (session , use_host_env = False ):
217+ if not use_host_env :
218+ install_deps (session )
219+ install_torch_trt (session )
220+ download_models (session , use_host_env )
221+ download_datasets (session , use_host_env )
222+ train_model (session , use_host_env )
223+ run_accuracy_tests (session , use_host_env )
224+ cleanup (session )
225+
226+ def run_l1_int8_accuracy_tests (session , use_host_env = False ):
227+ if not use_host_env :
228+ install_deps (session )
229+ install_torch_trt (session )
230+ download_models (session , use_host_env )
231+ download_datasets (session , use_host_env )
232+ train_model (session , use_host_env )
233+ finetune_model (session , use_host_env )
234+ run_int8_accuracy_tests (session , use_host_env )
235+ cleanup (session )
236+
237+ def run_l2_trt_compatibility_tests (session , use_host_env = False ):
238+ if not use_host_env :
239+ install_deps (session )
240+ install_torch_trt (session )
241+ download_models (session , use_host_env )
242+ run_trt_compatibility_tests (session , use_host_env )
243+ cleanup (session )
244+
245+ def run_l2_multi_gpu_tests (session , use_host_env = False ):
246+ if not use_host_env :
247+ install_deps (session )
248+ install_torch_trt (session )
249+ download_models (session , use_host_env )
250+ run_multi_gpu_tests (session , use_host_env )
251+ cleanup (session )
252+
253+ @nox .session (python = ["3" ], reuse_venv = True )
254+ def l0_api_tests (session ):
255+ """When a developer needs to check correctness for a PR or something"""
256+ run_l0_api_tests (session , use_host_env = False )
257+
258+ @nox .session (python = ["3" ], reuse_venv = True )
259+ def l0_api_tests_host_deps (session ):
260+ """When a developer needs to check basic api functionality using host dependencies"""
261+ run_l0_api_tests (session , use_host_env = True )
262+
263+ @nox .session (python = ["3" ], reuse_venv = True )
264+ def l0_dla_tests_host_deps (session ):
265+ """When a developer needs to check basic api functionality using host dependencies"""
266+ run_l0_dla_tests (session , use_host_env = True )
267+
268+ @nox .session (python = ["3" ], reuse_venv = True )
269+ def l1_accuracy_tests (session ):
270+ """Checking accuracy performance on various usecases"""
271+ run_l1_accuracy_tests (session , use_host_env = False )
272+
273+ @nox .session (python = ["3" ], reuse_venv = True )
274+ def l1_accuracy_tests_host_deps (session ):
275+ """Checking accuracy performance on various usecases using host dependencies"""
276+ run_l1_accuracy_tests (session , use_host_env = True )
277+
278+ @nox .session (python = ["3" ], reuse_venv = True )
279+ def l1_int8_accuracy_tests (session ):
280+ """Checking accuracy performance on various usecases"""
281+ run_l1_int8_accuracy_tests (session , use_host_env = False )
282+
283+ @nox .session (python = ["3" ], reuse_venv = True )
284+ def l1_int8_accuracy_tests_host_deps (session ):
285+ """Checking accuracy performance on various usecases using host dependencies"""
286+ run_l1_int8_accuracy_tests (session , use_host_env = True )
287+
288+ @nox .session (python = ["3" ], reuse_venv = True )
289+ def l2_trt_compatibility_tests (session ):
290+ """Makes sure that TensorRT Python and Torch-TensorRT can work together"""
291+ run_l2_trt_compatibility_tests (session , use_host_env = False )
292+
293+ @nox .session (python = ["3" ], reuse_venv = True )
294+ def l2_trt_compatibility_tests_host_deps (session ):
295+ """Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
296+ run_l2_trt_compatibility_tests (session , use_host_env = True )
297+
298+ @nox .session (python = ["3" ], reuse_venv = True )
299+ def l2_multi_gpu_tests (session ):
300+ """Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
301+ run_l2_multi_gpu_tests (session , use_host_env = False )
302+
303+ @nox .session (python = ["3" ], reuse_venv = True )
304+ def l2_multi_gpu_tests_host_deps (session ):
305+ """Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
306+ run_l2_multi_gpu_tests (session , use_host_env = True )
0 commit comments