Skip to content

Commit 0e055fc

Browse files
authored
[FlashInfer] Update include path and interface (#18317)
This PR updates the include path for FlashInfer JIT compilation, and also updates the plan function interface for attention prefill computation, to align with recent interface change in flashinfer-ai/flashinfer#1661.
1 parent 70e9164 commit 0e055fc

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

python/tvm/relax/backend/cuda/flashinfer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def get_object_file_path(src: Path) -> Path:
141141
)
142142
include_paths += [
143143
Path(tvm_home).resolve() / "include",
144-
Path(tvm_home).resolve() / "ffi" / "include",
145-
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
144+
Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "include",
145+
Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "3rdparty" / "dlpack" / "include",
146146
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
147147
]
148148
else:
@@ -160,8 +160,13 @@ def get_object_file_path(src: Path) -> Path:
160160
# The package is installed from source.
161161
include_paths += [
162162
tvm_package_path.parent.parent / "include",
163-
tvm_package_path.parent.parent / "ffi" / "include",
164-
tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack" / "include",
163+
tvm_package_path.parent.parent / "3rdparty" / "tvm-ffi" / "include",
164+
tvm_package_path.parent.parent
165+
/ "3rdparty"
166+
/ "tvm-ffi"
167+
/ "3rdparty"
168+
/ "dlpack"
169+
/ "include",
165170
tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include",
166171
]
167172
else:

src/runtime/vm/attn_backend.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc {
176176
plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer,
177177
qo_indptr->as_tensor(), page_indptr->as_tensor(), IntTuple(std::move(kv_len)),
178178
total_qo_len, batch_size, num_qo_heads, num_kv_heads, page_size,
179-
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream)
179+
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal,
180+
/*window_left=*/-1, copy_stream)
180181
.cast<IntTuple>();
181182
} else if (attn_kind == AttnKind::kMLA) {
182183
plan_info_vec =
@@ -280,7 +281,8 @@ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc {
280281
plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer,
281282
qo_indptr->as_tensor(), kv_indptr->as_tensor(), IntTuple(std::move(kv_len)),
282283
total_qo_len, batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1,
283-
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream)
284+
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal,
285+
/*window_left=*/-1, copy_stream)
284286
.cast<IntTuple>();
285287
}
286288

0 commit comments

Comments
 (0)