From c72d04c067172505b39e09c956197b7c81708fec Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 22 Feb 2024 12:11:11 +0900 Subject: [PATCH 01/22] remove control flow --- .../models/llama/modeling_llama.py | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8e494adefc2d73..7e9ebf704a97c4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -99,6 +99,14 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): @@ -127,9 +135,6 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=x.dtype) sin = emb.sin().to(dtype=x.dtype) - # backwards compatibility - self._cos_cached = cos - self._sin_cached = sin return cos, sin @@ -138,7 +143,21 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: a scaling factor is aplied to the position ids @@ -152,7 +171,21 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length From 9c6a877705ee6fbf5162af91005127430ba2388b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 22 Feb 2024 12:11:41 +0900 Subject: [PATCH 02/22] update gptneox --- .../models/gpt_neox/modeling_gpt_neox.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 8dd1cde35c7b89..4922b52a72832e 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -566,10 +566,25 @@ def forward(self, x, seq_len=None): class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ + # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ + # TODO @gante bring compatibility back def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -589,7 +604,21 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len From d04c69754fadd04d84e1f798c2ff72506ccd9f5d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 22 Feb 2024 12:15:22 +0900 Subject: [PATCH 03/22] update .... --- .../models/llama/modeling_llama.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7e9ebf704a97c4..4e4ed3f6e4cdb4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -105,28 +105,28 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead." + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._sin_cached @property def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead." + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached def forward(self, x, position_ids, seq_len=None): if seq_len is not None: - logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.") + logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.") # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @@ -156,8 +156,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: a scaling factor is aplied to the position ids @@ -184,8 +184,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length From 614b8c3572ffe4feab79da2c02ae0b2b3aff1dc7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 22 Feb 2024 12:17:43 +0900 Subject: [PATCH 04/22] nits --- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 3 ++- src/transformers/models/llama/modeling_llama.py | 4 +--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 4922b52a72832e..70a088d3a11c40 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -601,7 +601,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__ + # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__ + # TODO @gante no longer copied from def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor self.dim = dim diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4e4ed3f6e4cdb4..a9a0a990f7647c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -133,9 +133,7 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=x.dtype) - sin = emb.sin().to(dtype=x.dtype) - return cos, sin + return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): From 9c3c6f5eeab8bb6266dfc9de837692810e7662f1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 22 Feb 2024 12:43:50 +0900 Subject: [PATCH 05/22] Actually let's just break. Otherwise we are silently failing which imo is not optimal --- .../models/gpt_neox/modeling_gpt_neox.py | 10 ---- .../models/llama/modeling_llama.py | 52 +++---------------- 2 files changed, 6 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 70a088d3a11c40..fca0d8489c4c5f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -576,16 +576,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) - def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a9a0a990f7647c..11ccc2090bfe6e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -99,30 +99,18 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): - logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + raise ValueError( + "Starting v4.38.0, the cached sin and cos are no longer representative and should not be used" ) - return self._sin_cached @property def cos_cached(self): - logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + raise ValueError( + "Starting v4.38.0, the cached sin and cos are no longer representative and should not be used" ) - return self._cos_cached def forward(self, x, position_ids, seq_len=None): if seq_len is not None: @@ -141,21 +129,7 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + super().__init__(dim, max_position_embeddings, base, device) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: a scaling factor is aplied to the position ids @@ -169,21 +143,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + super().__init__(dim, max_position_embeddings, base, device) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length From 12d60c62c6d53ce4e0ac3441a9bdc1aa6fea0e0a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 22 Feb 2024 12:47:18 +0900 Subject: [PATCH 06/22] version BC --- .../models/llama/modeling_llama.py | 52 ++++++++++++++++--- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 11ccc2090bfe6e..a9a0a990f7647c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -99,18 +99,30 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): - raise ValueError( - "Starting v4.38.0, the cached sin and cos are no longer representative and should not be used" + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) + return self._sin_cached @property def cos_cached(self): - raise ValueError( - "Starting v4.38.0, the cached sin and cos are no longer representative and should not be used" + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) + return self._cos_cached def forward(self, x, position_ids, seq_len=None): if seq_len is not None: @@ -129,7 +141,21 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: a scaling factor is aplied to the position ids @@ -143,7 +169,21 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length From 0fcd9ad38495a6e0dff86ad478aca372eb5dbcb6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 22 Feb 2024 12:56:31 +0900 Subject: [PATCH 07/22] fix tests --- .../models/llama/modeling_llama.py | 41 ++----------------- 1 file changed, 3 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a9a0a990f7647c..2a89f8818433d1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -92,8 +92,9 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() + self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -101,7 +102,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -139,24 +140,6 @@ def forward(self, x, position_ids, seq_len=None): class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) - def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: a scaling factor is aplied to the position ids position_ids = position_ids.float() / self.scaling_factor @@ -167,24 +150,6 @@ def forward(self, x, position_ids, seq_len=None): class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) - def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length seq_len = torch.max(position_ids) + 1 From 3eeef21f8035894f44355f8e7ccd6250a4910790 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 27 Feb 2024 11:23:43 +0900 Subject: [PATCH 08/22] fix eager causal --- src/transformers/models/llama/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2a89f8818433d1..6e3f0df841a56f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -363,6 +363,7 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask if cache_position is not None: causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask From d95f3ed4902e5966f4ff529365b07efef9deaffa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 15:57:34 +0900 Subject: [PATCH 09/22] nit --- src/transformers/models/llama/modeling_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 85c3c16401fe94..1f9ee6bb1a566c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -102,7 +102,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) / self.scaling_factor + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) From 4a73df127879a6abda75f56f3965d26992362bda Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 16:10:19 +0900 Subject: [PATCH 10/22] add a test --- tests/models/llama/test_modeling_llama.py | 40 ++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index a393950232f306..c3bf7461c61747 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -20,7 +20,7 @@ import pytest from parameterized import parameterized -from transformers import LlamaConfig, is_torch_available, set_seed +from transformers import LlamaConfig, is_torch_available, set_seed, StaticCache from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -595,6 +595,42 @@ def test_model_13b_greedy_generation(self): text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + @require_torch_gpu + def test_compile_static_cache(self): + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = ["""Simply put, the theory of relativity states that 1) the laws of physics are the same everywhere in the universe and 2) the passage of time and the length of objects can vary depending on the observer\'s frame of reference.\n\nThe first part of the theory, that the laws of physics are the same everywhere, is known as the "princi""",""""""] + prompts = ["Simply put, the theory of relativity states that ", "My favorit condiment has to be"] + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token_id = "") + model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") + input_ids = tokenizer.encode(prompts, return_tensors="pt", padding=True) + + def decode_one_tokens(model, cur_token, input_pos, cache_position): + logits = model(cur_token, position_ids=input_pos,cache_position=cache_position, return_dict=False, use_cache = True)[0] + new_token = torch.argmax(logits,dim=-1)[0] + return new_token + + batch_size, seq_length = input_ids.shape + with torch.no_grad(): + model._setup_cache(StaticCache, 2, max_cache_len=4096) + cache_position = torch.arange(seq_length , device=torch_device) + generated_ids = torch.zeros(batch_size, seq_length+NUM_TOKENS_TO_GENERATE+1, dtype = torch.int, device=torch_device) + generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) + + logits = model(input_ids,cache_position=cache_position, return_dict=False, use_cache = True)[0] + next_token = torch.argmax(logits,dim=-1)[0] + generated_ids[:, seq_length] = next_token + + decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) + cache_position = torch.tensor([seq_length] , device=torch_device) + for _ in range(1, NUM_TOKENS_TO_GENERATE): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) + generated_ids.index_copy_(1, cache_position, next_token) + cache_position+=1 + + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + @require_torch class CodeLlamaIntegrationTest(unittest.TestCase): @@ -676,3 +712,5 @@ def test_model_7b_logits(self): ] infilling = tokenizer.batch_decode(generated_ids) self.assertEqual(infilling, EXPECTED_INFILLING) + + From 4323e4140dfa86b889c9a70edaded0b60267a2ad Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 16:10:37 +0900 Subject: [PATCH 11/22] style --- tests/models/llama/test_modeling_llama.py | 35 +++++++++++++---------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index c3bf7461c61747..49bb7c1cca89cc 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -20,7 +20,7 @@ import pytest from parameterized import parameterized -from transformers import LlamaConfig, is_torch_available, set_seed, StaticCache +from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -598,35 +598,42 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu def test_compile_static_cache(self): NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = ["""Simply put, the theory of relativity states that 1) the laws of physics are the same everywhere in the universe and 2) the passage of time and the length of objects can vary depending on the observer\'s frame of reference.\n\nThe first part of the theory, that the laws of physics are the same everywhere, is known as the "princi""",""""""] + EXPECTED_TEXT_COMPLETION = [ + """Simply put, the theory of relativity states that 1) the laws of physics are the same everywhere in the universe and 2) the passage of time and the length of objects can vary depending on the observer\'s frame of reference.\n\nThe first part of the theory, that the laws of physics are the same everywhere, is known as the "princi""", + """""", + ] prompts = ["Simply put, the theory of relativity states that ", "My favorit condiment has to be"] - tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token_id = "") + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token_id="") model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") input_ids = tokenizer.encode(prompts, return_tensors="pt", padding=True) def decode_one_tokens(model, cur_token, input_pos, cache_position): - logits = model(cur_token, position_ids=input_pos,cache_position=cache_position, return_dict=False, use_cache = True)[0] - new_token = torch.argmax(logits,dim=-1)[0] + logits = model( + cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + )[0] + new_token = torch.argmax(logits, dim=-1)[0] return new_token batch_size, seq_length = input_ids.shape with torch.no_grad(): model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length , device=torch_device) - generated_ids = torch.zeros(batch_size, seq_length+NUM_TOKENS_TO_GENERATE+1, dtype = torch.int, device=torch_device) + cache_position = torch.arange(seq_length, device=torch_device) + generated_ids = torch.zeros( + batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device + ) generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) - - logits = model(input_ids,cache_position=cache_position, return_dict=False, use_cache = True)[0] - next_token = torch.argmax(logits,dim=-1)[0] + + logits = model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] + next_token = torch.argmax(logits, dim=-1)[0] generated_ids[:, seq_length] = next_token - + decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length] , device=torch_device) + cache_position = torch.tensor([seq_length], device=torch_device) for _ in range(1, NUM_TOKENS_TO_GENERATE): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) generated_ids.index_copy_(1, cache_position, next_token) - cache_position+=1 + cache_position += 1 text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @@ -712,5 +719,3 @@ def test_model_7b_logits(self): ] infilling = tokenizer.batch_decode(generated_ids) self.assertEqual(infilling, EXPECTED_INFILLING) - - From 7f5ac69d3f82c554d4158e3a4da958c78f3ef73d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 08:41:03 +0100 Subject: [PATCH 12/22] nits --- tests/models/llama/test_modeling_llama.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 49bb7c1cca89cc..88a781701fccde 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -603,28 +603,28 @@ def test_compile_static_cache(self): """""", ] prompts = ["Simply put, the theory of relativity states that ", "My favorit condiment has to be"] - tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token_id="") + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="left") model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") - input_ids = tokenizer.encode(prompts, return_tensors="pt", padding=True) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) def decode_one_tokens(model, cur_token, input_pos, cache_position): logits = model( cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True )[0] - new_token = torch.argmax(logits, dim=-1)[0] + new_token = torch.argmax(logits[:,-1], dim=-1, keep_dim = True) return new_token - batch_size, seq_length = input_ids.shape + batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): model._setup_cache(StaticCache, 2, max_cache_len=4096) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device ) - generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) + generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - logits = model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits, dim=-1)[0] + logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] + next_token = torch.argmax(logits[:,-1], dim=-1, keep_dim=True) generated_ids[:, seq_length] = next_token decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) @@ -632,7 +632,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): for _ in range(1, NUM_TOKENS_TO_GENERATE): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - generated_ids.index_copy_(1, cache_position, next_token) + generated_ids[:,cache_position] = next_token[:, None].int() cache_position += 1 text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) From b7d1884dd4689c431f6b1d12d1fdf049ec9a07f0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 08:42:46 +0100 Subject: [PATCH 13/22] nits --- tests/models/llama/test_modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 88a781701fccde..3b5ac3ba84d14a 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -604,7 +604,7 @@ def test_compile_static_cache(self): ] prompts = ["Simply put, the theory of relativity states that ", "My favorit condiment has to be"] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="left") - model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") + model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto") inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) def decode_one_tokens(model, cur_token, input_pos, cache_position): From 100ab52ea07f32b87e6bd13689584e5e8127e323 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 08:43:43 +0100 Subject: [PATCH 14/22] more nits for the test --- tests/models/llama/test_modeling_llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 3b5ac3ba84d14a..877630f648325f 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -611,7 +611,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): logits = model( cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True )[0] - new_token = torch.argmax(logits[:,-1], dim=-1, keep_dim = True) + new_token = torch.argmax(logits[:,-1], dim=-1).unsqueeze(1) return new_token batch_size, seq_length = inputs["input_ids"].shape @@ -624,7 +624,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:,-1], dim=-1, keep_dim=True) + next_token = torch.argmax(logits[:,-1], dim=-1).unsqueeze(1) generated_ids[:, seq_length] = next_token decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) @@ -632,7 +632,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): for _ in range(1, NUM_TOKENS_TO_GENERATE): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - generated_ids[:,cache_position] = next_token[:, None].int() + generated_ids[:,cache_position] = next_token.int() cache_position += 1 text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) From ca82c260cf2df71383da11da18d4157cec05c9dd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 09:14:26 +0100 Subject: [PATCH 15/22] update and fix --- tests/models/llama/test_modeling_llama.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 877630f648325f..a2d5ad1cd4a845 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -598,20 +598,17 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu def test_compile_static_cache(self): NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = [ - """Simply put, the theory of relativity states that 1) the laws of physics are the same everywhere in the universe and 2) the passage of time and the length of objects can vary depending on the observer\'s frame of reference.\n\nThe first part of the theory, that the laws of physics are the same everywhere, is known as the "princi""", - """""", - ] - prompts = ["Simply put, the theory of relativity states that ", "My favorit condiment has to be"] - tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="left") - model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto") + EXPECTED_TEXT_COMPLETION = ['Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.', 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] + prompts = ["Simply put, the theory of relativity states that ", "My favorite all time favorite condiment is ketchup."] + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") + model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) def decode_one_tokens(model, cur_token, input_pos, cache_position): logits = model( cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True )[0] - new_token = torch.argmax(logits[:,-1], dim=-1).unsqueeze(1) + new_token = torch.argmax(logits[:,-1], dim=-1)[:,None] return new_token batch_size, seq_length = inputs["input_ids"].shape @@ -624,18 +621,18 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:,-1], dim=-1).unsqueeze(1) - generated_ids[:, seq_length] = next_token + next_token = torch.argmax(logits[:,-1], dim=-1)[:,None] + generated_ids[:, seq_length] = next_token[:,0] decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length], device=torch_device) + cache_position = torch.tensor([seq_length+1], device=torch_device) for _ in range(1, NUM_TOKENS_TO_GENERATE): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) generated_ids[:,cache_position] = next_token.int() cache_position += 1 - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) From 409d97e44a291bb1224951572991a69e6eae273b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 09:23:22 +0100 Subject: [PATCH 16/22] make sure cuda graphs are not skipped --- tests/models/llama/test_modeling_llama.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index a2d5ad1cd4a845..fe7ab5296f18f1 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -20,7 +20,7 @@ import pytest from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed +from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed, logging from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -30,6 +30,7 @@ require_torch_sdpa, slow, torch_device, + CaptureLogger ) from ...generation.test_utils import GenerationTesterMixin @@ -628,7 +629,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): cache_position = torch.tensor([seq_length+1], device=torch_device) for _ in range(1, NUM_TOKENS_TO_GENERATE): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) + with CaptureLogger(logging.get_logger(__name__)) as cl: + next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) + self.assertNotIn("skipping cudagraphs due to", cl.out) generated_ids[:,cache_position] = next_token.int() cache_position += 1 From 6b8493685ad60dae92245a5637f6ddae4e43b05a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 09:24:13 +0100 Subject: [PATCH 17/22] read token is needed for meta llama --- tests/models/llama/test_modeling_llama.py | 26 +++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index fe7ab5296f18f1..6cb86b55df795e 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -20,17 +20,18 @@ import pytest from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed, logging +from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed from transformers.testing_utils import ( + CaptureLogger, require_bitsandbytes, require_flash_attn, + require_read_token, require_torch, require_torch_accelerator, require_torch_gpu, require_torch_sdpa, slow, torch_device, - CaptureLogger ) from ...generation.test_utils import GenerationTesterMixin @@ -597,10 +598,17 @@ def test_model_13b_greedy_generation(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @require_torch_gpu + @require_read_token def test_compile_static_cache(self): NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = ['Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.', 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] - prompts = ["Simply put, the theory of relativity states that ", "My favorite all time favorite condiment is ketchup."] + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) @@ -609,7 +617,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): logits = model( cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True )[0] - new_token = torch.argmax(logits[:,-1], dim=-1)[:,None] + new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] return new_token batch_size, seq_length = inputs["input_ids"].shape @@ -622,17 +630,17 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:,-1], dim=-1)[:,None] - generated_ids[:, seq_length] = next_token[:,0] + next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] + generated_ids[:, seq_length] = next_token[:, 0] decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length+1], device=torch_device) + cache_position = torch.tensor([seq_length + 1], device=torch_device) for _ in range(1, NUM_TOKENS_TO_GENERATE): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): with CaptureLogger(logging.get_logger(__name__)) as cl: next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) self.assertNotIn("skipping cudagraphs due to", cl.out) - generated_ids[:,cache_position] = next_token.int() + generated_ids[:, cache_position] = next_token.int() cache_position += 1 text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) From a01d0e187cd08900cfe7c9f4ac33ae15d0aac09f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 09:31:44 +0100 Subject: [PATCH 18/22] update! --- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index fca0d8489c4c5f..2fc42e287236ed 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -569,6 +569,7 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ # TODO @gante bring compatibility back def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings @@ -594,6 +595,7 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__ # TODO @gante no longer copied from def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings From 7e2c08bf11831e9b06daa45f1c62ac655ea0f56e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 09:49:00 +0100 Subject: [PATCH 19/22] fiixup --- .../models/gpt_neox/modeling_gpt_neox.py | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 2fc42e287236ed..73615ee86e4a13 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -562,20 +562,15 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len], ) - + # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ +# TODO @gante bring compatibility back class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ - # TODO @gante bring compatibility back + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - super().__init__() self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -595,23 +590,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__ # TODO @gante no longer copied from def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - super().__init__() self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len From 0809a9f1db26f3509755cb3ac11973c0e1d9fe23 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 09:53:01 +0100 Subject: [PATCH 20/22] compile test should be slow --- tests/models/llama/test_modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 6cb86b55df795e..308e5d91195215 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -597,6 +597,7 @@ def test_model_13b_greedy_generation(self): text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + @slow @require_torch_gpu @require_read_token def test_compile_static_cache(self): From 5cfa3fbd34e2b6966351a9fabfcb1623d2bc5cde Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 10:17:01 +0100 Subject: [PATCH 21/22] fix thet fix copies --- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 73615ee86e4a13..71ae7b8882c698 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -562,7 +562,7 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len], ) - # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ +# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ # TODO @gante bring compatibility back class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" From 14db98ef81b6ef6b0b37ecfe944cc5e21013e2b3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 Feb 2024 10:33:14 +0100 Subject: [PATCH 22/22] =?UTF-8?q?=20stle=20=F0=9F=AB=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 71ae7b8882c698..882b4fc9ecc322 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -562,12 +562,12 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len], ) + # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ # TODO @gante bring compatibility back class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device)