From c13fe68061e2e22d8abdddfcc3199fe3fa4c1270 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 21 Oct 2024 15:22:28 +0100 Subject: [PATCH] beautiful --- torchtune/modules/common_utils.py | 34 +++++++++++++---------- torchtune/modules/model_fusion/_fusion.py | 2 +- torchtune/modules/transformer.py | 14 ++++++---- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index ead3c7ad1e..055252cf72 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -199,18 +199,18 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]: >>> # now temporarily disable caches >>> with disable_kv_cache(model): >>> print(model.caches_are_setup()) - >>> True + True >>> print(model.caches_are_enabled()) - >>> False + False >>> print(model.layers[0].attn.kv_cache) - >>> # KVCache() + KVCache() >>> # caches are now re-enabled, and their state is untouched >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) True >>> print(model.layers[0].attn.kv_cache) - >>> KVCache() + KVCache() Args: model (nn.Module): model to disable KV-cacheing for. @@ -219,7 +219,8 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]: None: Returns control to the caller with KV-caches disabled on the given model. Raises: - ValueError: If the model does not have caches setup. + ValueError: If the model does not have caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to + setup caches first. """ if not model.caches_are_setup(): raise ValueError( @@ -306,6 +307,7 @@ def local_kv_cache( Raises: ValueError: If the model already has caches setup. + You may use :func:`~torchtune.modules.common_utils.delete_kv_caches` to delete existing caches. """ if model.caches_are_setup(): raise ValueError( @@ -340,29 +342,31 @@ def delete_kv_caches(model: nn.Module): >>> dtype=torch.float32, >>> decoder_max_seq_len=1024) >>> print(model.caches_are_setup()) - >>> True + True >>> print(model.caches_are_enabled()) - >>> True + True >>> print(model.layers[0].attn.kv_cache) - >>> KVCache() + KVCache() >>> delete_kv_caches(model) >>> print(model.caches_are_setup()) - >>> False + False >>> print(model.caches_are_enabled()) - >>> False + False >>> print(model.layers[0].attn.kv_cache) - >>> None + None + Args: model (nn.Module): model to enable KV-cacheing for. Raises: - ValueError: if ``delete_kv_caches`` is called on a model which does not have - caches setup. + ValueError: if this function is called on a model which does not have + caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to + setup caches first. """ if not model.caches_are_setup(): raise ValueError( - "You have tried to delete model caches, but `model.caches_are_setup()` " - "is False!" + "You have tried to delete model caches, but model.caches_are_setup() " + "is False! Please setup caches on the model first." ) for module in model.modules(): if hasattr(module, "kv_cache") and callable(module.kv_cache): diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 40ede4feec..1a5452daae 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -405,7 +405,7 @@ def caches_are_enabled(self) -> bool: Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant attention modules will be "enabled" and all forward passes will update the caches. This behaviour can be disabled without altering the state of the KV-caches by "disabling" the KV-caches - using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False. + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. """ return self.decoder.caches_are_enabled() diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 910cb8273b..b509fb81e5 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -410,9 +410,9 @@ def setup_caches( ): """ Sets up key-value attention caches for inference. For each layer in ``self.layers``: - - :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. - - :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. - - :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + - :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`~torchtune.modules.model_fusion.FusionLayer` will use ``decoder_max_seq_len`` and ``encoder_max_seq_len``. Args: batch_size (int): batch size for the caches. @@ -460,7 +460,7 @@ def caches_are_enabled(self) -> bool: Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant attention modules will be "enabled" and all forward passes will update the caches. This behaviour can be disabled without altering the state of the KV-caches by "disabling" the KV-caches - using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False. + using :func:`torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. """ return self.layers[0].caches_are_enabled() @@ -468,10 +468,14 @@ def reset_caches(self): """ Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, without deleting or reallocating cache tensors. + + Raises: + RuntimeError: if KV-caches are not setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to + setup caches first. """ if not self.caches_are_enabled(): raise RuntimeError( - "Key value caches are not setup. Call ``setup_caches()`` first." + "Key value caches are not setup. Call model.setup_caches first." ) for layer in self.layers: