Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .quantization import gen_quantization_module
from .rope import gen_rope_module
from .sampling import gen_sampling_module
from .tllm_utils import get_trtllm_utils_spec
from .utils import version_at_least


Expand Down Expand Up @@ -537,38 +538,7 @@ def main():
)
]
if has_sm90:
jit_specs.append(
gen_jit_spec(
"trtllm_utils",
[
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "delayStream.cu",
],
extra_include_paths=[
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "cutlass_extensions"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "internal_cutlass_kernels"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "internal_cutlass_kernels",
],
),
)
jit_specs.append(get_trtllm_utils_spec())
jit_specs += gen_all_modules(
f16_dtype_,
f8_dtype_,
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# from tensorrt_llm.bindings.internal.runtime import delay_kernel
# from tensorrt_llm.logger import logger
from flashinfer.utils import delay_kernel
from flashinfer.tllm_utils import delay_kernel

from .jit.core import logger

Expand Down
47 changes: 47 additions & 0 deletions flashinfer/tllm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import functools

from .jit import env as jit_env
from .jit import gen_jit_spec


def get_trtllm_utils_spec():
return gen_jit_spec(
"trtllm_utils",
[
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal/tensorrt_llm/kernels/delayStream.cu",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp",
],
extra_include_paths=[
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "cutlass_extensions"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "internal_cutlass_kernels"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "internal_cutlass_kernels",
],
)


@functools.cache
def get_trtllm_utils_module():
return get_trtllm_utils_spec().build_and_load()


def delay_kernel(stream_delay_micro_secs):
get_trtllm_utils_module().delay_kernel(stream_delay_micro_secs)
39 changes: 0 additions & 39 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import functools
import math
import os
from enum import Enum
Expand All @@ -25,9 +24,6 @@
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

from .jit import env as jit_env
from .jit import gen_jit_spec

IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"


Expand Down Expand Up @@ -471,41 +467,6 @@ def set_log_level(lvl_str: str) -> None:
get_logging_module().set_log_level(log_level_map[lvl_str].value)


@functools.cache
def get_trtllm_utils_module():
return gen_jit_spec(
"trtllm_utils",
[
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal/tensorrt_llm/kernels/delayStream.cu",
],
extra_include_paths=[
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "cutlass_extensions"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "internal_cutlass_kernels"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "internal_cutlass_kernels",
],
).build_and_load()


def delay_kernel(stream_delay_micro_secs):
get_trtllm_utils_module().delay_kernel(stream_delay_micro_secs)


def device_support_pdl(device: torch.device) -> bool:
major, _ = get_compute_capability(device)
return major >= 9