Skip to content

Commit

Permalink
Fix: Jamba batched generation (#32914)
Browse files Browse the repository at this point in the history
* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
  • Loading branch information
vasqu authored Aug 28, 2024
1 parent 386931d commit 3bfd3e4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 56 deletions.
54 changes: 46 additions & 8 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,12 @@ def __init__(self, config: JambaConfig, layer_idx):
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
)

def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: HybridMambaAttentionDynamicCache = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
Expand All @@ -666,6 +671,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
# inner layernorms which isn't supported by this fused kernel
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if use_precomputed_states:
Expand All @@ -683,6 +691,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
cache_params.conv_states[self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -742,14 +753,17 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
return contextualized_states

# fmt: off
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask: Optional[torch.LongTensor] = None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
# 2. Convolution sequence transformation
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
if self.training:
Expand Down Expand Up @@ -784,6 +798,9 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -821,14 +838,19 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
return contextualized_states
# fmt: on

def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
def forward(
self,
hidden_states,
cache_params: HybridMambaAttentionDynamicCache = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
raise ValueError(
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
)
return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params)
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
return self.slow_forward(hidden_states, cache_params, attention_mask)


# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
Expand Down Expand Up @@ -1040,6 +1062,7 @@ def forward(
hidden_states = self.mamba(
hidden_states=hidden_states,
cache_params=past_key_value,
attention_mask=attention_mask,
)
self_attn_weights = None

Expand Down Expand Up @@ -1279,20 +1302,24 @@ def forward(
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
mamba_mask = self._update_mamba_mask(attention_mask, cache_position)

all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None

for decoder_layer in self.layers:
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask

if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
layer_mask,
position_ids,
past_key_values,
output_attentions,
Expand All @@ -1303,7 +1330,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1384,6 +1411,17 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):

return causal_mask

def _update_mamba_mask(self, attention_mask, cache_position):
"""
No need for zeroing states when
1. Cached forward
2. Attending to all inputs
"""
mamba_mask = attention_mask
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
mamba_mask = None
return mamba_mask


# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
class JambaForCausalLM(JambaPreTrainedModel):
Expand Down
52 changes: 4 additions & 48 deletions tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,51 +458,6 @@ def test_attention_outputs(self):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)

def test_left_padding_compatibility(self):
r"""
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
"""
import inspect
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535

# First, filter out models that don't support left padding - generative and decoder-only.
# Jamba is a decoder-only architecture
decoder_only_classes = self.all_generative_model_classes

# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs

for model_class in decoder_only_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]

# With left-padding (length 32)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]

# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down Expand Up @@ -692,7 +647,7 @@ def test_simple_generate(self):
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
Expand Down Expand Up @@ -737,10 +692,11 @@ def test_simple_batched_generate_with_padding(self):
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

# TODO fix logits
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
[
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
Expand All @@ -749,7 +705,7 @@ def test_simple_batched_generate_with_padding(self):

EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
[
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,
Expand Down

0 comments on commit 3bfd3e4

Please sign in to comment.