Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime error with single_prefill_with_kv_cache while Compilation #541

Open
YudiZh opened this issue Oct 20, 2024 · 6 comments
Open

Runtime error with single_prefill_with_kv_cache while Compilation #541

YudiZh opened this issue Oct 20, 2024 · 6 comments

Comments

@YudiZh
Copy link

YudiZh commented Oct 20, 2024

I tried to compile single_prefill_with_kv_cache using torch.compile.

import torch
from flashinfer import single_prefill_with_kv_cache

data_type = torch.bfloat16

QH=64
KH=8
S=1024
D=128

def generate_data():
    q = torch.randn(S, QH, D, device='cuda', dtype=data_type)
    k = torch.randn(S, KH, D, device='cuda', dtype=data_type)
    v = torch.randn(S, KH, D, device='cuda', dtype=data_type)
    return q, k, v

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

torch.library.define(
    "mylib::custom_func_flashinfer",
    "(Tensor q, Tensor k, Tensor v, Tensor custom_mask) -> Tensor",
)

@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v, custom_mask):
    return single_prefill_with_kv_cache(
        q, k, v, custom_mask=custom_mask
    )

@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v, custom_mask):
    return torch.empty_like(q)


def attn(q, k, v, custom_mask=None):
    return torch.ops.mylib.custom_func_flashinfer(q, k, v, custom_mask=custom_mask)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)


for i in range(10):
    q, k, v = generate_data()
    mask = torch.tril(
        torch.full((S, S), True, device="cuda:0"),
    )
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
    print(run_time)

cause following runtime error

/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py:37: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("mylib::custom_func_flashinfer")
Traceback (most recent call last):
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <module>
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 21, in timed
    result = fn()
             ^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <lambda>
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 42, in attn
    def attn(q, k, v, custom_mask=None):
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 987, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 217, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 451, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1131, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 993, in run
    return compiled_fn(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 373, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 403, in cudagraphify
    return manager.add_function(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2089, in add_function
    return fn, fn(inputs)
               ^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1841, in run
    out = self._run(new_inputs, function_id)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1932, in _run
    return self.run_eager(new_inputs, function_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2055, in run_eager
    return node.run(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 646, in run
    check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1699, in check_memory_pool
    raise RuntimeError(msg)
RuntimeError: These live storage data ptrs are in the cudagraph pool but not accounted for as an output of cudagraph trees: 

Data Pointer: 22959854977024, history: 
@yzh119
Copy link
Collaborator

yzh119 commented Oct 20, 2024

Hi @hnyls2002 have you ever met such errors before in sglang integration?

@yzh119
Copy link
Collaborator

yzh119 commented Oct 21, 2024

Seems the issue is with custom_mask, which internally calls flashinfer.packbits:

if custom_mask is not None and packed_custom_mask is None:
# create packed custom mask from custom mask
packed_custom_mask = packbits(
custom_mask.contiguous().view(-1), bitorder="little"
)

Could you try specifying packed_custom_mask argument instead of custom_mask argument? Or decorating flashinfer.packbits with pytorch cuda ops API as well.

@YudiZh
Copy link
Author

YudiZh commented Oct 21, 2024

Even when I only pass the q k v arguments and omit the others, the error still occurs

torch.library.define(
    "mylib::custom_func_flashinfer",
    "(Tensor q, Tensor k, Tensor v) -> Tensor",
)

@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v):
    return single_prefill_with_kv_cache(
        q, k, v
    )

@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v):
    return torch.empty_like(q)

def attn(q, k, v):
    return torch.ops.mylib.custom_func_flashinfer(q, k, v)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)

@abcdabcd987
Copy link
Member

abcdabcd987 commented Oct 24, 2024

@YudiZh Can you try torch.compile(..., fullgraph=True, mode="max-autotune-no-cudagraphs")? Cuda Graph provides little value when you are just capturing one CUDA kernel.

BTW, we are adding torch library annotations in #554

@YudiZh
Copy link
Author

YudiZh commented Oct 26, 2024

Thank you for your response. I have tried using torch.compile(..., fullgraph=True, mode="max-autotune-no-cudagraphs"), and the code runs without errors. However, when I aim to implement CUDA graphs for flashinfer and other PyTorch operations within the model's forward function, which mode should I use to achieve compilation? Does using max-autotune-no-cudagraphs result in PyTorch operations not reaching the expected acceleration when CUDA graphs are not involved? Here is a sample of my code:

import torch
import flashinfer
import torch.nn as nn
from torch import Tensor
data_type = torch.bfloat16

def generate_data_x():
    x = torch.randn(1, 1024, 4096, device='cuda', dtype=data_type)
    return x

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        total_head_dim = (32 + 2 * 8) * 128
        self.wqkv = nn.Linear(4096, total_head_dim).to(torch.bfloat16)
        self.attn = flashinfer.single_prefill_with_kv_cache
    
    def forward(self, x: Tensor) -> Tensor:
        bsz, seqlen, _ = x.shape
        q, k, v = self.wqkv(x).split([4096, 8*128, 8*128], dim=-1)

        q = q.view(bsz, seqlen, 32, 128)
        k = k.view(bsz, seqlen, 8, 128)
        v = v.view(bsz, seqlen, 8, 128)

        q = q.squeeze(0)
        k = k.squeeze(0)
        v = v.squeeze(0)

        y = self.attn(q, k, v)
        return y

self_attn = Attention().to("cuda")
attn = lambda model, x: model(x)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)


for i in range(10):
    x = generate_data_x()
    o, run_time = timed(lambda: attn(self_attn, x))
    print(run_time)

@yzh119
Copy link
Collaborator

yzh119 commented Oct 26, 2024

The apis start with single_ are not compatible with cudagraphs (I might spend some time to make them compatible when cudagraphs later).
The BatchPrefill/BatchDecode wrappers have been designed to be compatible with CUDAGraph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants