Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama ROPE] Fix torch export but also slow downs in forward #29198

Merged
merged 25 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,11 @@ def forward(self, x, seq_len=None):
)


# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
# TODO @gante bring compatibility back
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
Expand All @@ -586,7 +587,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
# TODO @gante no longer copied from
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
Expand Down
38 changes: 18 additions & 20 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,54 +92,55 @@ def forward(self, hidden_states):


class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._sin_cached

@property
def cos_cached(self):
logger.warning_once(
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._cos_cached

def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")

# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=x.dtype)
sin = emb.sin().to(dtype=x.dtype)
# backwards compatibility
self._cos_cached = cos
self._sin_cached = sin
return cos, sin
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def forward(self, x, position_ids, seq_len=None):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
Expand All @@ -150,10 +151,6 @@ def forward(self, x, position_ids, seq_len=None):
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def forward(self, x, position_ids, seq_len=None):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
Expand Down Expand Up @@ -367,6 +364,7 @@ def forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
Expand Down
54 changes: 53 additions & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import pytest
from parameterized import parameterized

from transformers import LlamaConfig, is_torch_available, set_seed
from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed
from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
Expand Down Expand Up @@ -595,6 +597,56 @@ def test_model_13b_greedy_generation(self):
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)

@slow
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token

batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)

logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]

decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with CaptureLogger(logging.get_logger(__name__)) as cl:
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
self.assertNotIn("skipping cudagraphs due to", cl.out)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1

text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)


@require_torch
class CodeLlamaIntegrationTest(unittest.TestCase):
Expand Down
Loading