Skip to content

Commit f01884c

Browse files
Add missing mask check for unit tests
1 parent 645f535 commit f01884c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

onnxruntime/python/tools/transformers/fusion_bart_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
321321
["Slice", "Unsqueeze", "Gather", "Shape", "Add"],
322322
[1, 2, 0, 0, 0],
323323
)
324+
mask_nodes_whisper_oai_unit_test = self.model.match_parent_path(
325+
add_qk,
326+
["Slice", "Slice"],
327+
[1, 0],
328+
)
324329
if mask_nodes_whisper_hf is not None:
325330
mask_nodes = mask_nodes_whisper_hf
326331
elif mask_nodes_whisper_oai is not None:
327332
mask_nodes = mask_nodes_whisper_oai
333+
elif mask_nodes_whisper_oai_unit_test is not None:
334+
mask_nodes = mask_nodes_whisper_oai_unit_test
328335
elif mask_nodes_bart is not None:
329336
mask_nodes = mask_nodes_bart
330337
else:

0 commit comments

Comments
 (0)