Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Jamba batched generation #32914

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,12 @@ def __init__(self, config: JambaConfig, layer_idx):
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
)

def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: HybridMambaAttentionDynamicCache = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
Expand All @@ -665,6 +670,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
# inner layernorms which isn't supported by this fused kernel
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if use_precomputed_states:
Expand All @@ -682,6 +690,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
cache_params.conv_states[self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -741,14 +752,17 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
return contextualized_states

# fmt: off
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask: Optional[torch.LongTensor] = None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
# 2. Convolution sequence transformation
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
if self.training:
Expand Down Expand Up @@ -783,6 +797,9 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -820,14 +837,19 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
return contextualized_states
# fmt: on

def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
def forward(
self,
hidden_states,
cache_params: HybridMambaAttentionDynamicCache = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
raise ValueError(
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
)
return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params)
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
return self.slow_forward(hidden_states, cache_params, attention_mask)


# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
Expand Down Expand Up @@ -1039,6 +1061,7 @@ def forward(
hidden_states = self.mamba(
hidden_states=hidden_states,
cache_params=past_key_value,
attention_mask=attention_mask,
)
self_attn_weights = None

Expand Down Expand Up @@ -1278,20 +1301,24 @@ def forward(
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
mamba_mask = self._update_mamba_mask(attention_mask, cache_position)

all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None

for decoder_layer in self.layers:
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask

if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
layer_mask,
position_ids,
past_key_values,
output_attentions,
Expand All @@ -1302,7 +1329,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1383,6 +1410,17 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):

return causal_mask

def _update_mamba_mask(self, attention_mask, cache_position):
"""
No need for zeroing states when
1. Cached forward
2. Attending to all inputs
"""
mamba_mask = attention_mask
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this line will fail at compilation time (data-dependent conditional branch). Can you confirm, i.e. try running a compiled forward pass?

If it fails, we can add a compile guard, i.e. start the if with not is_torchdynamo_compiling()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested via the following mini script:

import torch
from transformers import JambaForCausalLM, AutoTokenizer

model_id = "ai21labs/Jamba-tiny-random"
model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to("cuda")
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# tested on both, batched or non-batched input
#input = tokenizer(["Hey how are you doing on this lovely evening?", "What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")
input = tokenizer(["What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")

# tested on both, forward call or generate
out = model(**input)
#out = model.generate(**input, do_sample=False, max_new_tokens=10)

Haven't encountered any compilation errors locally, so seems to be fine. Is this what you had in mind to test compilation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's it!

Perfect, thank you for confirming :)

mamba_mask = None
return mamba_mask


# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
class JambaForCausalLM(JambaPreTrainedModel):
Expand Down
52 changes: 4 additions & 48 deletions tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,51 +458,6 @@ def test_attention_outputs(self):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)

def test_left_padding_compatibility(self):
r"""
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
"""
import inspect
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535

# First, filter out models that don't support left padding - generative and decoder-only.
# Jamba is a decoder-only architecture
decoder_only_classes = self.all_generative_model_classes

# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs

for model_class in decoder_only_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]

# With left-padding (length 32)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]

# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))

Comment on lines -461 to -505
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passed locally without the higher rtol/atol. Will see if the CI agrees.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it does :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping it open for visibility: Left padding works fine now, it was an issue of how padding has been handled in general (for mamba-related models).

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down Expand Up @@ -692,7 +647,7 @@ def test_simple_generate(self):
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
Expand Down Expand Up @@ -737,10 +692,11 @@ def test_simple_batched_generate_with_padding(self):
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

# TODO fix logits
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more visibility so that I don't forget about it.

EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
[
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
Expand All @@ -749,7 +705,7 @@ def test_simple_batched_generate_with_padding(self):

EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
[
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,
Expand Down