From 76479fdc4301e69c67f10d8ee0882b298f84da9b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 21 Nov 2023 22:24:34 +0000 Subject: [PATCH 01/11] replaced call to _prepare_4d_causal_attention_mask --- src/petals/__init__.py | 8 +------- src/petals/models/llama/block.py | 5 +++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 1af8bf951..bd8861c87 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,13 +17,7 @@ from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.3.0.dev1" - - -if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): - assert ( - version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0") - ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0" +__version__ = "2.3.0.dev2" def _override_bfloat16_mode_default(): diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index a8d433ded..841c7302a 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, @@ -244,8 +245,8 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - attention_mask = LlamaModel._prepare_decoder_attention_mask( - None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length ) outputs = super().forward( From d315467aab0895a1be17b06aa69070084b61e243 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 21 Nov 2023 23:00:23 +0000 Subject: [PATCH 02/11] upped transformers ver to 4.35 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index ef35f8455..7f0a77f3b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = accelerate>=0.22.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py + transformers>=4.35.0 # 4.35.0 is the minimum that contains modeling_attn_mask_utils.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet hivemind==1.1.10.post2 From c6db638433ad0dece231fd28122143e9b5e7742e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 21 Nov 2023 23:40:48 +0000 Subject: [PATCH 03/11] upd Bloom _prepare_attn_mask() --- src/petals/models/bloom/block.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index f246bd867..6a2a12f36 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple import torch +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor @@ -26,7 +27,13 @@ def forward( attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length) + fake_inputs_embeds = torch.tensor([42], dtype=torch.float32) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=fake_inputs_embeds, + past_key_values_length=past_length, + ) return super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs ) From 2bdbf2da58299ea9edb754547d78b04f4b59d0d2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 08:28:30 +0000 Subject: [PATCH 04/11] mask to bool for bloom fwd --- src/petals/models/bloom/block.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index 6a2a12f36..a460d169a 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -27,13 +27,17 @@ def forward( attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - fake_inputs_embeds = torch.tensor([42], dtype=torch.float32) + + # _prepare_4d only needs inputs_embeds.dtype. And it is changed to bool before .forward() anyways + fake_inputs_embeds = torch.tensor([42], dtype=torch.float32) + attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_length), inputs_embeds=fake_inputs_embeds, past_key_values_length=past_length, ) + attention_mask = attention_mask.bool() # consistent with https://github.com/huggingface/transformers/pull/27086 return super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs ) From fa254cff02fae934c24b55e00bcb3569b802c9b8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 08:30:08 +0000 Subject: [PATCH 05/11] Llama rotary dims from 4 to 2 --- src/petals/models/llama/block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 841c7302a..b3deb511b 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -85,8 +85,8 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - cos = cos[:, :, kv_seq_len - q_len :] - sin = sin[:, :, kv_seq_len - q_len :] + cos = cos[kv_seq_len - q_len :] + sin = sin[kv_seq_len - q_len :] if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) From 741b5394cc9968f3912172ee1c6bf6532f40ad0b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 09:48:44 +0000 Subject: [PATCH 06/11] past_key_values to None if zero shape --- src/petals/client/remote_generation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 97a115ab2..2f5583fb5 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -34,6 +34,9 @@ class _SkipTokensMixin: def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict: input_ids = input_ids[:, _skipped_tokens.get() :] _skipped_tokens.set(0) + if "past_key_values" in kwargs: + if kwargs['past_key_values'][0][0].shape == torch.Size([0]): + kwargs['past_key_values'] = None return super().prepare_inputs_for_generation(input_ids, **kwargs) From 3810049411f8f4a3299d8408cf16063073bd80e9 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 12:26:27 +0000 Subject: [PATCH 07/11] _prepare_4d_causal_attention_mask in tests --- tests/test_optimized_layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 84cbfffe2..69b77cc1e 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -2,6 +2,7 @@ import pytest import torch +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel @@ -131,8 +132,8 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - attention_mask = LlamaModel._prepare_decoder_attention_mask( - None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length ) outputs = super().forward( From e5a3fe64e2c27872bcf0a23b2b94ef196a2551dc Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 12:36:33 +0000 Subject: [PATCH 08/11] attn_mask fixes for falcon --- src/petals/models/falcon/block.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index a510abaa1..383fcb3c5 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.falcon.modeling_falcon import ( FalconAttention, FalconConfig, @@ -418,7 +419,14 @@ def forward( attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None and self.config.alibi: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + fake_inputs_embeds = torch.tensor([42], dtype=torch.float32) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=fake_inputs_embeds, + past_key_values_length=past_length, + ) outputs = super().forward( hidden_states, From bd0ca0fbae0a704f583497bc0e64afdc3c741690 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 12:43:08 +0000 Subject: [PATCH 09/11] llama block reformatted --- src/petals/models/llama/block.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index b3deb511b..6f539a841 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -246,7 +246,10 @@ def forward( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, + past_key_values_length=past_key_values_length, ) outputs = super().forward( From 5aebd3e8fca092c61ad750093406923714dd8c7e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 12:47:28 +0000 Subject: [PATCH 10/11] edits to bloom, falcon inputs_embeds arg --- src/petals/models/bloom/block.py | 7 ++----- src/petals/models/falcon/block.py | 3 +-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index a460d169a..355f650b8 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -28,16 +28,13 @@ def forward( if alibi is None: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - # _prepare_4d only needs inputs_embeds.dtype. And it is changed to bool before .forward() anyways - fake_inputs_embeds = torch.tensor([42], dtype=torch.float32) - attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_length), - inputs_embeds=fake_inputs_embeds, + inputs_embeds=hidden_states, past_key_values_length=past_length, ) - attention_mask = attention_mask.bool() # consistent with https://github.com/huggingface/transformers/pull/27086 + attention_mask = attention_mask.bool() return super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs ) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 383fcb3c5..d36fcb7cc 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -420,11 +420,10 @@ def forward( if alibi is None and self.config.alibi: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - fake_inputs_embeds = torch.tensor([42], dtype=torch.float32) attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_length), - inputs_embeds=fake_inputs_embeds, + inputs_embeds=hidden_states, past_key_values_length=past_length, ) From 401e7917000d92782df9ca256e84e6bb13de20b0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 22 Nov 2023 13:00:35 +0000 Subject: [PATCH 11/11] black --- src/petals/client/remote_generation.py | 4 ++-- src/petals/models/bloom/block.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 2f5583fb5..5324a5183 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -35,8 +35,8 @@ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) - input_ids = input_ids[:, _skipped_tokens.get() :] _skipped_tokens.set(0) if "past_key_values" in kwargs: - if kwargs['past_key_values'][0][0].shape == torch.Size([0]): - kwargs['past_key_values'] = None + if kwargs["past_key_values"][0][0].shape == torch.Size([0]): + kwargs["past_key_values"] = None return super().prepare_inputs_for_generation(input_ids, **kwargs) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index 355f650b8..e86b839c0 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -27,7 +27,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - + attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_length),