Skip to content

Commit

Permalink
fix: Fix for torch scripted module faiure with DLFW
Browse files Browse the repository at this point in the history
Signed-off-by: Anurag Dixit <anuragd@nvidia.com>
  • Loading branch information
Anurag Dixit committed Nov 23, 2021
1 parent a10613e commit 88c02d9
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def test_compile_traced(self):
self.assertTrue(same < 2e-2)

def test_compile_script(self):
trt_mod = torchtrt.ts.compile(self.scripted_model,
with torch.no_grad():
trt_mod = torchtrt.ts.compile(self.scripted_model,
inputs=[self.input],
device=torchtrt.Device(gpu_id=0),
enabled_precisions={torch.float})
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_compile_global(self):
trt_mod = torchtrt.compile(self.scripted_model,
Expand All @@ -46,12 +47,13 @@ def test_compile_global(self):
self.assertTrue(same < 2e-2)

def test_compile_global_nn_mod(self):
trt_mod = torchtrt.compile(self.model,
with torch.no_grad():
trt_mod = torchtrt.compile(self.model,
inputs=[self.input],
device=torchtrt.Device(gpu_id=0),
enabled_precisions={torch.float})
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_from_torch_tensor(self):
compile_spec = {
Expand Down

0 comments on commit 88c02d9

Please sign in to comment.