Skip to content

Commit 5455390

Browse files
committed
Fix
1 parent 08280cc commit 5455390

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
105105
og_nodes = list(program.graph.nodes)
106106
transformed_nodes = list(graph_module_out.graph.nodes)
107107

108-
assert any(node.target.__name__ == "batch_norm.default" for node in og_nodes)
108+
assert any(
109+
node.op == "call_function" and node.target.__name__ == "batch_norm.default"
110+
for node in og_nodes
111+
)
109112

110113
assert not any(
111114
node.op == "call_function" and "batch_norm" in node.target.__name__
@@ -137,7 +140,10 @@ def test_batch_norm_linear_fusing(bias: bool):
137140
og_nodes = list(og_module.graph.nodes)
138141
transformed_nodes = list(graph_module_out.graph.nodes)
139142

140-
assert any(node.target.__name__ == "linear.default" for node in og_nodes)
143+
assert any(
144+
node.op == "call_function" and node.target.__name__ == "linear.default"
145+
for node in og_nodes
146+
)
141147

142148
assert not any(
143149
node.op == "call_function" and "batch_norm" in node.target.__name__

0 commit comments

Comments
 (0)