Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions torchtune/modules/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,18 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]:
>>> # now temporarily disable caches
>>> with disable_kv_cache(model):
>>> print(model.caches_are_setup())
>>> True
True
>>> print(model.caches_are_enabled())
>>> False
False
>>> print(model.layers[0].attn.kv_cache)
>>> # KVCache()
KVCache()
>>> # caches are now re-enabled, and their state is untouched
>>> print(model.caches_are_setup())
True
>>> print(model.caches_are_enabled())
True
>>> print(model.layers[0].attn.kv_cache)
>>> KVCache()
KVCache()

Args:
model (nn.Module): model to disable KV-cacheing for.
Expand All @@ -219,7 +219,8 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]:
None: Returns control to the caller with KV-caches disabled on the given model.

Raises:
ValueError: If the model does not have caches setup.
ValueError: If the model does not have caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
setup caches first.
"""
if not model.caches_are_setup():
raise ValueError(
Expand Down Expand Up @@ -306,6 +307,7 @@ def local_kv_cache(

Raises:
ValueError: If the model already has caches setup.
You may use :func:`~torchtune.modules.common_utils.delete_kv_caches` to delete existing caches.
"""
if model.caches_are_setup():
raise ValueError(
Expand Down Expand Up @@ -340,29 +342,31 @@ def delete_kv_caches(model: nn.Module):
>>> dtype=torch.float32,
>>> decoder_max_seq_len=1024)
>>> print(model.caches_are_setup())
>>> True
True
>>> print(model.caches_are_enabled())
>>> True
True
>>> print(model.layers[0].attn.kv_cache)
>>> KVCache()
KVCache()
>>> delete_kv_caches(model)
>>> print(model.caches_are_setup())
>>> False
False
>>> print(model.caches_are_enabled())
>>> False
False
>>> print(model.layers[0].attn.kv_cache)
>>> None
None

Args:
model (nn.Module): model to enable KV-cacheing for.

Raises:
ValueError: if ``delete_kv_caches`` is called on a model which does not have
caches setup.
ValueError: if this function is called on a model which does not have
caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
setup caches first.
"""
if not model.caches_are_setup():
raise ValueError(
"You have tried to delete model caches, but `model.caches_are_setup()` "
"is False!"
"You have tried to delete model caches, but model.caches_are_setup() "
"is False! Please setup caches on the model first."
)
for module in model.modules():
if hasattr(module, "kv_cache") and callable(module.kv_cache):
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/model_fusion/_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def caches_are_enabled(self) -> bool:
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant
attention modules will be "enabled" and all forward passes will update the caches. This behaviour
can be disabled without altering the state of the KV-caches by "disabling" the KV-caches
using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False.
using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False.
"""
return self.decoder.caches_are_enabled()

Expand Down
14 changes: 9 additions & 5 deletions torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def setup_caches(
):
"""
Sets up key-value attention caches for inference. For each layer in ``self.layers``:
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
- :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
- :class:`~torchtune.modules.model_fusion.FusionLayer` will use ``decoder_max_seq_len`` and ``encoder_max_seq_len``.

Args:
batch_size (int): batch size for the caches.
Expand Down Expand Up @@ -460,18 +460,22 @@ def caches_are_enabled(self) -> bool:
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant
attention modules will be "enabled" and all forward passes will update the caches. This behaviour
can be disabled without altering the state of the KV-caches by "disabling" the KV-caches
using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False.
using :func:`torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False.
"""
return self.layers[0].caches_are_enabled()

def reset_caches(self):
"""
Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero,
without deleting or reallocating cache tensors.

Raises:
RuntimeError: if KV-caches are not setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
setup caches first.
"""
if not self.caches_are_enabled():
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
"Key value caches are not setup. Call model.setup_caches first."
)

for layer in self.layers:
Expand Down
Loading