Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile compatibilty for decoder-only models #32617

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,13 +1639,14 @@ def _tensor_or_none(token, device=None):

# Set pad token if unset (and there are conditions to do so)
if pad_token_tensor is None and eos_token_tensor is not None:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
if not is_torchdynamo_compiling():
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")

# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
Expand Down
35 changes: 25 additions & 10 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,24 +326,21 @@ def forward(

# reshape qkv for further computations
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2)
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)

kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1

# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
attention_scores = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)

# change view to [batch_size, num_heads, q_length, kv_length]
attn_weights = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :kv_length]
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
attn_weights = attn_weights + causal_mask

# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
Expand All @@ -356,7 +353,7 @@ def forward(
attention_probs = attention_probs * head_mask

# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)

# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
Expand Down Expand Up @@ -496,6 +493,8 @@ class BloomPreTrainedModel(PreTrainedModel):
_no_split_modules = ["BloomBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down Expand Up @@ -894,9 +893,25 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing #Copied from ... ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, bloom has alibi and needs 2D attention for that. So we can't expand it to 4D, and choose to append zeros to attn to make it static shape.

else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
# input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
# the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

# This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
# The only difference is the usage of 2D instead of 4D mask, but the shape will be static
if isinstance(past_key_values, StaticCache) and attention_mask is not None:
target_length = past_key_values.get_max_length()
batch_size, seq_length = attention_mask.shape
diff = target_length - seq_length

new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
attention_mask = torch.cat(
[attention_mask, new_attn_mask],
dim=-1,
)

model_inputs.update(
{
Expand Down
67 changes: 36 additions & 31 deletions src/transformers/models/falcon/configuration_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,42 @@ class FalconConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
bos_token_id (`int`, *optional*, defaults to 11):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 11):
Expand Down Expand Up @@ -167,7 +196,6 @@ def __init__(
self.ffn_hidden_size = hidden_size * 4
else:
self.ffn_hidden_size = ffn_hidden_size
self._rope_scaling_validation()

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

Expand All @@ -178,26 +206,3 @@ def head_dim(self):
@property
def rotary(self):
return not self.alibi

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if self.alibi:
raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
Loading
Loading