Skip to content

Commit

Permalink
fix: Fix python API tests for mobilenet v2
Browse files Browse the repository at this point in the history
This commit modifies test cases to use traced model instead of scripting. Model execution with  mobilenet(using torch.jit.script) has problems with Pytorch 1.10

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Oct 20, 2021
1 parent 4d95b04 commit e5a38ff
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def test_from_torch_tensor(self):
"enabled_precisions": {torch.float}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
trt_mod = trtorch.compile(self.traced_model, compile_spec)
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_device(self):
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
trt_mod = trtorch.compile(self.traced_model, compile_spec)
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)


Expand Down Expand Up @@ -169,7 +169,7 @@ class TestPTtoTRTtoPT(ModelTestCase):

def setUp(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.ts_model = torch.jit.script(self.model)
self.ts_model = torch.jit.trace(self.model, [self.input])

def test_pt_to_trt_to_pt(self):
compile_spec = {
Expand Down

0 comments on commit e5a38ff

Please sign in to comment.