1+ import unittest
2+ import trtorch
3+ import torch
4+ import torchvision .models as models
5+
6+
7+ class ModelTestCase (unittest .TestCase ):
8+ def __init__ (self , methodName = 'runTest' , model = None ):
9+ super (ModelTestCase , self ).__init__ (methodName )
10+ self .model = model
11+ self .model .eval ().to ("cuda" )
12+
13+ @staticmethod
14+ def parametrize (testcase_class , model = None ):
15+ testloader = unittest .TestLoader ()
16+ testnames = testloader .getTestCaseNames (testcase_class )
17+ suite = unittest .TestSuite ()
18+ for name in testnames :
19+ suite .addTest (testcase_class (name , model = model ))
20+ return suite
21+
22+ class TestCompile (ModelTestCase ):
23+ def setUp (self ):
24+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
25+ self .traced_model = torch .jit .trace (self .model , [self .input ])
26+ self .scripted_model = torch .jit .script (self .model )
27+
28+ def test_compile_traced (self ):
29+ extra_info = {
30+ "input_shapes" : [self .input .shape ],
31+ }
32+
33+ trt_mod = trtorch .compile (self .traced_model , extra_info )
34+ same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
35+ self .assertTrue (same < 2e-3 )
36+
37+ #def test_compile_script(self):
38+ # pass
39+
40+ class TestCheckMethodOpSupport (unittest .TestCase ):
41+ def setUp (self ):
42+ module = models .alexnet (pretrained = True ).eval ().to ("cuda" )
43+ self .module = torch .jit .trace (module , torch .ones ((1 , 3 , 224 , 224 )).to ("cuda" ))
44+
45+ def test_check_support (self ):
46+ self .assertTrue (trtorch .check_method_op_support (self .module , "forward" ))
47+
48+ class TestLoggingAPIs (unittest .TestCase ):
49+ def test_logging_prefix (self ):
50+ new_prefix = "TEST"
51+ trtorch .logging .set_logging_prefix (new_prefix )
52+ logging_prefix = trtorch .logging .get_logging_prefix ()
53+ self .assertEqual (new_prefix , logging_prefix )
54+
55+ def test_reportable_log_level (self ):
56+ new_level = trtorch .logging .Level .Warning
57+ trtorch .logging .set_reportable_log_level (new_level )
58+ level = trtorch .logging .get_reportable_log_level ()
59+ self .assertEqual (new_level , level )
60+
61+ def test_is_colored_output_on (self ):
62+ trtorch .logging .set_is_colored_output_on (True )
63+ color = trtorch .logging .get_is_colored_output_on ()
64+ self .assertTrue (color )
65+
66+ def test_suite ():
67+ suite = unittest .TestSuite ()
68+ suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
69+ suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet50 (pretrained = True )))
70+ suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
71+ suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
72+ suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
73+
74+ return suite
75+
76+ suite = test_suite ()
77+
78+ runner = unittest .TextTestRunner ()
79+ result = runner .run (suite )
80+
81+ exit (int (not result .wasSuccessful ()))
0 commit comments