Skip to content

Commit

Permalink
bugfix: only use sm90 group gemm when torch cuda >= 12.3 (#699)
Browse files Browse the repository at this point in the history
wgmma is only available for cuda 12.3 or later, turn to use sm80 version
when torch.cuda version is lower.

cc @xslingcn
  • Loading branch information
yzh119 authored Dec 25, 2024
1 parent d158717 commit 1c8dc36
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_cuda_stream(device: torch.device) -> int:

def determine_gemm_backend(device: torch.device) -> str:
major, _ = get_compute_capability(device)
if major >= 9:
if major >= 9 and torch.version.cuda >= "12.3":
return "sm90"
else:
return "sm80"
Expand Down

0 comments on commit 1c8dc36

Please sign in to comment.