Skip to content

Commit d61b667

Browse files
committed
fix(exception_elimination): Exception branches are no longer consistent
so cover both cases Also adjusts Half precision thresholds in python Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent a12d249 commit d61b667

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

core/lowering/passes/exception_elimination.cpp

+35-13
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ struct ExceptionOrPassPatternElimination {
2727

2828
private:
2929
bool isExceptionOrPassNode(Node* n) {
30-
/// Check if this Node hosts a pattern like so:
31-
/// = prim::If(%5958)
32-
/// block0():
33-
/// = prim::RaiseException(%45)
34-
/// -> ()
35-
/// block1():
36-
/// -> ()
3730
if (n->blocks().size() != 2) {
3831
return false;
3932
}
@@ -46,15 +39,44 @@ struct ExceptionOrPassPatternElimination {
4639
}
4740

4841
auto arm1_start = arm1->nodes().begin();
42+
auto arm2_start = arm2->nodes().begin();
4943

50-
if ((*arm1_start)->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
51-
// Make sure that block0 is solely just the exception and the return
52-
return false;
44+
/// Check if this Node hosts a pattern like so:
45+
/// = prim::If(%5958)
46+
/// block0():
47+
/// = prim::RaiseException(%45)
48+
/// -> ()
49+
/// block1():
50+
/// -> ()
51+
if ((*arm1_start)->kind() == prim::RaiseException) {
52+
if ((*(++arm1_start))->kind() != prim::Return) {
53+
// Make sure that block0 is solely just the exception and the return
54+
return false;
55+
}
56+
57+
if ((*(arm2_start))->kind() != prim::Return) {
58+
// Make sure that block1 is solely the return
59+
return false;
60+
}
5361
}
5462

55-
if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
56-
// Make sure that block1 is solely the return
57-
return false;
63+
/// Check if this Node hosts a pattern like so:
64+
/// = prim::If(%5958)
65+
/// block0():
66+
/// -> ()
67+
/// block1():
68+
/// = prim::RaiseException(%45)
69+
/// -> ()
70+
if ((*arm2_start)->kind() == prim::RaiseException) {
71+
if ((*(++arm2_start))->kind() != prim::Return) {
72+
// Make sure that block1 is solely just the exception and the return
73+
return false;
74+
}
75+
76+
if ((*(arm1_start))->kind() != prim::Return) {
77+
// Make sure that block0 is solely the return
78+
return false;
79+
}
5880
}
5981

6082
return true;

tests/py/test_api.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def test_compile_script_half(self):
9393

9494
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
9595
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
96-
self.assertTrue(same < 2e-2)
96+
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
97+
self.assertTrue(same < 3e-2)
9798

9899

99100
class TestCompileHalfDefault(ModelTestCase):
@@ -115,7 +116,8 @@ def test_compile_script_half_by_default(self):
115116

116117
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
117118
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
118-
self.assertTrue(same < 2e-2)
119+
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
120+
self.assertTrue(same < 3e-2)
119121

120122

121123
class TestFallbackToTorch(ModelTestCase):

0 commit comments

Comments
 (0)