Skip to content

Commit bc50f1a

Browse files
authored
ci: add guard for aot compilation (#1127)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Do not compile sm90a/sm100a kernels for CUDA version lower than certain threshold. ## πŸ” Related Issues Check the failed jobs https://github.com/flashinfer-ai/flashinfer/actions/runs/15504091163/job/43657444651 ## Reviewer Notes cc @abcdabcd987 @wenscarl
1 parent 5a15b43 commit bc50f1a

File tree

2 files changed

+49
-35
lines changed

2 files changed

+49
-35
lines changed

β€Žflashinfer/aot.pyβ€Ž

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .quantization import gen_quantization_module
3232
from .rope import gen_rope_module
3333
from .sampling import gen_sampling_module
34+
from .utils import version_at_least
3435

3536

3637
def gen_fa2(
@@ -482,8 +483,12 @@ def main():
482483
if "TORCH_CUDA_ARCH_LIST" not in os.environ:
483484
raise RuntimeError("Please explicitly set env var TORCH_CUDA_ARCH_LIST.")
484485
gencode_flags = _get_cuda_arch_flags()
485-
has_sm90 = any("compute_90" in flag for flag in gencode_flags)
486-
has_sm100 = any("compute_100" in flag for flag in gencode_flags)
486+
has_sm90 = any("compute_90" in flag for flag in gencode_flags) and version_at_least(
487+
torch.version.cuda, "12.3"
488+
)
489+
has_sm100 = any(
490+
"compute_100" in flag for flag in gencode_flags
491+
) and version_at_least(torch.version.cuda, "12.8")
487492

488493
# Update data dir
489494
jit_env.FLASHINFER_CSRC_DIR = project_root / "csrc"
@@ -528,38 +533,41 @@ def main():
528533
jit_env.SPDLOG_INCLUDE_DIR,
529534
jit_env.FLASHINFER_INCLUDE_DIR,
530535
],
531-
),
532-
gen_jit_spec(
533-
"trtllm_utils",
534-
[
535-
jit_env.FLASHINFER_CSRC_DIR
536-
/ "nv_internal"
537-
/ "tensorrt_llm"
538-
/ "kernels"
539-
/ "delayStream.cu",
540-
],
541-
extra_include_paths=[
542-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
543-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
544-
jit_env.FLASHINFER_CSRC_DIR
545-
/ "nv_internal"
546-
/ "tensorrt_llm"
547-
/ "cutlass_extensions"
548-
/ "include",
549-
jit_env.FLASHINFER_CSRC_DIR
550-
/ "nv_internal"
551-
/ "tensorrt_llm"
552-
/ "kernels"
553-
/ "internal_cutlass_kernels"
554-
/ "include",
555-
jit_env.FLASHINFER_CSRC_DIR
556-
/ "nv_internal"
557-
/ "tensorrt_llm"
558-
/ "kernels"
559-
/ "internal_cutlass_kernels",
560-
],
561-
),
536+
)
562537
]
538+
if has_sm90:
539+
jit_specs.append(
540+
gen_jit_spec(
541+
"trtllm_utils",
542+
[
543+
jit_env.FLASHINFER_CSRC_DIR
544+
/ "nv_internal"
545+
/ "tensorrt_llm"
546+
/ "kernels"
547+
/ "delayStream.cu",
548+
],
549+
extra_include_paths=[
550+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
551+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
552+
jit_env.FLASHINFER_CSRC_DIR
553+
/ "nv_internal"
554+
/ "tensorrt_llm"
555+
/ "cutlass_extensions"
556+
/ "include",
557+
jit_env.FLASHINFER_CSRC_DIR
558+
/ "nv_internal"
559+
/ "tensorrt_llm"
560+
/ "kernels"
561+
/ "internal_cutlass_kernels"
562+
/ "include",
563+
jit_env.FLASHINFER_CSRC_DIR
564+
/ "nv_internal"
565+
/ "tensorrt_llm"
566+
/ "kernels"
567+
/ "internal_cutlass_kernels",
568+
],
569+
),
570+
)
563571
jit_specs += gen_all_modules(
564572
f16_dtype_,
565573
f8_dtype_,

β€Žflashinfer/utils.pyβ€Ž

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,14 +394,20 @@ def determine_attention_backend(
394394
return "fa2"
395395

396396

397+
def version_at_least(version: str, base_version: str) -> bool:
398+
from packaging import version as pkg_version
399+
400+
return pkg_version.parse(version) >= pkg_version.parse(base_version)
401+
402+
397403
def is_sm90a_supported(device: torch.device) -> bool:
398404
major, _ = get_compute_capability(device)
399-
return major == 9 and torch.version.cuda >= "12.3"
405+
return major == 9 and version_at_least(torch.version.cuda, "12.3")
400406

401407

402408
def is_sm100a_supported(device: torch.device) -> bool:
403409
major, _ = get_compute_capability(device)
404-
return major == 10 and torch.version.cuda >= "12.8"
410+
return major == 10 and version_at_least(torch.version.cuda, "12.8")
405411

406412

407413
def determine_mla_backend(device: torch.device) -> str:

0 commit comments

Comments
Β (0)