@@ -46,6 +46,34 @@ def test_compile_script(self):
4646 self .assertTrue (same < 2e-3 )
4747
4848
49+ class TestFallbackToTorch (ModelTestCase ):
50+
51+ def setUp (self ):
52+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
53+ self .scripted_model = torch .jit .script (self .model )
54+
55+ def test_compile_script (self ):
56+ compile_spec = {
57+ "input_shapes" : [self .input .shape ],
58+ "device" : {
59+ "device_type" : trtorch .DeviceType .GPU ,
60+ "gpu_id" : 0 ,
61+ "dla_core" : 0 ,
62+ "allow_gpu_fallback" : False ,
63+ "disable_tf32" : False
64+ },
65+ "torch_fallback" : {
66+ "enabled" : True ,
67+ "forced_fallback_ops" : ["aten::max_pool2d" ],
68+ "min_block_size" : 1
69+ }
70+ }
71+
72+ trt_mod = trtorch .compile (self .scripted_model , compile_spec )
73+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
74+ self .assertTrue (same < 2e-3 )
75+
76+
4977class TestPTtoTRTtoPT (ModelTestCase ):
5078
5179 def setUp (self ):
@@ -106,6 +134,7 @@ def test_suite():
106134 suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
107135 suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
108136 suite .addTest (TestPTtoTRTtoPT .parametrize (TestPTtoTRTtoPT , model = models .mobilenet_v2 (pretrained = True )))
137+ suite .addTest (TestFallbackToTorch .parametrize (TestFallbackToTorch , model = models .resnet18 (pretrained = True )))
109138 suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
110139
111140 return suite
0 commit comments