diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 4e0fc3e8541a..6b5b1293ff21 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -16,203 +16,36 @@ # under the License. """FlashInfer JIT compilation module for CUDA backend""" -import hashlib -import json -import os -import subprocess -from concurrent.futures import ThreadPoolExecutor +import re from pathlib import Path from typing import List -import tvm_ffi - import tvm from tvm.target import Target -def _compile_flashinfer_kernels( - name: str, source_paths: List[Path], target: Target, num_threads: int -) -> List[Path]: - from flashinfer.jit.env import ( # pylint: disable=import-outside-toplevel - CUTLASS_INCLUDE_DIRS, - FLASHINFER_CSRC_DIR, - FLASHINFER_INCLUDE_DIR, - FLASHINFER_JIT_DIR, - FLASHINFER_TVM_BINDING_DIR, - ) - - # ------------------------------------------------------------------------ - # Caching Flow: create build_directory and compute cache hash. - # ------------------------------------------------------------------------ - build_directory = FLASHINFER_JIT_DIR / name - build_directory.mkdir(parents=True, exist_ok=True) - - def get_object_file_path(src: Path) -> Path: - obj_name = src.stem + ".o" - obj_path = build_directory / obj_name - return obj_path - - # Compute latest modification time among all source files - latest_src_mtime = max(src.stat().st_mtime for src in source_paths) +def _rename_exported_func_names(source_paths: List[Path], prefix: str): + """Rename the ffi-exported function names in the source files to the given prefix.""" + pattern = re.compile(r"^(\s*TVM_FFI_DLL_EXPORT_TYPED_FUNC\()([A-Za-z0-9_]+)(,.*)$") + for source_path in source_paths: + if not source_path.name.endswith("_binding.cu"): + continue - # Get modification time for the current file (the one that contains this function) - current_file_mtime = Path(__file__).stat().st_mtime + original_text = source_path.read_text(encoding="utf-8") + lines = original_text.splitlines(keepends=True) + updated = False + for idx, line in enumerate(lines): + line_body = line.rstrip("\r\n") + line_ending = line[len(line_body) :] + match = pattern.match(line_body) + if not match: + continue + new_body = f"{match.group(1)}{prefix}_{match.group(2)}{match.group(3)}" + lines[idx] = new_body + line_ending + updated = True - # Build the hash key from metadata - hash_key = { - "name": name, - "target": str(target), - "latest_src_mtime": latest_src_mtime, - "current_file_mtime": current_file_mtime, - } - - hash_value = hashlib.md5( - json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8") - ).hexdigest() - - # Check if a valid hash exists in the build directory - hash_file = build_directory / "hash.md5" - if hash_file.exists(): - with open(hash_file, "r") as f: - cached_hash = f.read().strip() - if cached_hash == hash_value: - # Check that all object files exist - object_files = [] - all_exist = True - for src in source_paths: - obj_path = get_object_file_path(src) - if not obj_path.exists(): - all_exist = False - break - object_files.append(obj_path) - if all_exist: - return object_files - - # If we are here, cache is missing or outdated. Write the new hash and compile the paths - with open(hash_file, "w") as f: - f.write(hash_value) - - # ------------------------------------------------------------------------ - # 1) Common CUDA compile flags - # ------------------------------------------------------------------------ - cuda_cflags = [ - "-O3", - "-std=c++17", - "--threads", - str(num_threads), - "-g", - "-use_fast_math", - "--expt-relaxed-constexpr", - # DMLC default - "-DDMLC_USE_FOPEN64=0", - "-DDMLC_USE_LOGGING_LIBRARY=", - # Enable `-fPIC` for the host compiler - "-Xcompiler=-fPIC", - "-DFLASHINFER_ENABLE_F16", - "-DFLASHINFER_ENABLE_BF16", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", - ] - - # Determine compute version - compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split(".")) - if compute_version in ["90", "100"]: - compute_version += "a" - cuda_cflags += [ - "-gencode", - f"arch=compute_{compute_version},code=sm_{compute_version}", - ] - - # ------------------------------------------------------------------------ - # 2) Include paths - # ------------------------------------------------------------------------ - include_paths = [ - FLASHINFER_INCLUDE_DIR, - FLASHINFER_CSRC_DIR, - FLASHINFER_TVM_BINDING_DIR, - ] + CUTLASS_INCLUDE_DIRS - - if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME", None): - # Respect TVM_SOURCE_DIR and TVM_HOME if they are set - tvm_home = ( - os.environ["TVM_SOURCE_DIR"] - if os.environ.get("TVM_SOURCE_DIR", None) - else os.environ["TVM_HOME"] - ) - include_paths += [ - Path(tvm_home).resolve() / "include", - Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "include", - Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "3rdparty" / "dlpack" / "include", - Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", - ] - else: - # If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM package path - tvm_package_path = Path(tvm.__file__).resolve().parent - if (tvm_package_path / "include").exists(): - # The package is installed from pip. - tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent - include_paths += [ - tvm_package_path / "include", - tvm_package_path / "3rdparty" / "dmlc-core" / "include", - tvm_ffi_package_path / "include", - ] - elif (tvm_package_path.parent.parent / "include").exists(): - # The package is installed from source. - include_paths += [ - tvm_package_path.parent.parent / "include", - tvm_package_path.parent.parent / "3rdparty" / "tvm-ffi" / "include", - tvm_package_path.parent.parent - / "3rdparty" - / "tvm-ffi" - / "3rdparty" - / "dlpack" - / "include", - tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include", - ] - else: - # warning: TVM is not installed in the system. - print( - "Warning: Include path for TVM cannot be found. " - "FlashInfer kernel compilation may fail due to missing headers." - ) - - # ------------------------------------------------------------------------ - # 3) Function to compile a single source file - # ------------------------------------------------------------------------ - def compile_single_source(src: Path) -> Path: - # Derive the .o filename from the source filename - obj_path = get_object_file_path(src) - - # Construct the command - cmd = ( - ["nvcc"] - + cuda_cflags - + [f"-I{inc_path}" for inc_path in include_paths] - + ["-c", "-o", str(obj_path), str(src)] - ) - - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - out, err = proc.communicate() - if proc.returncode != 0: - raise RuntimeError( - f"FlashInfer JIT compilation failed for {src}\n" - f"Command: {' '.join(cmd)}\n" - f"stdout:\n{out.decode('utf-8')}\n" - f"stderr:\n{err.decode('utf-8')}" - ) - return obj_path - - # ------------------------------------------------------------------------ - # 4) Compile each source in parallel using ThreadPoolExecutor - # ------------------------------------------------------------------------ - object_files = [] - with ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(compile_single_source, src) for src in source_paths] - for f in futures: - object_files.append(f.result()) # Will raise if there's a compilation error - - # Return list of generated object files for any further linking steps - return object_files + if updated: + source_path.write_text("".join(lines), encoding="utf-8") def _load_flashinfer_modules(object_files: List[Path]) -> List[tvm.runtime.Module]: @@ -228,9 +61,8 @@ def gen_flashinfer_prefill_module( dtype_o: str, qk_head_dim: int, v_head_dim: int, - target: Target, - enable_inline_rope: bool = True, - num_threads: int = 8, + enable_inline_rope: bool, + return_static_libs: bool = False, ) -> List[tvm.runtime.Module]: """Generate a FlashInfer module for prefill. @@ -246,12 +78,12 @@ def gen_flashinfer_prefill_module( The head dimension of the query and key tensors. v_head_dim : int The head dimension of the value tensor. - target : Target - The target device to compile for. enable_inline_rope : bool Whether to enable inline rotary positional embedding. - num_threads : int - The number of threads to use for compilation. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- @@ -259,7 +91,7 @@ def gen_flashinfer_prefill_module( """ try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_customize_batch_prefill_tvm_binding, + gen_customize_batch_prefill_module, ) except ImportError: raise ImportError( @@ -289,32 +121,33 @@ def gen_flashinfer_prefill_module( if backend == "fa2" else "#include " ) - jit_args = { - "backend": backend, - "uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_" + jit_spec = gen_customize_batch_prefill_module( + backend=backend, + uri=f"batch_prefill_tvm_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_o}_" + f"qk_head_dim_{qk_head_dim}_" + f"v_head_dim_{v_head_dim}_" + f"enable_inline_rope_{enable_inline_rope}", - "dtype_q": torch_dtype_q, - "dtype_kv": torch_dtype_kv, - "dtype_o": torch_dtype_o, - "idtype": torch.int32, - "head_dim_qk": qk_head_dim, - "head_dim_vo": v_head_dim, - "additional_tensor_names": [], - "additional_tensor_dtypes": [], - "additional_scalar_names": ["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], - "additional_scalar_dtypes": ["double", "double", "double"], - "variant_name": variant_name, - "variant_decl": variant_decl, - "enable_inline_rope": enable_inline_rope, - } - uri, source_paths = gen_customize_batch_prefill_tvm_binding(**jit_args) - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + dtype_q=torch_dtype_q, + dtype_kv=torch_dtype_kv, + dtype_o=torch_dtype_o, + idtype=torch.int32, + head_dim_qk=qk_head_dim, + head_dim_vo=v_head_dim, + pos_encoding_mode=int(enable_inline_rope), + additional_tensor_names=[], + additional_tensor_dtypes=[], + additional_scalar_names=["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + additional_scalar_dtypes=["double", "double", "double"], + variant_name=variant_name, + variant_decl=variant_decl, + ) + _rename_exported_func_names(jit_spec.sources, "batch_prefill") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] def gen_flashinfer_decode_module( @@ -323,8 +156,8 @@ def gen_flashinfer_decode_module( dtype_o: str, qk_head_dim: int, v_head_dim: int, - target: Target, - num_threads: int = 8, + enable_inline_rope: bool, + return_static_libs: bool = False, ) -> List[tvm.runtime.Module]: """Generate a FlashInfer module for decode. @@ -340,10 +173,12 @@ def gen_flashinfer_decode_module( The head dimension of the query and key tensors. v_head_dim : int The head dimension of the value tensor. - target : Target - The target device to compile for. - num_threads : int - The number of threads to use for compilation. + enable_inline_rope : bool + Whether to enable inline rotary positional embedding. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- @@ -351,7 +186,7 @@ def gen_flashinfer_decode_module( """ try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_customize_batch_decode_tvm_binding, + gen_customize_batch_decode_module, ) except ImportError: raise ImportError( @@ -366,29 +201,32 @@ def gen_flashinfer_decode_module( torch_dtype_q = getattr(torch, dtype_q) torch_dtype_kv = getattr(torch, dtype_kv) torch_dtype_o = getattr(torch, dtype_o) - jit_args = { - "uri": f"batch_decode_tvm_dtype_q_{dtype_q}_" + jit_spec = gen_customize_batch_decode_module( + uri=f"batch_decode_tvm_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_o}_" + f"qk_head_dim_{qk_head_dim}_" - + f"v_head_dim_{v_head_dim}", - "dtype_q": torch_dtype_q, - "dtype_kv": torch_dtype_kv, - "dtype_o": torch_dtype_o, - "idtype": torch.int32, - "head_dim_qk": qk_head_dim, - "head_dim_vo": v_head_dim, - "additional_tensor_names": [], - "additional_tensor_dtypes": [], - "additional_scalar_names": ["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], - "additional_scalar_dtypes": ["double", "double", "double"], - "variant_name": "DefaultAttention", - "variant_decl": "#include ", - } - uri, source_paths = gen_customize_batch_decode_tvm_binding(**jit_args) - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + + f"v_head_dim_{v_head_dim}_" + + f"enable_inline_rope_{enable_inline_rope}", + dtype_q=torch_dtype_q, + dtype_kv=torch_dtype_kv, + dtype_o=torch_dtype_o, + idtype=torch.int32, + head_dim_qk=qk_head_dim, + head_dim_vo=v_head_dim, + pos_encoding_mode=int(enable_inline_rope), + additional_tensor_names=[], + additional_tensor_dtypes=[], + additional_scalar_names=["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + additional_scalar_dtypes=["double", "double", "double"], + variant_name="DefaultAttention", + variant_decl="#include ", + ) + _rename_exported_func_names(jit_spec.sources, "batch_decode") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] def gen_flashinfer_mla_module( @@ -397,8 +235,7 @@ def gen_flashinfer_mla_module( dtype_o: str, head_dim_ckv: int, head_dim_kpe: int, - target: Target, - num_threads: int = 8, + return_static_libs: bool = False, ) -> List[tvm.runtime.Module]: """Generate a FlashInfer module for MLA. @@ -418,6 +255,10 @@ def gen_flashinfer_mla_module( The target device to compile for. num_threads : int The number of threads to use for compilation. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- @@ -425,7 +266,7 @@ def gen_flashinfer_mla_module( """ try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_batch_mla_tvm_binding, + gen_batch_mla_module, ) except ImportError: raise ImportError( @@ -440,92 +281,36 @@ def gen_flashinfer_mla_module( torch_dtype_q = getattr(torch, dtype_q) torch_dtype_kv = getattr(torch, dtype_kv) torch_dtype_o = getattr(torch, dtype_o) - jit_args = { - "uri": f"batch_mla_tvm_dtype_q_{dtype_q}_" - + f"dtype_kv_{dtype_kv}_" - + f"dtype_o_{dtype_o}_" - + f"head_dim_ckv_{head_dim_ckv}_" - + f"head_dim_kpe_{head_dim_kpe}", - "dtype_q": torch_dtype_q, - "dtype_kv": torch_dtype_kv, - "dtype_o": torch_dtype_o, - "dtype_idx": torch.int32, - "head_dim_ckv": head_dim_ckv, - "head_dim_kpe": head_dim_kpe, - } - uri, source_paths = gen_batch_mla_tvm_binding(**jit_args) - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules - - -def gen_sampling_module(target: Target, num_threads: int = 8): - """ - Generate a FlashInfer module for sampling kernels. - - Parameters - ---------- - target : Target - The target device for which the module will be compiled. - num_threads : int, optional - The number of threads to use during compilation (default is 8). - - Returns - ------- - List[tvm.runtime.Module] - A list of compiled static library modules for the FlashInfer sampling kernels. - """ - try: - from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_sampling_tvm_binding, - ) - except ImportError: - raise ImportError( - "FlashInfer is not installed. Please follow instructions " - "in https://docs.flashinfer.ai to install FlashInfer." - ) - uri, source_paths = gen_sampling_tvm_binding(uri="sampling") - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + jit_spec = gen_batch_mla_module( + backend="fa2", + dtype_q=torch_dtype_q, + dtype_kv=torch_dtype_kv, + dtype_o=torch_dtype_o, + dtype_idx=torch.int32, + head_dim_ckv=head_dim_ckv, + head_dim_kpe=head_dim_kpe, + use_profiler=False, + ) + _rename_exported_func_names(jit_spec.sources, "batch_mla") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] def gen_grouped_gemm_module( - dtype_a: str, - dtype_b: str, - dtype_out: str, - scale_granularity_m: int, - scale_granularity_n: int, - scale_granularity_k: int, - scale_major_mode: str, - mma_sm: int, - target: Target, - num_threads: int = 8, + target: Target, return_static_libs: bool = False ) -> List[tvm.runtime.Module]: """Generate a FlashInfer module for FP8 grouped GEMM. Parameters ---------- - dtype_a : str - The data type of matrix A (e.g., "float8_e4m3fn"). - dtype_b : str - The data type of matrix B (e.g., "float8_e4m3fn"). - dtype_out : str - The data type of the output matrix (e.g., "bfloat16"). - scale_granularity_m : int - The scaling granularity in the M dimension. - scale_granularity_n : int - The scaling granularity in the N dimension. - scale_granularity_k : int - The scaling granularity in the K dimension. - scale_major_mode : str - The scale storage mode ("K" or "MN"). - mma_sm : int - The MMA scheduling mode (1 or 2). target : Target The target device to compile for. - num_threads : int - The number of threads to use for compilation. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- @@ -537,48 +322,24 @@ def gen_grouped_gemm_module( when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k), m_indptr: (batch_size, ) requires all m in m_indptr to be multiple of 4 """ + # NOTE: This function is still under development, + # and we currently only support SM100 grouped gemm try: - from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_grouped_gemm_fp8_tvm_binding, - get_grouped_gemm_fp8_uri, + from flashinfer.gemm import ( # pylint: disable=import-outside-toplevel + gen_gemm_sm100_module, ) except ImportError: raise ImportError( "FlashInfer is not installed. Please follow instructions " "in https://docs.flashinfer.ai to install FlashInfer." ) - try: - import torch # pylint: disable=import-outside-toplevel - except ImportError: - raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") - - torch_dtype_a = getattr(torch, dtype_a) - torch_dtype_b = getattr(torch, dtype_b) - torch_dtype_out = getattr(torch, dtype_out) - - uri = get_grouped_gemm_fp8_uri( - dtype_a=torch_dtype_a, - dtype_b=torch_dtype_b, - dtype_out=torch_dtype_out, - scale_granularity_m=scale_granularity_m, - scale_granularity_n=scale_granularity_n, - scale_granularity_k=scale_granularity_k, - scale_major_mode=scale_major_mode, - mma_sm=mma_sm, - ) - uri, source_paths = gen_grouped_gemm_fp8_tvm_binding( - uri=uri, - dtype_a=torch_dtype_a, - dtype_b=torch_dtype_b, - dtype_out=torch_dtype_out, - scale_granularity_m=scale_granularity_m, - scale_granularity_n=scale_granularity_n, - scale_granularity_k=scale_granularity_k, - scale_major_mode=scale_major_mode, - mma_sm=mma_sm, - ) - - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split(".")) + if compute_version == "100": + jit_spec = gen_gemm_sm100_module() + else: + raise ValueError(f"Unsupported compute version: {compute_version}") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index e6e171da9903..e94d5c42957b 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -371,8 +371,7 @@ def __init__( # pylint: disable=too-many-locals enable_disaggregation : bool Whether to enable disaggregation in the KV cache. """ - if rope_mode == RopeMode.INLINE: - assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not support partial rotary dim." + assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support inline mode." attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind if attn_kind_single == "mha_sliding": @@ -383,8 +382,8 @@ def __init__( # pylint: disable=too-many-locals dtype_o=dtype, qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim), v_head_dim=(v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim), - target=target, - enable_inline_rope=rope_mode == RopeMode.INLINE, + enable_inline_rope=False, + return_static_libs=True, ) flashinfer_decode_mods = ( rx.backend.cuda.flashinfer.gen_flashinfer_decode_module( @@ -393,7 +392,8 @@ def __init__( # pylint: disable=too-many-locals dtype_o=dtype, qk_head_dim=qk_head_dim, v_head_dim=v_head_dim, - target=target, + enable_inline_rope=False, + return_static_libs=True, ) if attn_kind_single == "mha" else [] @@ -405,7 +405,7 @@ def __init__( # pylint: disable=too-many-locals dtype_o=dtype, head_dim_ckv=v_head_dim, head_dim_kpe=qk_head_dim - v_head_dim, - target=target, + return_static_libs=True, ) if attn_kind_single == "mla" else [] @@ -417,8 +417,8 @@ def __init__( # pylint: disable=too-many-locals bb = rx.BlockBuilder.current() mha_functions = ( [ - rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_with_paged_kv_cache_run"), rx.ExternFunc("batch_prefill_with_kv_cache_plan")]), - rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_decode_with_paged_kv_cache_run"), rx.ExternFunc("batch_decode_with_paged_kv_cache_plan")]), + rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_paged_run"), rx.ExternFunc("batch_prefill_plan")]), + rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_decode_run"), rx.ExternFunc("batch_decode_plan")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), @@ -427,7 +427,8 @@ def __init__( # pylint: disable=too-many-locals if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] ) - mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" else []) + ragged_prefill_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_ragged_run"), rx.ExternFunc("batch_prefill_plan")]) if attn_kind_single == "mha" else rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_ragged_run"), rx.ExternFunc("batch_prefill_plan"), rx.PrimValue(mla_original_qk_head_dim), rx.PrimValue(mla_original_v_head_dim)]) + mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_run"), rx.ExternFunc("batch_mla_plan")] if attn_kind_single == "mla" else []) attn_merge_functions = [ bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), ] @@ -463,7 +464,7 @@ def __init__( # pylint: disable=too-many-locals rx.op.zeros((), dtype), bb.add_func(_kv_cache_transpose_append(num_key_value_heads, qk_head_dim, dtype), "kv_cache_transpose_append"), bb.add_func(_kv_cache_transpose_append_mla(qk_head_dim, dtype), "kv_cache_transpose_append_mla"), - rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_with_ragged_kv_cache_run"), rx.ExternFunc("batch_prefill_with_kv_cache_plan")]), + ragged_prefill_function, *mha_functions, mla_function, rx.Tuple(attn_merge_functions), diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index 3b37d9810b1c..13e151ecd202 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -59,11 +59,18 @@ std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { - CHECK_EQ(args.size(), 3); + CHECK(args.size() == 3 || args.size() == 5); ffi::Function attn_func = args[1].cast(); ffi::Function plan_func = args[2].cast(); + int64_t qk_head_dim_override = -1; + int64_t v_head_dim_override = -1; + if (args.size() == 5) { + qk_head_dim_override = args[3].cast(); + v_head_dim_override = args[4].cast(); + } return std::make_unique(std::move(attn_func), std::move(plan_func), - attn_kind); + attn_kind, qk_head_dim_override, + v_head_dim_override); } LOG(FATAL) << "Cannot reach here"; throw; diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index ea5f49c6c08a..1fd22a97abdc 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -57,6 +58,22 @@ class AttnBackendFunc { virtual ~AttnBackendFunc() = default; protected: + // helper allocator class for creating strided view of a Tensor + // that applies byte offset to the original data pointer + class ViewBasedAlloc { + public: + explicit ViewBasedAlloc(Tensor source) : source_(source) {} + void AllocData(DLTensor* tensor, int64_t* strides, int64_t extra_byte_offset) { + tensor->data = static_cast(source_->data) + extra_byte_offset; + tensor->strides = strides; + } + + void FreeData(DLTensor* tensor) {} + + private: + Tensor source_; + }; + ffi::Function attn_func_; public: @@ -133,16 +150,34 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, plan_info_vec] = cached_buffers_[depth]; double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; - attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, qo_indptr, - page_indptr, page_indices, length_info, q_rope_position, k_rope_pos_offset, - attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), - /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), - /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, - /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + + ICHECK_EQ(pages.ndim(), 5); + int H = pages->shape[2]; + int N = pages->shape[3]; + int D = pages->shape[4]; + CHECK(pages.IsContiguous()); + std::vector pages_k_v_shape = {pages->shape[0], H, N, D}; + std::vector pages_k_v_strides = {2 * H * N * D, N * D, D, 1}; + Tensor pages_k = + Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, + pages->device, pages_k_v_strides.data(), pages->byte_offset); + Tensor pages_v = Tensor::FromNDAlloc( + ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, pages->device, + pages_k_v_strides.data(), pages->byte_offset + (H * N * D) * pages.DataType().bytes()); + + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages_k, pages_v, + qo_indptr, page_indptr, page_indices, length_info, attn_output, attn_lse, + /*mask_mode_code=*/static_cast(causal), /*layout(HND)=*/1, + /*window_left=*/-1, /*enable_pdl=*/false, sm_scale, + /*rope_rcp_scale=*/rope_rcp_scale, /*rope_rcp_theta=*/rope_rcp_theta); + DeviceAPI::Get(device)->SetStream(device, original_stream); } void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, @@ -150,9 +185,43 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, plan_info_vec] = cached_buffers_[depth]; - attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, page_indices, - attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), - /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1], sm_scale, compute_stream); + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); + ICHECK_NE(qk_head_dim_, -1); + ICHECK_NE(v_head_dim_, -1); + int64_t H = q->shape[1]; + int64_t page_size = pages->shape[1]; + int64_t rope_head_dim = qk_head_dim_ - v_head_dim_; + int64_t nope_head_dim = q->shape[2] - rope_head_dim; + + // Split q into q_nope and q_pe + CHECK(q.IsContiguous()); + std::vector q_nope_shape = {q->shape[0], H, nope_head_dim}; + std::vector q_pe_shape = {q->shape[0], H, rope_head_dim}; + std::vector q_strides = {H * q->shape[2], q->shape[2], 1}; + Tensor q_nope = Tensor::FromNDAlloc(ViewBasedAlloc(q), ffi::Shape(q_nope_shape), q->dtype, + q->device, q_strides.data(), q->byte_offset); + Tensor q_pe = Tensor::FromNDAlloc(ViewBasedAlloc(q), ffi::Shape(q_pe_shape), q->dtype, + q->device, q_strides.data(), + q->byte_offset + nope_head_dim * q.DataType().bytes()); + // Split pages into kv_nope and kv_pe + CHECK(pages.IsContiguous()); + std::vector kv_nope_shape = {pages->shape[0], page_size, nope_head_dim}; + std::vector kv_pe_shape = {pages->shape[0], page_size, rope_head_dim}; + std::vector kv_strides = {page_size * pages->shape[2], pages->shape[2], 1}; + Tensor kv_nope = + Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(kv_nope_shape), pages->dtype, + pages->device, kv_strides.data(), pages->byte_offset); + Tensor kv_pe = Tensor::FromNDAlloc( + ViewBasedAlloc(pages), ffi::Shape(kv_pe_shape), pages->dtype, pages->device, + kv_strides.data(), pages->byte_offset + nope_head_dim * pages.DataType().bytes()); + + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q_nope, q_pe, kv_nope, + kv_pe, page_indices, attn_output, attn_lse, + /*mask_mode_code=*/static_cast(causal), + /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1], sm_scale); + DeviceAPI::Get(device)->SetStream(device, original_stream); } void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, @@ -161,31 +230,37 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { int64_t batch_size, int64_t total_qo_len, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) final { - std::vector kv_len; - kv_len.reserve(batch_size); + Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32), Device{kDLCPU, 0}); + int32_t* kv_len_arr_data = static_cast(kv_len_arr.data_ptr()); for (int i = 0; i < static_cast(batch_size); ++i) { - kv_len.push_back((*page_indptr)[i + 1] != (*page_indptr)[i] - ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size + - (*last_page_len)[i] - : 0); + kv_len_arr_data[i] = + (*page_indptr)[i + 1] != (*page_indptr)[i] + ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size + (*last_page_len)[i] + : 0; } - IntTuple plan_info_vec; + qk_head_dim_ = qk_head_dim; + v_head_dim_ = v_head_dim; + ffi::Array plan_info_vec; + Device device = float_workspace_buffer->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, copy_stream); if (attn_kind == AttnKind::kMHA) { // Todo(tvm-team): enable cuda graph plan_info_vec = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, - qo_indptr->as_tensor(), page_indptr->as_tensor(), IntTuple(std::move(kv_len)), - total_qo_len, batch_size, num_qo_heads, num_kv_heads, page_size, + qo_indptr->as_tensor(), page_indptr->as_tensor(), kv_len_arr, total_qo_len, + batch_size, num_qo_heads, num_kv_heads, page_size, /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, - /*window_left=*/-1, copy_stream) - .cast(); + /*window_left=*/-1, /*fixed_split_size=*/-1, /*disable_split_kv=*/false) + .cast>(); } else if (attn_kind == AttnKind::kMLA) { plan_info_vec = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, - qo_indptr->as_tensor(), page_indptr->as_tensor(), IntTuple(std::move(kv_len)), - num_qo_heads, v_head_dim, causal, copy_stream) - .cast(); + qo_indptr->as_tensor(), page_indptr->as_tensor(), kv_len_arr, num_qo_heads, + v_head_dim, causal) + .cast>(); } + DeviceAPI::Get(device)->SetStream(device, original_stream); if (cached_buffers_.size() <= static_cast(depth)) { cached_buffers_.resize(depth + 1); @@ -196,8 +271,10 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { } private: + int64_t qk_head_dim_ = -1; + int64_t v_head_dim_ = -1; ffi::Function plan_func_; - std::vector> cached_buffers_; + std::vector>> cached_buffers_; }; /*! \brief The ragged prefill attention function base class. */ @@ -244,23 +321,30 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc { class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { public: explicit FlashInferRaggedPrefillFunc(ffi::Function attn_func, ffi::Function plan_func, - AttnKind attn_kind) + AttnKind attn_kind, int64_t qk_head_dim_override, + int64_t v_head_dim_override) : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), + qk_head_dim_override_(qk_head_dim_override), + v_head_dim_override_(v_head_dim_override), plan_func_(std::move(plan_func)) {} void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position, Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; attn_func_(float_workspace_buffer_, int_workspace_buffer_, plan_info_vec_, q, k, v, qo_indptr, - kv_indptr, q_rope_position, k_rope_pos_offset, attn_output, attn_lse, + kv_indptr, attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), - /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), - /*layout(NHD)=*/0, /*window_left=*/-1, sm_scale, + /*layout(NHD)=*/0, /*window_left=*/-1, + /*enable_pdl=*/false, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, - /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + /*rope_rcp_theta=*/rope_rcp_theta); + DeviceAPI::Get(device)->SetStream(device, original_stream); } void BeginForward(Tensor float_workspace_buffer, Tensor int_workspace_buffer, @@ -268,30 +352,42 @@ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { HostMemoryVector* kv_indptr, int64_t batch_size, int64_t total_qo_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) final { - std::vector kv_len; - kv_len.reserve(batch_size); + Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32), Device{kDLCPU, 0}); + int32_t* kv_len_arr_data = static_cast(kv_len_arr.data_ptr()); for (int i = 0; i < static_cast(batch_size); ++i) { - kv_len.push_back((*kv_indptr)[i + 1] - (*kv_indptr)[i]); + kv_len_arr_data[i] = (*kv_indptr)[i + 1] - (*kv_indptr)[i]; + } + if (qk_head_dim_override_ != -1) { + qk_head_dim = qk_head_dim_override_; + } + if (v_head_dim_override_ != -1) { + v_head_dim = v_head_dim_override_; } // Todo(tvm-team): enable cuda graph float_workspace_buffer_ = float_workspace_buffer; int_workspace_buffer_ = int_workspace_buffer; page_locked_int_workspace_buffer_ = page_locked_int_workspace_buffer; + Device device = float_workspace_buffer->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, copy_stream); plan_info_vec_ = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, - qo_indptr->as_tensor(), kv_indptr->as_tensor(), IntTuple(std::move(kv_len)), - total_qo_len, batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1, + qo_indptr->as_tensor(), kv_indptr->as_tensor(), kv_len_arr, total_qo_len, + batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1, /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, - /*window_left=*/-1, copy_stream) - .cast(); + /*window_left=*/-1, /*fixed_split_size=*/-1, /*disable_split_kv=*/false) + .cast>(); + DeviceAPI::Get(device)->SetStream(device, original_stream); } private: + int64_t qk_head_dim_override_; + int64_t v_head_dim_override_; ffi::Function plan_func_; Tensor float_workspace_buffer_; Tensor int_workspace_buffer_; Tensor page_locked_int_workspace_buffer_; - IntTuple plan_info_vec_; + ffi::Array plan_info_vec_; }; /*! \brief The paged decode attention function base class. */ @@ -359,15 +455,33 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { Tensor length_info, Tensor k_rope_pos_offset, Tensor q_rope_position, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, plan_info_vec] = cached_buffers_[depth]; double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; - attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, page_indptr, - page_indices, length_info, q_rope_position, k_rope_pos_offset, attn_output, attn_lse, - /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), - /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, - /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + + ICHECK_EQ(pages.ndim(), 5); + int H = pages->shape[2]; + int N = pages->shape[3]; + int D = pages->shape[4]; + CHECK(pages.IsContiguous()); + std::vector pages_k_v_shape = {pages->shape[0], H, N, D}; + std::vector pages_k_v_strides = {2 * H * N * D, N * D, D, 1}; + Tensor pages_k = + Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, + pages->device, pages_k_v_strides.data(), pages->byte_offset); + Tensor pages_v = Tensor::FromNDAlloc( + ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, pages->device, + pages_k_v_strides.data(), pages->byte_offset + (H * N * D) * pages.DataType().bytes()); + + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages_k, pages_v, + page_indptr, page_indices, length_info, attn_output, attn_lse, + /*layout(HND)=*/1, /*window_left=*/-1, /*enable_pdl=*/false, sm_scale, + /*rope_rcp_scale=*/rope_rcp_scale, /*rope_rcp_theta=*/rope_rcp_theta); + DeviceAPI::Get(device)->SetStream(device, original_stream); } void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, @@ -377,13 +491,18 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, TVMStreamHandle copy_stream) final { // Todo(tvm-team): enable cuda graph - IntTuple plan_info_vec = + Tensor empty_qkv_data = Tensor::Empty({1}, q_dtype, Device{kDLCPU, 0}); + Device device = float_workspace_buffer->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, copy_stream); + ffi::Array plan_info_vec = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, page_indptr->as_tensor(), batch_size, num_qo_heads, num_kv_heads, page_size, /*enable_cuda_graph=*/false, - static_cast(rope_mode == RoPEMode::kInline), - /*window_left=*/-1, qk_head_dim, v_head_dim, q_dtype, kv_dtype, copy_stream) - .cast(); + /*window_left=*/-1, /*logits_soft_cap=*/0.0, qk_head_dim, v_head_dim, + empty_qkv_data, empty_qkv_data) + .cast>(); + DeviceAPI::Get(device)->SetStream(device, original_stream); if (cached_buffers_.size() <= static_cast(depth)) { cached_buffers_.resize(depth + 1); @@ -395,7 +514,7 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { private: ffi::Function plan_func_; - std::vector> cached_buffers_; + std::vector>> cached_buffers_; }; /*! \brief The paged prefill with tree mask attention function base class. */ diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index 09557a8f0a27..1c695a10e25d 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -860,8 +860,9 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { sliding_window_offset->data(), n_elem * elem_byte_size_); std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + 2 * n_elem, sink_size->data(), n_elem * elem_byte_size_); - Tensor view = merged_attn_aux_data_device_.CreateView( - {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = + Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), ffi::Shape({3, n_elem}), + dtype_aux_, device_, attn_aux_data_copy_offset_ * elem_byte_size_); attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem); return view; } @@ -895,8 +896,9 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { src_data->data(), n_elem * elem_byte_size_); std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_ + n_elem, dst_data->data(), n_elem * elem_byte_size_); - Tensor view = merged_compact_kv_aux_data_device_.CreateView( - {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_), + ffi::Shape({2, n_elem}), dtype_aux_, device_, + compact_kv_aux_data_copy_offset_ * elem_byte_size_); compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem); return view; } @@ -919,6 +921,20 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { } private: + // helper allocator class that applies byte offset to the original data pointer + class ViewHelper { + public: + explicit ViewHelper(Tensor source) : source_(source) {} + void AllocData(DLTensor* tensor, int64_t extra_byte_offset) { + tensor->data = static_cast(source_->data) + extra_byte_offset; + } + + void FreeData(DLTensor* tensor) {} + + private: + Tensor source_; + }; + /*! * \brief Calculate the start element offsets of the auxiliary arrays in the local cache. * \return Return the local cache size (total number of elements in the local cache). @@ -990,8 +1006,9 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { int64_t n_elem = data->size(); std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, data->data(), n_elem * elem_byte_size_); - Tensor view = merged_attn_aux_data_device_.CreateView( - {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = + Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), ffi::Shape({n_elem}), + dtype_aux_, device_, attn_aux_data_copy_offset_ * elem_byte_size_); attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); return view; } @@ -1000,8 +1017,9 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { int64_t n_elem = data->size(); std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, data->data(), n_elem * elem_byte_size_); - Tensor view = merged_compact_kv_aux_data_device_.CreateView( - {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_), + ffi::Shape({n_elem}), dtype_aux_, device_, + compact_kv_aux_data_copy_offset_ * elem_byte_size_); compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); return view; } diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 0f3f56866134..4fb3cd69d60f 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -2052,7 +2052,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { temp_float_attn_workspace_, temp_int_attn_workspace_[0], temp_int_pinned_attn_workspace_[0], &cur_append_lengths_indptr_host_, &cur_append_lengths_indptr_host_, cur_batch_size_, - cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_kv_heads_, qk_head_dim_, + cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_qo_heads_, qk_head_dim_, v_head_dim_, /*causal=*/true, copy_stream_); } } diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py index 8333e4b2d66b..da6fdacebdbd 100644 --- a/tests/python/relax/test_group_gemm_flashinfer.py +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -18,14 +18,14 @@ """Test for FlashInfer GroupedGemm TVM integration""" import math + import numpy as np import pytest import torch + import tvm import tvm.testing from tvm import relax -from tvm.contrib import utils -from tvm.relax.backend.cuda import flashinfer DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 fp8_dtype = "float8_e4m3fn" @@ -389,36 +389,11 @@ def test_grouped_gemm_correctness( device = tvm.cuda(0) target = tvm.target.Target.from_device(device) - def _load_module(name: str, static_modules): - """Helper function to load compiled modules.""" - assert len(static_modules) > 0 - if len(static_modules) == 1: - return static_modules[0] - static_mod = static_modules[0] - for mod in static_modules[1:]: - static_mod.import_module(mod) - temp = tvm.contrib.utils.tempdir() - mod_path = temp.relpath(f"{name}.so") - static_mod.export_library(mod_path) - return tvm.runtime.load_module(mod_path) - # Generate the module - modules = relax.backend.cuda.flashinfer.gen_grouped_gemm_module( - dtype_a=dtype_a, - dtype_b=dtype_b, - dtype_out=dtype_out, - scale_granularity_m=scale_granularity_m, - scale_granularity_n=scale_granularity_n, - scale_granularity_k=scale_granularity_k, - scale_major_mode=scale_major_mode, - mma_sm=mma_sm, - target=target, - num_threads=4, - ) + mod = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(target=target)[0] # Load the module - mod = _load_module("flashinfer_grouped_gemm", modules) - grouped_gemm_fn = mod["grouped_gemm_fp8_run"] + grouped_gemm_fn = mod["group_gemm_fp8_nt_groupwise"] # Generate test data test_data = generate_test_data( @@ -460,7 +435,11 @@ def _load_module(name: str, static_modules): test_data["m_indptr"], # m_indptr test_data["n"], # n (scalar) test_data["k"], # k (scalar) - None, # cuda_stream (use default stream) + scale_granularity_m, + scale_granularity_n, + scale_granularity_k, + scale_major_mode, + mma_sm, ) # Compute reference result diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index dd29140e9bb2..4aae9dec5995 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -23,7 +23,6 @@ import tvm.testing from tvm import dlight as dl from tvm import relax -from tvm.contrib import utils from tvm.relax.frontend.nn.llm.kv_cache import ( AttnKind, RopeMode, @@ -78,7 +77,7 @@ fcompact_copy = None -def set_global_func(): +def set_global_func(rope_mode: RopeMode): global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fpopn global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, fdebug_get_kv global fattention_prefill, fattention_decode, fattention_prefill_ragged @@ -98,48 +97,30 @@ def set_global_func(): ) fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") - def load_module(name: str, static_modules: List[tvm.runtime.Module]): - assert len(static_modules) > 0 - if len(static_modules) == 1: - return static_modules[0] - static_mod = static_modules[0] - for mod in static_modules[1:]: - static_mod.import_module(mod) - temp = utils.tempdir() - mod_path = temp.relpath(f"{name}.so") - static_mod.export_library(mod_path) - return tvm.runtime.load_module(mod_path) - target = tvm.target.Target.from_device(device) - flashinfer_prefill_mod = load_module( - "flashinfer_prefill", - relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - qk_head_dim=head_dim, - v_head_dim=head_dim, - target=target, - ), - ) - flashinfer_decode_mod = load_module( - "flashinfer_decode", - relax.backend.cuda.flashinfer.gen_flashinfer_decode_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - qk_head_dim=head_dim, - v_head_dim=head_dim, - target=target, - ), - ) - - fattention_prefill = flashinfer_prefill_mod["batch_prefill_with_paged_kv_cache_run"] - fattention_prefill_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] - fattention_prefill_ragged = flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"] - fattention_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] - fattention_decode = flashinfer_decode_mod["batch_decode_with_paged_kv_cache_run"] - fattention_decode_plan = flashinfer_decode_mod["batch_decode_with_paged_kv_cache_plan"] + flashinfer_prefill_mod = relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=head_dim, + v_head_dim=head_dim, + enable_inline_rope=rope_mode == RopeMode.INLINE, + )[0] + flashinfer_decode_mod = relax.backend.cuda.flashinfer.gen_flashinfer_decode_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=head_dim, + v_head_dim=head_dim, + enable_inline_rope=rope_mode == RopeMode.INLINE, + )[0] + + fattention_prefill = flashinfer_prefill_mod["batch_prefill_paged_run"] + fattention_prefill_plan = flashinfer_prefill_mod["batch_prefill_plan"] + fattention_prefill_ragged = flashinfer_prefill_mod["batch_prefill_ragged_run"] + fattention_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_plan"] + fattention_decode = flashinfer_decode_mod["batch_decode_run"] + fattention_decode_plan = flashinfer_decode_mod["batch_decode_plan"] builts = [] for tir_func in [ @@ -560,8 +541,8 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode): if __name__ == "__main__": - set_global_func() - for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]: + for rope_mode in [RopeMode.NONE, RopeMode.NORMAL]: + set_global_func(rope_mode) cache = create_kv_cache(rope_mode) test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode)) test_paged_attention_kv_cache_remove_sequence((cache, rope_mode)) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py index e3de4944fef9..cd76f9ce20a7 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py @@ -25,7 +25,6 @@ import tvm.testing from tvm import dlight as dl from tvm import relax -from tvm.contrib import utils from tvm.relax.frontend.nn.llm.kv_cache import ( AttnKind, RopeMode, @@ -115,47 +114,27 @@ def set_global_func(dtype): fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv_mla") - def load_module(name: str, static_modules: List[tvm.runtime.Module]): - assert len(static_modules) > 0 - if len(static_modules) == 1: - return static_modules[0] - static_mod = static_modules[0] - for mod in static_modules[1:]: - static_mod.import_module(mod) - temp = utils.tempdir() - mod_path = temp.relpath(f"{name}.so") - static_mod.export_library(mod_path) - return tvm.runtime.load_module(mod_path) - target = tvm.target.Target.from_device(device) - flashinfer_prefill_mod = load_module( - "flashinfer_prefill", - relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - target=target, - enable_inline_rope=False, - ), - ) - flashinfer_mla_mod = load_module( - "flashinfer_mla", - relax.backend.cuda.flashinfer.gen_flashinfer_mla_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - head_dim_ckv=kv_lora_rank, - head_dim_kpe=qk_rope_head_dim, - target=target, - ), - ) - - fattn_prefill_ragged = flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"] - fattn_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] - fmla_prefill = flashinfer_mla_mod["batch_mla_paged_attention_run"] - fmla_prefill_plan = flashinfer_mla_mod["batch_mla_paged_attention_plan"] + flashinfer_prefill_mod = relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + enable_inline_rope=False, + )[0] + flashinfer_mla_mod = relax.backend.cuda.flashinfer.gen_flashinfer_mla_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + head_dim_ckv=kv_lora_rank, + head_dim_kpe=qk_rope_head_dim, + )[0] + + fattn_prefill_ragged = flashinfer_prefill_mod["batch_prefill_ragged_run"] + fattn_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_plan"] + fmla_prefill = flashinfer_mla_mod["batch_mla_run"] + fmla_prefill_plan = flashinfer_mla_mod["batch_mla_plan"] builts = [] for tir_func in [ @@ -221,7 +200,13 @@ def create_kv_cache(dtype): tvm.runtime.empty((), dtype, device=device), None, # f_transpose_append_mha ftranspose_append, - ["flashinfer", fattn_prefill_ragged, fattn_prefill_ragged_plan], # fattn_prefill_ragged + [ + "flashinfer", + fattn_prefill_ragged, + fattn_prefill_ragged_plan, + qk_nope_head_dim + qk_rope_head_dim, + v_head_dim, + ], # fattn_prefill_ragged [], # fattn_prefill [], # fattn_decode [], # fattn_prefill_sliding_window