-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Runtime error with single_prefill_with_kv_cache while Compilation #541
Comments
Hi @hnyls2002 have you ever met such errors before in sglang integration? |
Seems the issue is with flashinfer/python/flashinfer/prefill.py Lines 277 to 281 in 78e26e4
Could you try specifying |
Even when I only pass the q k v arguments and omit the others, the error still occurs torch.library.define(
"mylib::custom_func_flashinfer",
"(Tensor q, Tensor k, Tensor v) -> Tensor",
)
@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v):
return single_prefill_with_kv_cache(
q, k, v
)
@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v):
return torch.empty_like(q)
def attn(q, k, v):
return torch.ops.mylib.custom_func_flashinfer(q, k, v)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True) |
Thank you for your response. I have tried using import torch
import flashinfer
import torch.nn as nn
from torch import Tensor
data_type = torch.bfloat16
def generate_data_x():
x = torch.randn(1, 1024, 4096, device='cuda', dtype=data_type)
return x
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
class Attention(nn.Module):
def __init__(self):
super().__init__()
total_head_dim = (32 + 2 * 8) * 128
self.wqkv = nn.Linear(4096, total_head_dim).to(torch.bfloat16)
self.attn = flashinfer.single_prefill_with_kv_cache
def forward(self, x: Tensor) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.wqkv(x).split([4096, 8*128, 8*128], dim=-1)
q = q.view(bsz, seqlen, 32, 128)
k = k.view(bsz, seqlen, 8, 128)
v = v.view(bsz, seqlen, 8, 128)
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
y = self.attn(q, k, v)
return y
self_attn = Attention().to("cuda")
attn = lambda model, x: model(x)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)
for i in range(10):
x = generate_data_x()
o, run_time = timed(lambda: attn(self_attn, x))
print(run_time) |
The apis start with |
I tried to compile single_prefill_with_kv_cache using torch.compile.
cause following runtime error
The text was updated successfully, but these errors were encountered: