Skip to content

Commit

Permalink
Fix TP > 1 cuda graphs
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson and alexm-redhat committed Jan 31, 2025
1 parent 433322b commit f2b2500
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 19 deletions.
3 changes: 2 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def __init__(self, runner: "ModelRunnerBase"):

@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
"""Context manager used when capturing CUDA graphs."""
yield

Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def _get_decode_wrapper(self):
return self._decode_wrapper

@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
Expand Down
92 changes: 81 additions & 11 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,33 +90,103 @@ class TritonMLAState(AttentionState):

def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False

@contextmanager
def graph_capture(self, max_batch_size: int):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
self._is_graph_capturing = True

self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)

self._positions = torch.zeros((max_batch_size, ),
dtype=torch.long,
device=self.runner.device)

yield

self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._positions

def graph_clone(self, batch_size: int):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
assert self._is_graph_capturing
return self.__class__(self.runner)

def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
self,
batch_size: int,
is_encoder_decoder_model: bool = False,
positions: Optional[torch.Tensor] = None):
assert self._is_graph_capturing

attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())

if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")

return attn_metadata

def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
"input_positions": attn_metadata.decode_metadata.input_positions,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")

return input_buffers

def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
input_buffers["input_positions"][:attn_metadata.decode_metadata.
input_positions.shape[0]].copy_(
attn_metadata.decode_metadata.
input_positions,
non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")

def begin_forward(self, model_input):
return
Expand Down
16 changes: 13 additions & 3 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)

import numpy as np
import torch
Expand Down Expand Up @@ -288,8 +289,12 @@ def __init__(self, runner: "ModelRunnerBase"):
self._is_graph_capturing = False

@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
assert positions is None

self._is_graph_capturing = True

self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
Expand All @@ -299,7 +304,9 @@ def graph_capture(self, max_batch_size: int):
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)

yield

self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
Expand All @@ -310,7 +317,10 @@ def graph_clone(self, batch_size: int) -> "CommonAttentionState":
return self.__class__(self.runner)

def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
self,
batch_size: int,
is_encoder_decoder_model: bool = False,
positions: Optional[torch.Tensor] = None):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
Expand Down
9 changes: 6 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,11 +1483,13 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
self.vllm_config.compilation_config.
cudagraph_capture_sizes)
for batch_size in cudagraph_capture_sizes:
cur_input_positions = input_positions[..., :batch_size]
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder))
is_encoder_decoder,
positions=cur_input_positions))
# Disable KV Scale Calculation for graph capture
attn_metadata.enable_kv_scales_calculation = False
if self.lora_config:
Expand All @@ -1513,7 +1515,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"input_ids":
input_tokens[:batch_size],
"positions":
input_positions[..., :batch_size],
cur_input_positions,
"intermediate_inputs":
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
Expand Down Expand Up @@ -1974,7 +1976,8 @@ def forward(

# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
if positions is not None:
self.input_buffers["positions"].copy_(positions, non_blocking=True)

if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(
Expand Down

0 comments on commit f2b2500

Please sign in to comment.