@@ -58,7 +58,8 @@ def train_model(session, use_host_env=False):
5858
5959 session .run_always ('python' ,
6060 'export_ckpt.py' ,
61- 'vgg16_ckpts/ckpt_epoch25.pth' )
61+ 'vgg16_ckpts/ckpt_epoch25.pth' ,
62+ env = {'PYTHONPATH' : PYT_PATH })
6263 else :
6364 session .run_always ('python' ,
6465 'main.py' ,
@@ -146,13 +147,27 @@ def run_accuracy_tests(session, use_host_env=False):
146147 else :
147148 session .run_always ("python" , test )
148149
150+ def copy_model (session ):
151+ model_files = [ 'trained_vgg16.jit.pt' ,
152+ 'trained_vgg16_qat.jit.pt' ]
153+
154+ for file_name in model_files :
155+ src_file = os .path .join (TOP_DIR , str ('examples/int8/training/vgg16/' ) + file_name )
156+ if os .path .exists (src_file ):
157+ session .run_always ('cp' ,
158+ '-rpf' ,
159+ os .path .join (TOP_DIR , src_file ),
160+ os .path .join (TOP_DIR , str ('tests/py/' ) + file_name ),
161+ external = True )
162+
149163def run_int8_accuracy_tests (session , use_host_env = False ):
150164 print ("Running accuracy tests" )
165+ copy_model (session )
151166 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
152167 tests = [
153- "test_ptq_dataloader .py" ,
168+ "test_ptq_dataloader_calibrator .py" ,
154169 "test_ptq_to_backend.py" ,
155- "test_qat_trt_accuracy" ,
170+ "test_qat_trt_accuracy.py " ,
156171 ]
157172 for test in tests :
158173 if use_host_env :
@@ -162,9 +177,10 @@ def run_int8_accuracy_tests(session, use_host_env=False):
162177
163178def run_trt_compatibility_tests (session , use_host_env = False ):
164179 print ("Running TensorRT compatibility tests" )
180+ copy_model (session )
165181 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
166182 tests = [
167- "test_trt_intercompatibilty .py" ,
183+ "test_trt_intercompatability .py" ,
168184 "test_ptq_trt_calibrator.py" ,
169185 ]
170186 for test in tests :
@@ -218,7 +234,7 @@ def run_l1_accuracy_tests(session, use_host_env=False):
218234 install_deps (session )
219235 install_torch_trt (session )
220236 download_models (session , use_host_env )
221- download_datasets (session , use_host_env )
237+ download_datasets (session )
222238 train_model (session , use_host_env )
223239 run_accuracy_tests (session , use_host_env )
224240 cleanup (session )
@@ -228,7 +244,7 @@ def run_l1_int8_accuracy_tests(session, use_host_env=False):
228244 install_deps (session )
229245 install_torch_trt (session )
230246 download_models (session , use_host_env )
231- download_datasets (session , use_host_env )
247+ download_datasets (session )
232248 train_model (session , use_host_env )
233249 finetune_model (session , use_host_env )
234250 run_int8_accuracy_tests (session , use_host_env )
@@ -239,6 +255,8 @@ def run_l2_trt_compatibility_tests(session, use_host_env=False):
239255 install_deps (session )
240256 install_torch_trt (session )
241257 download_models (session , use_host_env )
258+ download_datasets (session )
259+ train_model (session , use_host_env )
242260 run_trt_compatibility_tests (session , use_host_env )
243261 cleanup (session )
244262
0 commit comments