|
31 | 31 | from .quantization import gen_quantization_module |
32 | 32 | from .rope import gen_rope_module |
33 | 33 | from .sampling import gen_sampling_module |
| 34 | +from .utils import version_at_least |
34 | 35 |
|
35 | 36 |
|
36 | 37 | def gen_fa2( |
@@ -482,8 +483,12 @@ def main(): |
482 | 483 | if "TORCH_CUDA_ARCH_LIST" not in os.environ: |
483 | 484 | raise RuntimeError("Please explicitly set env var TORCH_CUDA_ARCH_LIST.") |
484 | 485 | gencode_flags = _get_cuda_arch_flags() |
485 | | - has_sm90 = any("compute_90" in flag for flag in gencode_flags) |
486 | | - has_sm100 = any("compute_100" in flag for flag in gencode_flags) |
| 486 | + has_sm90 = any("compute_90" in flag for flag in gencode_flags) and version_at_least( |
| 487 | + torch.version.cuda, "12.3" |
| 488 | + ) |
| 489 | + has_sm100 = any( |
| 490 | + "compute_100" in flag for flag in gencode_flags |
| 491 | + ) and version_at_least(torch.version.cuda, "12.8") |
487 | 492 |
|
488 | 493 | # Update data dir |
489 | 494 | jit_env.FLASHINFER_CSRC_DIR = project_root / "csrc" |
@@ -528,38 +533,41 @@ def main(): |
528 | 533 | jit_env.SPDLOG_INCLUDE_DIR, |
529 | 534 | jit_env.FLASHINFER_INCLUDE_DIR, |
530 | 535 | ], |
531 | | - ), |
532 | | - gen_jit_spec( |
533 | | - "trtllm_utils", |
534 | | - [ |
535 | | - jit_env.FLASHINFER_CSRC_DIR |
536 | | - / "nv_internal" |
537 | | - / "tensorrt_llm" |
538 | | - / "kernels" |
539 | | - / "delayStream.cu", |
540 | | - ], |
541 | | - extra_include_paths=[ |
542 | | - jit_env.FLASHINFER_CSRC_DIR / "nv_internal", |
543 | | - jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", |
544 | | - jit_env.FLASHINFER_CSRC_DIR |
545 | | - / "nv_internal" |
546 | | - / "tensorrt_llm" |
547 | | - / "cutlass_extensions" |
548 | | - / "include", |
549 | | - jit_env.FLASHINFER_CSRC_DIR |
550 | | - / "nv_internal" |
551 | | - / "tensorrt_llm" |
552 | | - / "kernels" |
553 | | - / "internal_cutlass_kernels" |
554 | | - / "include", |
555 | | - jit_env.FLASHINFER_CSRC_DIR |
556 | | - / "nv_internal" |
557 | | - / "tensorrt_llm" |
558 | | - / "kernels" |
559 | | - / "internal_cutlass_kernels", |
560 | | - ], |
561 | | - ), |
| 536 | + ) |
562 | 537 | ] |
| 538 | + if has_sm90: |
| 539 | + jit_specs.append( |
| 540 | + gen_jit_spec( |
| 541 | + "trtllm_utils", |
| 542 | + [ |
| 543 | + jit_env.FLASHINFER_CSRC_DIR |
| 544 | + / "nv_internal" |
| 545 | + / "tensorrt_llm" |
| 546 | + / "kernels" |
| 547 | + / "delayStream.cu", |
| 548 | + ], |
| 549 | + extra_include_paths=[ |
| 550 | + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", |
| 551 | + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", |
| 552 | + jit_env.FLASHINFER_CSRC_DIR |
| 553 | + / "nv_internal" |
| 554 | + / "tensorrt_llm" |
| 555 | + / "cutlass_extensions" |
| 556 | + / "include", |
| 557 | + jit_env.FLASHINFER_CSRC_DIR |
| 558 | + / "nv_internal" |
| 559 | + / "tensorrt_llm" |
| 560 | + / "kernels" |
| 561 | + / "internal_cutlass_kernels" |
| 562 | + / "include", |
| 563 | + jit_env.FLASHINFER_CSRC_DIR |
| 564 | + / "nv_internal" |
| 565 | + / "tensorrt_llm" |
| 566 | + / "kernels" |
| 567 | + / "internal_cutlass_kernels", |
| 568 | + ], |
| 569 | + ), |
| 570 | + ) |
563 | 571 | jit_specs += gen_all_modules( |
564 | 572 | f16_dtype_, |
565 | 573 | f8_dtype_, |
|
0 commit comments