Skip to content

Commit 6d336f6

Browse files
mzusmanMor Zusmantomeras91
committed
Cuda graph (vllm-project#5)
* Drop indecies when finish * min 1 attention layer * CG is working on forward pass passing * Remove comments * cosmetics - rename indecies -> indices, organize some whitespaces * Add some TODOs * Adding mamba cache for cg * Remove useless vars from input_metadata * Remove unused import * Set the seqlen offset to boolean * Return only hidden state * Return only hidden states * Add padding to match forward pass bs * Is prompt instead of seqlen offset * Remove mamba cache class (not used) * Another remove * Remove * Use mamba4gc * Fix mamba forward, run update only on non prompt * Use 1 index after the maximal index * Remove import * Remove import * typo * typo * place holder * Padding and empty token takes it from the first empty place * reformat * Apply suggestions from code review Whitespaces --------- Co-authored-by: Mor Zusman <morz@ai21.com> Co-authored-by: Tomer Asida <tomera@ai21.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
1 parent 07cc899 commit 6d336f6

File tree

7 files changed

+132
-87
lines changed

7 files changed

+132
-87
lines changed

vllm/model_executor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from vllm.model_executor.sampling_metadata import SamplingMetadata
22
from vllm.model_executor.utils import set_random_seed
33
from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo, MambaCache
4+
from vllm.model_executor.utils import set_random_seed
45

56
__all__ = [
67
"SamplingMetadata",
78
"set_random_seed",
89
"MambaCacheParams",
910
"RequestInfo",
1011
"MambaCache",
12+
"RequestInfo"
1113
]

vllm/model_executor/input_metadata.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from vllm.model_executor.mamba_metadata import MambaCache, RequestInfo
5+
from vllm.model_executor.mamba_metadata import RequestInfo
66

77

88
class InputMetadata:
@@ -45,7 +45,6 @@ def __init__(
4545
# Set during the execution of the first attention op.
4646
# FIXME(woosuk): This is a hack.
4747
self.attn_bias = None
48-
self.mamba_cache_batch: List[MambaCache] = []
4948
self.requests_info = requests_info
5049

5150
def __repr__(self) -> str:

vllm/model_executor/mamba_metadata.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
@dataclass
77
class MambaCacheParams:
8-
seqlen_offset: int = 0
8+
is_prompt: bool = False
99
conv_state: torch.Tensor = torch.Tensor()
1010
ssm_state: torch.Tensor = torch.Tensor()
1111

@@ -16,15 +16,3 @@ class RequestInfo:
1616
n: int = 1
1717

1818

19-
class MambaCache:
20-
def __init__(
21-
self,
22-
request_info: RequestInfo,
23-
layer_idx2mamba_cache: Optional[Dict[int, MambaCacheParams]] = None
24-
) -> None:
25-
self.request_info = request_info
26-
if layer_idx2mamba_cache is None:
27-
self.layer_idx2mamba_cache = defaultdict(MambaCacheParams)
28-
else:
29-
self.layer_idx2mamba_cache = layer_idx2mamba_cache
30-

vllm/model_executor/models/jamba.py

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from vllm.model_executor.weight_utils import (default_weight_loader,
3434
hf_model_weights_iterator)
3535
from vllm.sequence import SamplerOutput
36-
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
36+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
3737
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
3838
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
3939

@@ -114,7 +114,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
114114

115115
# 2. Convolution sequence transformation
116116
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
117-
if cache_params is not None and cache_params.seqlen_offset > 0:
117+
if cache_params is not None and not cache_params.is_prompt:
118118
hidden_states = causal_conv1d_update(
119119
hidden_states.squeeze(-1),
120120
cache_params.conv_state,
@@ -154,7 +154,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
154154
A = -torch.exp(self.A_log.float())
155155
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
156156
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
157-
if cache_params is not None and cache_params.seqlen_offset > 0:
157+
if cache_params is not None and not cache_params.is_prompt:
158158
scan_outputs = selective_state_update(
159159
cache_params.ssm_state,
160160
hidden_states[..., 0],
@@ -187,50 +187,14 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
187187
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
188188
return contextualized_states
189189

190-
def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata):
191-
if input_metadata.is_prompt:
192-
batch_size = hidden_states.shape[0]
193-
conv_cache = torch.zeros(
194-
batch_size,
195-
self.config.mamba_expand * self.config.hidden_size,
196-
self.config.mamba_d_conv,
197-
device=hidden_states.device,
198-
dtype=hidden_states.dtype
199-
)
200-
ssm_cache = torch.zeros(
201-
batch_size,
202-
self.config.mamba_expand * self.config.hidden_size,
203-
self.config.mamba_d_state,
204-
device=hidden_states.device,
205-
dtype=hidden_states.dtype
206-
)
207-
cache = MambaCacheParams(0, conv_cache, ssm_cache)
208-
else:
209-
for mamba_cache_request in input_metadata.mamba_cache_batch:
210-
# check if batch size of cache fits "n"
211-
n = mamba_cache_request.request_info.n
212-
if mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.shape[0] < n:
213-
expanded_dims_conv = (n, *mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.shape[1:])
214-
conv_state = mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.expand(*expanded_dims_conv)
215-
expanded_dims_ssm = (n, *mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state.shape[1:])
216-
ssm_state = mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state.expand(*expanded_dims_ssm)
217-
mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state = conv_state
218-
mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state = ssm_state
219-
220-
# mamba requires concatenated cache
221-
conv_state = torch.concat([req.layer_idx2mamba_cache[self.layer_idx].conv_state for req in input_metadata.mamba_cache_batch], dim=0)
222-
ssm_state = torch.concat([req.layer_idx2mamba_cache[self.layer_idx].ssm_state for req in input_metadata.mamba_cache_batch], dim=0)
223-
cache = MambaCacheParams(1, conv_state, ssm_state)
190+
def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor):
191+
cache = MambaCacheParams(
192+
input_metadata.is_prompt,
193+
conv_state=conv_state[self.layer_idx],
194+
ssm_state=ssm_state[self.layer_idx]
195+
)
224196
hidden_states = self.mamba_forward(hidden_states, cache_params=cache)
225197

226-
# split cache back to individual requests
227-
sample_id = 0
228-
for req_mamba_metadata in input_metadata.mamba_cache_batch:
229-
n = 1 if input_metadata.is_prompt else req_mamba_metadata.request_info.n
230-
req_mamba_metadata.layer_idx2mamba_cache[self.layer_idx].conv_state=cache.conv_state[sample_id:sample_id + n]
231-
req_mamba_metadata.layer_idx2mamba_cache[self.layer_idx].ssm_state=cache.ssm_state[sample_id:sample_id + n]
232-
sample_id += n
233-
234198
return hidden_states
235199

236200

@@ -352,6 +316,8 @@ def forward(self,
352316
hidden_states: torch.Tensor,
353317
input_metadata: InputMetadata,
354318
residual: Optional[torch.Tensor],
319+
conv_state: torch.Tensor,
320+
ssm_state: torch.Tensor,
355321
**kwargs):
356322

357323
if residual is None:
@@ -360,7 +326,12 @@ def forward(self,
360326
else:
361327
hidden_states, residual = self.input_layernorm(hidden_states, residual)
362328

363-
hidden_states = self.mamba(hidden_states, input_metadata)
329+
hidden_states = self.mamba(
330+
hidden_states,
331+
input_metadata,
332+
conv_state,
333+
ssm_state
334+
)
364335
# Fully Connected
365336
hidden_states, residual = self.pre_moe_layernorm(
366337
hidden_states, residual)
@@ -433,7 +404,8 @@ def self_attention(self,
433404
positions: torch.Tensor,
434405
hidden_states: torch.Tensor,
435406
kv_cache: KVCache,
436-
input_metadata: InputMetadata) -> torch.Tensor:
407+
input_metadata: InputMetadata,
408+
**kwargs) -> torch.Tensor:
437409
qkv, _ = self.qkv_proj(hidden_states)
438410
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
439411
# TODO - add embedding flag
@@ -450,7 +422,8 @@ def forward(
450422
hidden_states: torch.Tensor,
451423
kv_cache: KVCache,
452424
input_metadata: InputMetadata,
453-
residual: Optional[torch.Tensor]):
425+
residual: Optional[torch.Tensor],
426+
**kwargs):
454427
if residual is None:
455428
residual = hidden_states
456429
hidden_states = self.input_layernorm(hidden_states)
@@ -524,6 +497,8 @@ def forward(
524497
positions: torch.Tensor,
525498
kv_caches: List[KVCache],
526499
input_metadata: InputMetadata,
500+
conv_state: torch.Tensor,
501+
ssm_state: torch.Tensor
527502
) -> torch.Tensor:
528503
hidden_states = self.embed_tokens(input_ids)
529504
residual = None
@@ -534,7 +509,10 @@ def forward(
534509
hidden_states=hidden_states,
535510
kv_cache=kv_caches[i],
536511
input_metadata=input_metadata,
537-
residual=residual)
512+
residual=residual,
513+
conv_state=conv_state,
514+
ssm_state=ssm_state
515+
)
538516
hidden_states, _ = self.final_layernorm(hidden_states, residual)
539517
return hidden_states
540518

@@ -593,9 +571,17 @@ def forward(
593571
positions: torch.Tensor,
594572
kv_caches: List[KVCache],
595573
input_metadata: InputMetadata,
596-
) -> torch.Tensor:
597-
hidden_states = self.model(input_ids, positions, kv_caches,
598-
input_metadata)
574+
conv_state: torch.Tensor,
575+
ssm_state: torch.Tensor
576+
):
577+
hidden_states = self.model(
578+
input_ids,
579+
positions,
580+
kv_caches,
581+
input_metadata,
582+
conv_state,
583+
ssm_state
584+
)
599585
return hidden_states
600586

601587
def sample(

vllm/worker/cache_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_cache_block_size(
9393

9494
if is_mamba:
9595
attention_period = model_config.hf_config.attn_layer_period
96-
num_layers = num_layers // attention_period
96+
num_layers = max(num_layers // attention_period, 1)
9797

9898
key_cache_block = cache_config.block_size * num_heads * head_size
9999
value_cache_block = key_cache_block

0 commit comments

Comments
 (0)