@@ -46,6 +46,34 @@ def test_compile_script(self):
46
46
self .assertTrue (same < 2e-3 )
47
47
48
48
49
+ class TestFallbackToTorch (ModelTestCase ):
50
+
51
+ def setUp (self ):
52
+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
53
+ self .scripted_model = torch .jit .script (self .model )
54
+
55
+ def test_compile_script (self ):
56
+ compile_spec = {
57
+ "input_shapes" : [self .input .shape ],
58
+ "device" : {
59
+ "device_type" : trtorch .DeviceType .GPU ,
60
+ "gpu_id" : 0 ,
61
+ "dla_core" : 0 ,
62
+ "allow_gpu_fallback" : False ,
63
+ "disable_tf32" : False
64
+ },
65
+ "torch_fallback" : {
66
+ "enabled" : True ,
67
+ "forced_fallback_ops" : ["aten::max_pool2d" ],
68
+ "min_block_size" : 1
69
+ }
70
+ }
71
+
72
+ trt_mod = trtorch .compile (self .scripted_model , compile_spec )
73
+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
74
+ self .assertTrue (same < 2e-3 )
75
+
76
+
49
77
class TestPTtoTRTtoPT (ModelTestCase ):
50
78
51
79
def setUp (self ):
@@ -106,6 +134,7 @@ def test_suite():
106
134
suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
107
135
suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
108
136
suite .addTest (TestPTtoTRTtoPT .parametrize (TestPTtoTRTtoPT , model = models .mobilenet_v2 (pretrained = True )))
137
+ suite .addTest (TestFallbackToTorch .parametrize (TestFallbackToTorch , model = models .resnet18 (pretrained = True )))
109
138
suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
110
139
111
140
return suite
0 commit comments