Skip to content

Commit 401f6fb

Browse files
committed
Fix missing symbols in trtllm_utils.so
The `trtllm_utils` library also needs symbols from common files, e.g. the `llmException`. Extend the dependencies to link `trtllm_utils.so` with common helpers. ```python >>> flashinfer.utils.get_trtllm_utils_module() Traceback (most recent call last): ... OSError: ...flashinfer/data/aot/trtllm_utils/trtllm_utils.so: undefined symbol: _ZTIN12tensorrt_llm6common13TllmExceptionE ``` Fixes: #1167 Signed-off-by: Christian Heimes <christian@python.org>
1 parent bc50f1a commit 401f6fb

File tree

2 files changed

+12
-36
lines changed

2 files changed

+12
-36
lines changed

flashinfer/aot.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .quantization import gen_quantization_module
3232
from .rope import gen_rope_module
3333
from .sampling import gen_sampling_module
34-
from .utils import version_at_least
34+
from .utils import get_trtllm_utils_spec, version_at_least
3535

3636

3737
def gen_fa2(
@@ -536,38 +536,7 @@ def main():
536536
)
537537
]
538538
if has_sm90:
539-
jit_specs.append(
540-
gen_jit_spec(
541-
"trtllm_utils",
542-
[
543-
jit_env.FLASHINFER_CSRC_DIR
544-
/ "nv_internal"
545-
/ "tensorrt_llm"
546-
/ "kernels"
547-
/ "delayStream.cu",
548-
],
549-
extra_include_paths=[
550-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
551-
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
552-
jit_env.FLASHINFER_CSRC_DIR
553-
/ "nv_internal"
554-
/ "tensorrt_llm"
555-
/ "cutlass_extensions"
556-
/ "include",
557-
jit_env.FLASHINFER_CSRC_DIR
558-
/ "nv_internal"
559-
/ "tensorrt_llm"
560-
/ "kernels"
561-
/ "internal_cutlass_kernels"
562-
/ "include",
563-
jit_env.FLASHINFER_CSRC_DIR
564-
/ "nv_internal"
565-
/ "tensorrt_llm"
566-
/ "kernels"
567-
/ "internal_cutlass_kernels",
568-
],
569-
),
570-
)
539+
jit_specs.append(get_trtllm_utils_spec())
571540
jit_specs += gen_all_modules(
572541
f16_dtype_,
573542
f8_dtype_,

flashinfer/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,13 +471,16 @@ def set_log_level(lvl_str: str) -> None:
471471
get_logging_module().set_log_level(log_level_map[lvl_str].value)
472472

473473

474-
@functools.cache
475-
def get_trtllm_utils_module():
474+
def get_trtllm_utils_spec():
476475
return gen_jit_spec(
477476
"trtllm_utils",
478477
[
479478
jit_env.FLASHINFER_CSRC_DIR
480479
/ "nv_internal/tensorrt_llm/kernels/delayStream.cu",
480+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp",
481+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp",
482+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp",
483+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp",
481484
],
482485
extra_include_paths=[
483486
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
@@ -499,7 +502,11 @@ def get_trtllm_utils_module():
499502
/ "kernels"
500503
/ "internal_cutlass_kernels",
501504
],
502-
).build_and_load()
505+
)
506+
507+
@functools.cache
508+
def get_trtllm_utils_module():
509+
return get_trtllm_utils_spec().build_and_load()
503510

504511

505512
def delay_kernel(stream_delay_micro_secs):

0 commit comments

Comments
 (0)