Skip to content

Commit af3ddbe

Browse files
tiranyzh119
authored andcommitted
bugfix: Fix missing symbols in trtllm_utils.so (flashinfer-ai#1168)
## 📌 Description The `trtllm_utils` library also needs symbols from common files, e.g. the `llmException`. Extend the dependencies to link `trtllm_utils.so` with common helpers. ```python >>> flashinfer.utils.get_trtllm_utils_module() Traceback (most recent call last): ... OSError: ...flashinfer/data/aot/trtllm_utils/trtllm_utils.so: undefined symbol: _ZTIN12tensorrt_llm6common13TllmExceptionE ``` ## 🔍 Related Issues Fixes: flashinfer-ai#1167 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: Christian Heimes <christian@python.org> Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent 104295e commit af3ddbe

File tree

4 files changed

+50
-72
lines changed

4 files changed

+50
-72
lines changed

flashinfer/aot.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .quantization import gen_quantization_module
3232
from .rope import gen_rope_module
3333
from .sampling import gen_sampling_module
34+
from .tllm_utils import get_trtllm_utils_spec
3435
from .utils import version_at_least
3536

3637

@@ -537,38 +538,7 @@ def main():
537538
)
538539
]
539540
if has_sm90:
540-
jit_specs.append(
541-
gen_jit_spec(
542-
"trtllm_utils",
543-
[
544-
jit_env.FLASHINFER_CSRC_DIR
545-
/ "nv_internal"
546-
/ "tensorrt_llm"
547-
/ "kernels"
548-
/ "delayStream.cu",
549-
],
550-
extra_include_paths=[
551-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
552-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
553-
jit_env.FLASHINFER_CSRC_DIR
554-
/ "nv_internal"
555-
/ "tensorrt_llm"
556-
/ "cutlass_extensions"
557-
/ "include",
558-
jit_env.FLASHINFER_CSRC_DIR
559-
/ "nv_internal"
560-
/ "tensorrt_llm"
561-
/ "kernels"
562-
/ "internal_cutlass_kernels"
563-
/ "include",
564-
jit_env.FLASHINFER_CSRC_DIR
565-
/ "nv_internal"
566-
/ "tensorrt_llm"
567-
/ "kernels"
568-
/ "internal_cutlass_kernels",
569-
],
570-
),
571-
)
541+
jit_specs.append(get_trtllm_utils_spec())
572542
jit_specs += gen_all_modules(
573543
f16_dtype_,
574544
f8_dtype_,

flashinfer/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# from tensorrt_llm.bindings.internal.runtime import delay_kernel
1313
# from tensorrt_llm.logger import logger
14-
from flashinfer.utils import delay_kernel
14+
from flashinfer.tllm_utils import delay_kernel
1515

1616
from .jit.core import logger
1717

flashinfer/tllm_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import functools
2+
3+
from .jit import env as jit_env
4+
from .jit import gen_jit_spec
5+
6+
7+
def get_trtllm_utils_spec():
8+
return gen_jit_spec(
9+
"trtllm_utils",
10+
[
11+
jit_env.FLASHINFER_CSRC_DIR
12+
/ "nv_internal/tensorrt_llm/kernels/delayStream.cu",
13+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp",
14+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp",
15+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp",
16+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp",
17+
],
18+
extra_include_paths=[
19+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
20+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
21+
jit_env.FLASHINFER_CSRC_DIR
22+
/ "nv_internal"
23+
/ "tensorrt_llm"
24+
/ "cutlass_extensions"
25+
/ "include",
26+
jit_env.FLASHINFER_CSRC_DIR
27+
/ "nv_internal"
28+
/ "tensorrt_llm"
29+
/ "kernels"
30+
/ "internal_cutlass_kernels"
31+
/ "include",
32+
jit_env.FLASHINFER_CSRC_DIR
33+
/ "nv_internal"
34+
/ "tensorrt_llm"
35+
/ "kernels"
36+
/ "internal_cutlass_kernels",
37+
],
38+
)
39+
40+
41+
@functools.cache
42+
def get_trtllm_utils_module():
43+
return get_trtllm_utils_spec().build_and_load()
44+
45+
46+
def delay_kernel(stream_delay_micro_secs):
47+
get_trtllm_utils_module().delay_kernel(stream_delay_micro_secs)

flashinfer/utils.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
import functools
1817
import math
1918
import os
2019
from enum import Enum
@@ -25,9 +24,6 @@
2524
from torch.torch_version import TorchVersion
2625
from torch.torch_version import __version__ as torch_version
2726

28-
from .jit import env as jit_env
29-
from .jit import gen_jit_spec
30-
3127
IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"
3228

3329

@@ -471,41 +467,6 @@ def set_log_level(lvl_str: str) -> None:
471467
get_logging_module().set_log_level(log_level_map[lvl_str].value)
472468

473469

474-
@functools.cache
475-
def get_trtllm_utils_module():
476-
return gen_jit_spec(
477-
"trtllm_utils",
478-
[
479-
jit_env.FLASHINFER_CSRC_DIR
480-
/ "nv_internal/tensorrt_llm/kernels/delayStream.cu",
481-
],
482-
extra_include_paths=[
483-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
484-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
485-
jit_env.FLASHINFER_CSRC_DIR
486-
/ "nv_internal"
487-
/ "tensorrt_llm"
488-
/ "cutlass_extensions"
489-
/ "include",
490-
jit_env.FLASHINFER_CSRC_DIR
491-
/ "nv_internal"
492-
/ "tensorrt_llm"
493-
/ "kernels"
494-
/ "internal_cutlass_kernels"
495-
/ "include",
496-
jit_env.FLASHINFER_CSRC_DIR
497-
/ "nv_internal"
498-
/ "tensorrt_llm"
499-
/ "kernels"
500-
/ "internal_cutlass_kernels",
501-
],
502-
).build_and_load()
503-
504-
505-
def delay_kernel(stream_delay_micro_secs):
506-
get_trtllm_utils_module().delay_kernel(stream_delay_micro_secs)
507-
508-
509470
def device_support_pdl(device: torch.device) -> bool:
510471
major, _ = get_compute_capability(device)
511472
return major >= 9

0 commit comments

Comments
 (0)