From 6cc4dfe3f1e8d421c6d6351388e06e9b123cbfe1 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:06:08 +0200 Subject: [PATCH] Fix the initialization of the cache when we have multi gpu (#33303) * init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante --- src/transformers/cache_utils.py | 40 +++++++++---- src/transformers/generation/utils.py | 27 +++++++++ tests/generation/test_utils.py | 85 ++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 11 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b3e94da3d7d7bd..0671157e447038 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1030,6 +1030,9 @@ class StaticCache(Cache): The device on which the cache should be initialized. Should be the same as the layer. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Example: @@ -1060,6 +1063,7 @@ def __init__( device: torch.device = None, dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() if max_batch_size is not None: @@ -1088,16 +1092,20 @@ def __init__( # Note: There will be significant perf decrease if switching to use 5D tensors instead. cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) for idx in range(config.num_hidden_layers): - new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + if layer_device_map is not None: + layer_device = layer_device_map[idx] + else: + layer_device = device + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) # Notes: # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case # it is not needed anyway) # 2. `torch.export()` requires mutations to be registered as buffers. if not is_torchdynamo_compiling(): - self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) - self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) + self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) + self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) new_layer_key_cache = getattr(self, f"key_cache_{idx}") new_layer_value_cache = getattr(self, f"value_cache_{idx}") torch._dynamo.mark_static_address(new_layer_key_cache) @@ -1130,9 +1138,9 @@ def update( Return: A tuple containing the updated key and value states. """ + cache_position = cache_kwargs.get("cache_position") - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) + k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1201,6 +1209,9 @@ class SlidingWindowCache(StaticCache): The device on which the cache should be initialized. Should be the same as the layer. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Example: @@ -1231,6 +1242,7 @@ def __init__( device: torch.device = None, dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: @@ -1247,6 +1259,7 @@ def __init__( device=device, dtype=dtype, max_batch_size=max_batch_size, + layer_device_map=layer_device_map, ) def update( @@ -1280,7 +1293,6 @@ def update( v_out = v_out[:, :, indices] try: - cache_position.to(device=k_out.device) k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) except NotImplementedError: @@ -1495,6 +1507,9 @@ class HybridCache(Cache): The device on which the cache should be initialized. Should be the same as the layer. dtype (torch.dtype, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Example: @@ -1525,6 +1540,7 @@ def __init__( device: Union[torch.device, str] = "cpu", dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() if max_batch_size is not None: @@ -1562,11 +1578,15 @@ def __init__( self.head_dim, ) for i in range(config.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[i] + else: + layer_device = device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) @@ -1617,8 +1637,6 @@ def update( ) -> Tuple[torch.Tensor]: cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window") - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] if sliding_window: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 17a234c62b285e..019eb6c27f18cc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1446,12 +1446,39 @@ def _get_cache( # models. May cause trobles with non-text modalities. cache_dtype = self.get_output_embeddings().weight.dtype + def get_layer_device_map(execution_device_map: Optional[dict] = None): + if execution_device_map is None or len(execution_device_map) <= 1: + return None + layer_device_map = {} + for layer in execution_device_map: + for idx in range(self.config.num_hidden_layers): + if f".{idx}." in f"{layer}.": + layer_device_map[idx] = execution_device_map[layer] + break + for idx in range(self.config.num_hidden_layers): + if idx not in layer_device_map: + raise RuntimeError(f"layer {idx} has not been mapped to a device.") + return layer_device_map + + execution_device_map = None + # Taken from dispatch_model from accelerate. + # This is needed here if we don't want to make changes in accelerate in order to save execution_device + # For offloaded case, we need to get the execution device, not just the device where it is offloaded + if hasattr(self, "hf_device_map"): + main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] + execution_device_map = { + name: main_device if device in ["cpu", "disk"] else device + for name, device in self.hf_device_map.items() + } + layer_device_map = get_layer_device_map(execution_device_map) + cache_kwargs = { "config": self.config if hasattr(self.config, "text_config") else self.config, "max_batch_size": batch_size, "max_cache_len": max_cache_len, "device": device, "dtype": cache_dtype, + "layer_device_map": layer_device_map, } self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 65507795c84dd8..0ed054ad58696e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3444,6 +3444,91 @@ def test_special_tokens_fall_back_to_model_default(self): self.assertTrue(test_bos_id == gen_output[0, 0]) self.assertTrue(generation_config.bos_token_id is None) + @pytest.mark.generate + @require_torch_multi_gpu + def test_generate_with_static_cache_multi_gpu(self): + """ + Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus. + """ + # need to split manually as auto doesn't work well with unbalanced model + device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + generation_kwargs = { + "max_new_tokens": 20, + "cache_implementation": "static", + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + results = model.generate(input_ids, **generation_kwargs) + self.assertTrue(isinstance(results.past_key_values, StaticCache)) + + # check device of each layer + key_cache_0 = results.past_key_values.key_cache[0] + value_cache_0 = results.past_key_values.value_cache[0] + self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + + key_cache_1 = results.past_key_values.key_cache[1] + value_cache_1 = results.past_key_values.value_cache[1] + self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + + @pytest.mark.generate + @require_torch_multi_gpu + def test_init_static_cache_multi_gpu(self): + """ + Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup. + """ + # need to split manually as auto doesn't work well with unbalanced model + device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + generation_kwargs = { + "max_new_tokens": 20, + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + # TODO: We need to raise a warning in case the cache is not set correctly + # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"): + # past_key_values = StaticCache( + # config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype + # ) + # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) + + # deduced from the device_map : layer 0 on device 0 and layer 1 on device 1 + layer_device_map = {0: 0, 1: 1} + past_key_values = StaticCache( + config=model.config, + batch_size=1, + max_cache_len=30, + device=torch_device, + dtype=model.dtype, + layer_device_map=layer_device_map, + ) + results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) + + # check device of each layer + key_cache_0 = results.past_key_values.key_cache[0] + value_cache_0 = results.past_key_values.value_cache[0] + self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + + key_cache_1 = results.past_key_values.key_cache[1] + value_cache_1 = results.past_key_values.value_cache[1] + self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + @require_torch class TokenHealingTestCase(unittest.TestCase):