Skip to content

Commit 6ddda58

Browse files
xuechendicharlifu
authored andcommitted
[BUG FIX][NON-CUDA]quick fix to avoid call cudagraph_unsafe in attention (vllm-project#25298)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent e11fe87 commit 6ddda58

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

vllm/attention/layer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929

3030
logger = init_logger(__name__)
3131
USE_XFORMERS_OPS = None
32+
try:
33+
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
34+
except AttributeError:
35+
tag_cudagraph_unsafe = () # type: ignore[assignment]
3236

3337

3438
def check_xformers_availability():
@@ -577,7 +581,7 @@ def unified_attention_fake(
577581
mutates_args=[],
578582
fake_impl=unified_attention_fake,
579583
dispatch_key=current_platform.dispatch_key,
580-
tags=(torch._C.Tag.cudagraph_unsafe, ),
584+
tags=tag_cudagraph_unsafe,
581585
)
582586

583587

@@ -628,5 +632,5 @@ def unified_attention_with_output_fake(
628632
mutates_args=["output", "output_block_scale"],
629633
fake_impl=unified_attention_with_output_fake,
630634
dispatch_key=current_platform.dispatch_key,
631-
tags=(torch._C.Tag.cudagraph_unsafe, ),
635+
tags=tag_cudagraph_unsafe,
632636
)

0 commit comments

Comments
 (0)