From e7f892b5650d84968f8dded0d083fa730588595c Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 29 Aug 2024 16:54:39 +0100 Subject: [PATCH 01/22] refactoring setup_caches --- docs/source/api_ref_modules.rst | 1 + recipes/eleuther_eval.py | 4 +- recipes/generate.py | 3 +- .../modules/model_fusion/test_fusion_layer.py | 6 +- .../model_fusion/test_fusion_models.py | 24 ++++++-- tests/torchtune/modules/test_common_utils.py | 15 +++++ .../modules/test_transformer_decoder.py | 3 +- tests/torchtune/utils/test_generation.py | 5 +- torchtune/modules/__init__.py | 4 +- torchtune/modules/attention.py | 7 ++- torchtune/modules/common_utils.py | 45 +++++++++++++- torchtune/modules/kv_cache.py | 23 +++++++- torchtune/modules/model_fusion/_fusion.py | 28 +++++---- torchtune/modules/transformer.py | 58 +++++++------------ 14 files changed, 157 insertions(+), 69 deletions(-) create mode 100644 tests/torchtune/modules/test_common_utils.py diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index c62e338801..2497fdbfe5 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -92,6 +92,7 @@ These are utilities that are common to and can be used by all modules. :nosignatures: common_utils.reparametrize_as_dtype_state_dict_post_hook + common_utils.setup_caches Vision Transforms diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 49f1430047..26117b4d44 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -16,7 +16,7 @@ from torch.nn.utils.rnn import pad_sequence from torchtune import config, utils -from torchtune.modules import TransformerDecoder +from torchtune.modules import setup_caches, TransformerDecoder from torchtune.modules.tokenizers import ModelTokenizer from torchtune.recipe_interfaces import EvalRecipeInterface @@ -149,7 +149,7 @@ def _model_generate( # are not needed for a regular model call, so we just setup here if self.enable_kv_cache: with context.device: - self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype) + setup_caches(batch_size=curr_batch_size, dtype=self._dtype) temperature = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", False) diff --git a/recipes/generate.py b/recipes/generate.py index 8f094dac40..8405aa7a52 100644 --- a/recipes/generate.py +++ b/recipes/generate.py @@ -15,6 +15,7 @@ from torchtune import config, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import ChatFormat, InstructTemplate, Message +from torchtune.modules import setup_caches logger = utils.get_logger("DEBUG") @@ -81,7 +82,7 @@ def _setup_model( # Ensure the cache is setup on the right device if enable_kv_cache: with self._device: - model.setup_caches(batch_size=1, dtype=self._dtype) + setup_caches(batch_size=1, dtype=self._dtype) return model diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layer.py index 33258f3d60..74047690ca 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layer.py @@ -24,7 +24,7 @@ def __init__(self, dim): self.linear = nn.Linear(dim, dim) self.cache_enabled = False - def setup_cache(self, batch_size, dtype): + def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): self.cache_enabled = True def reset_cache(self): @@ -115,7 +115,9 @@ def test_setup_cache(self, fused_layer): """ Test that the cache methods works as expected. """ - fused_layer.setup_cache(2, torch.float32) + fused_layer.setup_cache( + 2, torch.float32, encoder_max_seq_len=1, decoder_max_seq_len=1 + ) assert fused_layer.cache_enabled fused_layer.reset_cache() assert not fused_layer.cache_enabled diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 1c579ea6ca..a6af12daef 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -9,6 +9,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn +from torchtune.modules import setup_caches from torchtune.modules.model_fusion import DeepFusionModel from torchtune.utils.seed import set_seed @@ -18,6 +19,17 @@ def random(): set_seed(1) +class DummyLayer: + def __init__(self): + self.cache_enabled = False + + def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): + self.cache_enabled = True + + def reset_cache(self): + self.cache_enabled = False + + class DummyModel(nn.Module): def __init__(self, dim, vocab_size): super().__init__() @@ -27,15 +39,15 @@ def __init__(self, dim, vocab_size): self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.output = nn.Linear(dim, vocab_size) - - def setup_caches(self, batch_size, dtype): - self.cache_enabled = True + self.max_seq_len = 2 + self.layers = [DummyLayer()] def caches_are_enabled(self): - return self.cache_enabled + return self.layers[0].cache_enabled def reset_caches(self): - self.cache_enabled = False + for layer in self.layers: + layer.reset_cache() def forward(self, tokens, mask, encoder_input, encoder_mask, input_pos): x = self.embed(tokens) @@ -141,7 +153,7 @@ def test_setup_cache(self, fused_model): """ Test that the cache methods works as expected. """ - fused_model.setup_caches(2, torch.float32) + setup_caches(fused_model.decoder, batch_size=2, dtype=torch.float32) assert fused_model.caches_are_enabled() fused_model.reset_caches() assert not fused_model.caches_are_enabled() diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py new file mode 100644 index 0000000000..b4ca73c92c --- /dev/null +++ b/tests/torchtune/modules/test_common_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# from torchtune.modules import setup_caches + + +# class TestSetupCaches: +# def test_setup_caches_decoder_model(self): +# pass + +# def test_setup_caches_fusion_model(self): +# pass diff --git a/tests/torchtune/modules/test_transformer_decoder.py b/tests/torchtune/modules/test_transformer_decoder.py index 3cc9ff8c0b..7dbbf9e92d 100644 --- a/tests/torchtune/modules/test_transformer_decoder.py +++ b/tests/torchtune/modules/test_transformer_decoder.py @@ -23,6 +23,7 @@ MultiHeadAttention, RMSNorm, RotaryPositionalEmbeddings, + setup_caches, TanhGate, TransformerCrossAttentionLayer, TransformerDecoder, @@ -301,7 +302,7 @@ def decoder_with_kv_cache_enabled( for p in decoder.parameters(): nn.init.constant_(p, 0.2) decoder.eval() - decoder.setup_caches(batch_size=4, dtype=torch.float32) + setup_caches(decoder, batch_size=4, dtype=torch.float32) return decoder def test_forward( diff --git a/tests/torchtune/utils/test_generation.py b/tests/torchtune/utils/test_generation.py index 15c4a336fa..8a451897e9 100644 --- a/tests/torchtune/utils/test_generation.py +++ b/tests/torchtune/utils/test_generation.py @@ -11,6 +11,7 @@ from torchtune import utils from torchtune.models.llama2 import llama2 +from torchtune.modules import setup_caches from torchtune.utils._generation import sample @@ -30,7 +31,7 @@ def generation_model(self, dtype=torch.float32): max_seq_len=2048, ) fixed_init_model(model) - model.setup_caches(batch_size=1, dtype=dtype) + setup_caches(model, batch_size=1, dtype=dtype) model.eval() return model @@ -59,7 +60,7 @@ def generation_model_batched(self, dtype=torch.float32): max_seq_len=2048, ) fixed_init_model(model) - model.setup_caches(batch_size=2, dtype=dtype) + setup_caches(model, batch_size=2, dtype=dtype) model.eval() return model diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 1c508941f7..c1ec585f0e 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. from .attention import MultiHeadAttention # noqa -from .common_utils import reparametrize_as_dtype_state_dict_post_hook +from .common_utils import reparametrize_as_dtype_state_dict_post_hook, setup_caches from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa + from .layer_norm import Fp32LayerNorm # noqa from .low_precision import FrozenNF4Linear # noqa from .lr_schedulers import get_cosine_schedule_with_warmup # noqa @@ -29,6 +30,7 @@ "FrozenNF4Linear", "get_cosine_schedule_with_warmup", "KVCache", + "setup_caches", "RotaryPositionalEmbeddings", "RMSNorm", "Fp32LayerNorm", diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 99ddb17b1c..efbe243975 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -136,13 +136,16 @@ def __init__( self.k_norm = k_norm self.pos_embeddings = pos_embeddings - def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: """Setup key value caches for attention calculation. If called after kv_cache is already setup, this will be skipped. Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. """ # Don't overwrite user defined kv_cache from init if self.kv_cache is not None: @@ -152,7 +155,7 @@ def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: else: self.kv_cache = KVCache( batch_size=batch_size, - max_seq_len=self.max_seq_len, + max_seq_len=max_seq_len, num_heads=self.num_heads, head_dim=self.head_dim, dtype=dtype, diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 9c588fabba..b4ab3fb5cd 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -4,12 +4,55 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from torchao.dtypes.nf4tensor import NF4Tensor +from torchtune.modules.transformer import TransformerDecoder + + +def setup_caches( + model: TransformerDecoder, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, +): + """ + Setup static key-value caches for attention calculation for a given ``TransformerDecoder` model. + This function supports cache setup for both decoder, and encoder-decoder models. + + Concretely, all layers which are an instance of :class:`~torchtune.modules.TransformerSelfAttentionLayer` + will use ``decoder_max_seq_len``, and all layers which are an instance + of :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + :class:`~torchtune.modules.model_fusion.FusionLayer` will use both. + + Args: + model (TransformerDecoder): An instance of a ``TransformerDecoder`` model. + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (Optional[int]): maximum cache sequence length for encoder layers. + Default None, in which case ``model.max_seq_len`` is used. + decoder_max_seq_len (Optional[int]): maximum cache sequence length for decoder layers. + Default None, in which case ``model.max_seq_len`` is used. + + """ + encoder_max_seq_len = ( + model.max_seq_len if encoder_max_seq_len is None else encoder_max_seq_len + ) + decoder_max_seq_len = ( + model.max_seq_len if decoder_max_seq_len is None else decoder_max_seq_len + ) + for layer in model.layers: + layer.setup_cache( + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) def reparametrize_as_dtype_state_dict_post_hook( diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 1ad55a4b8e..4118bedd49 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -53,7 +53,7 @@ def update( ) -> Tuple[Tensor, Tensor]: """Update KV cache with the new k_val, v_val and return the updated cache. - Raises an assertion error if ``input_pos`` is longer than the maximum sequence length. + Raises an assertion error Args: input_pos (Tensor): Current position tensor with shape [S] @@ -62,12 +62,29 @@ def update( Returns: Tuple[Tensor, Tensor]: Updated KV cache with key first + + Raises: + ValueError: if the sequence length of ``input_pos`` is longer than the maximum sequence length. + ValueError: if the batch size of the new key (or value) tensor is greater than the batch size + used during cache setup. """ - assert input_pos.shape[0] == k_val.shape[2] - self.size = input_pos.max().item() + 1 + if input_pos.shape[0] != k_val.shape[2]: + raise ValueError( + f"The current cache has been setup with a sequence length of {k_val.shape[2]}" + f", but found cache positions with sequence length {input_pos.shape[0]}!" + ) + + if k_val.shape[0] > self.k_cache.shape[0]: + raise ValueError( + f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}" + f", but found new key tensors with batch size {k_val.shape[0]}!" + ) + + self.size = input_pos.max().item() + 1 k_out = self.k_cache v_out = self.v_cache + k_out[:, :, input_pos] = k_val v_out[:, :, input_pos] = v_val diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 689d823dec..fbe9ed61a1 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -87,15 +87,28 @@ def _load_state_dict_hook(self, state_dict, *args, **kwargs): state_dict[new_key] = state_dict[key] del state_dict[key] - def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: + def setup_cache( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int, + decoder_max_seq_len: int, + ) -> None: """Setup key value cache for both layers. Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): maximum cache sequence length. + decoder_max_seq_len (int): this parameter is ignored in this layer. """ - self.layer.setup_cache(batch_size, dtype) - self.fusion_layer.setup_cache(batch_size, dtype) + self.layer.setup_cache( + batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len + ) + self.fusion_layer.setup_cache( + batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len + ) @property def cache_enabled(self) -> bool: @@ -304,15 +317,6 @@ def __init__( self.decoder = decoder self.encoder = encoder - def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: - """Setup key value caches for attention calculation. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - """ - self.decoder.setup_caches(batch_size, dtype) - def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" return self.decoder.caches_are_enabled() diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index f9d9515f61..a2adaa050e 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -43,14 +43,23 @@ def __init__( self.sa_scale = sa_scale or nn.Identity() self.mlp_scale = mlp_scale or nn.Identity() - def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: + def setup_cache( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int, + decoder_max_seq_len: int, + ) -> None: """Setup key value caches for attention calculation. Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): this parameter is ignored in this layer. + decoder_max_seq_len (int): maximum cache sequence length. """ - self.attn.setup_cache(batch_size, dtype) + self.attn.setup_cache(batch_size, dtype, max_seq_len=decoder_max_seq_len) @property def cache_enabled(self) -> bool: @@ -148,14 +157,23 @@ def __init__( self.ca_scale = ca_scale or nn.Identity() self.mlp_scale = mlp_scale or nn.Identity() - def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: + def setup_cache( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int, + decoder_max_seq_len: int, + ) -> None: """Setup key value caches for attention calculation. Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): maximum cache sequence length. + decoder_max_seq_len (int): this parameter is ignored in this layer. """ - self.attn.setup_cache(batch_size, dtype) + self.attn.setup_cache(batch_size, dtype, encoder_max_seq_len) @property def cache_enabled(self) -> bool: @@ -339,22 +357,6 @@ def __init__( self.head_dim = head_dim self.causal_mask = None - def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: - """Setup key value caches for attention calculation. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - """ - for layer in self.layers: - layer.setup_cache(batch_size, dtype) - - # causal_mask is used during inference to ensure we're attending - # to the right tokens - self.causal_mask = torch.tril( - torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool) - ) - def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" return self.layers[0].cache_enabled @@ -534,22 +536,6 @@ def __init__( self.head_dim = head_dim self.causal_mask = None - def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: - """Setup key value caches for attention calculation. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - """ - for layer in self.layers: - layer.setup_cache(batch_size, dtype) - - # causal_mask is used during inference to ensure we're attending - # to the right tokens - self.causal_mask = torch.tril( - torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool) - ) - def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" return self.layers[0].cache_enabled From 895e93729a1bd3b645f3c342000593395e10f5db Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 4 Sep 2024 20:39:08 +0100 Subject: [PATCH 02/22] merging --- tests/assets/invalid_dummy_config.yaml | 4 +- tests/assets/valid_dummy_config.yaml | 2 +- tests/recipes/test_eleuther_eval.py | 4 +- .../test_full_finetune_single_device.py | 9 +- .../recipes/test_lora_finetune_distributed.py | 2 +- tests/recipes/test_lora_finetune_fsdp2.py | 2 +- .../test_lora_finetune_single_device.py | 9 +- .../test_ppo_full_tunetune_single_device.py | 20 +- tests/recipes/test_qat_distributed.py | 2 +- tests/recipes/utils.py | 6 +- tests/regression_tests/test_llama2_7b.py | 4 +- .../datasets/test_instruct_dataset.py | 16 +- .../models/llama2/scripts/compare_dora.py | 259 ++++++ .../models/llama2/test_lora_llama2.py | 10 +- tests/torchtune/models/phi3/test_lora_phi3.py | 10 +- .../loss/test_ce_chunked_output_loss.py | 50 ++ .../modules/model_fusion/test_fusion_layer.py | 6 +- .../model_fusion/test_fusion_models.py | 24 +- tests/torchtune/modules/peft/test_dora.py | 192 +++++ tests/torchtune/modules/peft/test_lora.py | 8 +- tests/torchtune/modules/peft/test_utils.py | 118 ++- .../modules/test_transformer_decoder.py | 3 +- .../checkpointing/test_checkpointer.py | 766 ++++++++++++++++++ .../checkpointing/test_checkpointer_utils.py | 177 ++++ .../torchtune/training/test_metric_logging.py | 201 +++++ tests/torchtune/training/test_precision.py | 91 +++ tests/torchtune/training/test_profiler.py | 253 ++++++ tests/torchtune/utils/test_distributed.py | 12 +- tests/torchtune/utils/test_generation.py | 5 +- torchtune/models/gemma/transformer.py | 36 +- torchtune/modules/model_fusion/_fusion.py | 30 +- torchtune/modules/transformer.py | 54 +- 32 files changed, 2221 insertions(+), 164 deletions(-) create mode 100644 tests/torchtune/models/llama2/scripts/compare_dora.py create mode 100644 tests/torchtune/modules/loss/test_ce_chunked_output_loss.py create mode 100644 tests/torchtune/modules/peft/test_dora.py create mode 100644 tests/torchtune/training/checkpointing/test_checkpointer.py create mode 100644 tests/torchtune/training/checkpointing/test_checkpointer_utils.py create mode 100644 tests/torchtune/training/test_metric_logging.py create mode 100644 tests/torchtune/training/test_precision.py create mode 100644 tests/torchtune/training/test_profiler.py diff --git a/tests/assets/invalid_dummy_config.yaml b/tests/assets/invalid_dummy_config.yaml index d40b764469..56f0ff26db 100644 --- a/tests/assets/invalid_dummy_config.yaml +++ b/tests/assets/invalid_dummy_config.yaml @@ -1,8 +1,8 @@ test1: - _component_: torchtune.utils.get_dtype + _component_: torchtune.training.get_dtype dtype: fp32 dummy: 3 test2: - _component_: torchtune.utils.get_dtype + _component_: torchtune.training.get_dtype dtype: fp32 dummy: 3 diff --git a/tests/assets/valid_dummy_config.yaml b/tests/assets/valid_dummy_config.yaml index d98e013086..6969ee8292 100644 --- a/tests/assets/valid_dummy_config.yaml +++ b/tests/assets/valid_dummy_config.yaml @@ -1,3 +1,3 @@ test: - _component_: torchtune.utils.get_dtype + _component_: torchtune.training.get_dtype dtype: fp32 diff --git a/tests/recipes/test_eleuther_eval.py b/tests/recipes/test_eleuther_eval.py index 8ce747f070..c522805b38 100644 --- a/tests/recipes/test_eleuther_eval.py +++ b/tests/recipes/test_eleuther_eval.py @@ -29,7 +29,7 @@ def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir): tune run eleuther_eval \ --config eleuther_evaluation \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ @@ -90,7 +90,7 @@ def test_eval_recipe_errors_without_lm_eval(self, caplog, monkeypatch, tmpdir): tune run eleuther_eval \ --config eleuther_evalation \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index bbefe3df12..646e53382c 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -46,6 +46,7 @@ def _get_test_config_overrides(self): "optimizer=torch.optim.AdamW", "optimizer.lr=2e-5", "log_every_n_steps=1", + "clip_grad_norm=100", ] + dummy_alpaca_dataset_config() def _fetch_expected_loss_values(self, model_type): @@ -129,7 +130,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tune run full_finetune_single_device \ --config llama2/7B_full_low_memory \ output_dir={tmpdir} \ - checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ @@ -150,7 +151,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tune run full_finetune_single_device \ --config llama2/7B_full_low_memory \ output_dir={tmpdir} \ - checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\ checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} @@ -216,7 +217,7 @@ def test_gradient_accumulation(self, tmpdir, monkeypatch): cmd_1 = f""" tune run full_finetune_single_device \ --config llama2/7B_full_low_memory \ - checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ @@ -242,7 +243,7 @@ def test_gradient_accumulation(self, tmpdir, monkeypatch): cmd_2 = f""" tune run full_finetune_single_device \ --config llama2/7B_full_low_memory \ - checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index d8e268a99c..1828d5332e 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -71,7 +71,7 @@ def test_loss(self, fsdp_sharding_strategy, tmpdir, monkeypatch): tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ diff --git a/tests/recipes/test_lora_finetune_fsdp2.py b/tests/recipes/test_lora_finetune_fsdp2.py index 25a65016e3..5293b6cfe4 100644 --- a/tests/recipes/test_lora_finetune_fsdp2.py +++ b/tests/recipes/test_lora_finetune_fsdp2.py @@ -69,7 +69,7 @@ def test_loss(self, tmpdir, monkeypatch): tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_lora \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index e12dec9fa7..c899e91136 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -43,6 +43,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "optimizer.lr=2e-5", "log_every_n_steps=1", "gradient_accumulation_steps=1", + "clip_grad_norm=100", ] + dummy_alpaca_dataset_config() def _fetch_expected_loss_values(self, model_type): @@ -127,7 +128,7 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): tune run lora_finetune_single_device --config llama2/7B_qlora_single_device \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelMetaCheckpointer + checkpointer=torchtune.training.FullModelMetaCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ @@ -176,7 +177,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tune run lora_finetune_single_device \ --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelHFCheckpointer \ + checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ @@ -197,7 +198,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tune run lora_finetune_single_device \ --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelHFCheckpointer \ + checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} @@ -232,7 +233,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): tune run lora_finetune_single_device \ --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ diff --git a/tests/recipes/test_ppo_full_tunetune_single_device.py b/tests/recipes/test_ppo_full_tunetune_single_device.py index 80163cc5f0..c895924e02 100644 --- a/tests/recipes/test_ppo_full_tunetune_single_device.py +++ b/tests/recipes/test_ppo_full_tunetune_single_device.py @@ -67,7 +67,7 @@ def test_loss(self, tmpdir, monkeypatch): tune run ppo_full_finetune_single_device \ --config mistral/7B_full_ppo_low_memory \ output_dir={tmpdir} \ - checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{policy_ckpt_path}]\ checkpointer.output_dir={policy_tmpdir} \ @@ -83,7 +83,7 @@ def test_loss(self, tmpdir, monkeypatch): reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ - metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ """.split() @@ -157,7 +157,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tune run ppo_full_finetune_single_device \ --config mistral/7B_full_ppo_low_memory \ output_dir={tmpdir} \ - checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{policy_ckpt_path}]\ checkpointer.output_dir={policy_tmpdir} \ @@ -173,7 +173,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ - metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ """.split() @@ -209,7 +209,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tune run ppo_full_finetune_single_device \ --config mistral/7B_full_ppo_low_memory \ output_dir={tmpdir} \ - checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{policy_tmpdir}' \ checkpointer.checkpoint_files=[{os.path.join(policy_tmpdir, "hf_model_0001_0.pt")}]\ checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, "recipe_state.pt")}\ @@ -227,7 +227,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ resume_from_checkpoint=True \ - metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={resumed_log_file} \ """.split() @@ -277,7 +277,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc tune run ppo_full_finetune_single_device \ --config mistral/7B_full_ppo_low_memory \ output_dir={tmpdir} \ - checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{policy_ckpt_path}]\ checkpointer.output_dir={policy_tmpdir} \ @@ -293,7 +293,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ - metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ optimizer_in_bwd=True @@ -331,7 +331,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc tune run ppo_full_finetune_single_device \ --config mistral/7B_full_ppo_low_memory \ output_dir={tmpdir} \ - checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{policy_tmpdir}' \ checkpointer.checkpoint_files=[{os.path.join(policy_tmpdir, "hf_model_0001_0.pt")}]\ checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, "recipe_state.pt")}\ @@ -349,7 +349,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ resume_from_checkpoint=True \ - metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={resumed_log_file} \ optimizer_in_bwd=True diff --git a/tests/recipes/test_qat_distributed.py b/tests/recipes/test_qat_distributed.py index 6eac534a22..5d4d7069f1 100644 --- a/tests/recipes/test_qat_distributed.py +++ b/tests/recipes/test_qat_distributed.py @@ -95,5 +95,5 @@ def test_loss(self, config, model_type, ckpt_type, tmpdir, monkeypatch): loss_values = get_loss_values_from_metric_logger(log_file) expected_loss_values = self._fetch_expected_loss_values(model_type) torch.testing.assert_close( - loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + loss_values, expected_loss_values, rtol=1e-3, atol=1e-3 ) diff --git a/tests/recipes/utils.py b/tests/recipes/utils.py index ae47b813a5..dea31e6221 100644 --- a/tests/recipes/utils.py +++ b/tests/recipes/utils.py @@ -13,9 +13,9 @@ from torch.utils.data import Dataset CKPT_COMPONENT_MAP = { - "tune": "torchtune.utils.FullModelTorchTuneCheckpointer", - "meta": "torchtune.utils.FullModelMetaCheckpointer", - "hf": "torchtune.utils.FullModelHFCheckpointer", + "tune": "torchtune.training.FullModelTorchTuneCheckpointer", + "meta": "torchtune.training.FullModelMetaCheckpointer", + "hf": "torchtune.training.FullModelHFCheckpointer", } diff --git a/tests/regression_tests/test_llama2_7b.py b/tests/regression_tests/test_llama2_7b.py index a7569f36c1..cba0a39032 100644 --- a/tests/regression_tests/test_llama2_7b.py +++ b/tests/regression_tests/test_llama2_7b.py @@ -37,7 +37,7 @@ def test_finetune_and_eval(self, tmpdir, capsys, monkeypatch): tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ @@ -52,7 +52,7 @@ def test_finetune_and_eval(self, tmpdir, capsys, monkeypatch): tune run eleuther_eval \ --config eleuther_evaluation \ output_dir={tmpdir} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{tmpdir}' \ checkpointer.checkpoint_files=[torchtune_model_0.pt] \ checkpointer.output_dir={tmpdir} \ diff --git a/tests/torchtune/datasets/test_instruct_dataset.py b/tests/torchtune/datasets/test_instruct_dataset.py index 352f173df5..c734bec885 100644 --- a/tests/torchtune/datasets/test_instruct_dataset.py +++ b/tests/torchtune/datasets/test_instruct_dataset.py @@ -62,16 +62,18 @@ def test_get_item(self, train_on_input): assert label == expected_labels[i] expected_tokenized_prompts = [ - [0, 4, 4, 2, 2, 2, 7, 2, 2, 5, 2, 2, 6, -1], - [0, 2, 2, 8, 2, 15, 8, 3, 15, 3, 4, 9, 3, 15, -1], + [0, 6, 4, 6, 4, 4, 2, 2, 2, 7, 2, 2, 5, 2, 2, 6, -1], + [0, 6, 4, 6, 2, 2, 8, 2, 15, 8, 3, 15, 3, 4, 9, 3, 15, -1], ] - prompt_lengths = (7, 6) + prompt_lengths = (10, 9) expected_labels = [ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + [2, 2, 5, 2, 2, 6, -1], [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + [8, 3, 15, 3, 4, 9, 3, 15, -1], ] + system_prompt = "follow this prompt" + dataset = instruct_dataset( tokenizer=DummyTokenizer(), source="json", @@ -79,13 +81,19 @@ def test_get_item(self, train_on_input): data_files=str(ASSETS / "instruct_tiny.json"), column_map={"input": "instruction", "output": "response"}, split="train", + new_system_prompt=system_prompt, ) + system_prompt_offset = len(system_prompt.split(" ")) + 1 # +1 for bos token + assert len(dataset) == 2 for i in range(len(dataset)): prompt, label = dataset[i]["tokens"], dataset[i]["labels"] assert prompt == expected_tokenized_prompts[i] if train_on_input: - assert label == expected_tokenized_prompts[i] + assert ( + label[system_prompt_offset:] + == expected_tokenized_prompts[i][system_prompt_offset:] + ) else: assert label == expected_labels[i] diff --git a/tests/torchtune/models/llama2/scripts/compare_dora.py b/tests/torchtune/models/llama2/scripts/compare_dora.py new file mode 100644 index 0000000000..3ff8c760c9 --- /dev/null +++ b/tests/torchtune/models/llama2/scripts/compare_dora.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import pytest + +import torch +import torch.nn.functional as F +from torch import nn +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 +from torchtune import utils +from torchtune.modules.peft import ( + DoRALinear, + get_merged_lora_ckpt, + load_dora_magnitudes, + LoRALinear, +) +from torchtune.utils.seed import set_seed + + +def compare_dora(self, dtype, use_bias, quantize_base): + dropout = 0.0 + batch_size = 2 + in_dim = 64 + out_dim = 128 + rank = 4 + alpha = 1.0 + use_bias = False + quantize_base = False + dtype = torch.bfloat16 + + constructor_kwargs = { + "in_dim": in_dim, + "out_dim": out_dim, + "rank": rank, + "alpha": alpha, + "dropout": dropout, + "use_bias": use_bias, + "quantize_base": quantize_base, + } + + # this combo is not supported yet + if use_bias: + with pytest.raises( + NotImplementedError, match="DoRALinear does not support using bias" + ): + DoRALinear(**constructor_kwargs) + return + + # build our DoRA module and a reference module for comparison + with utils.set_default_dtype(dtype): + module = DoRALinear(**constructor_kwargs) + ref = _DoraReference(dtype=dtype, **constructor_kwargs) + lora_module = LoRALinear(**constructor_kwargs) + + # make the initial parameters equal + state_dict = ref.state_dict() + lora_state_dict = ref.state_dict() + if quantize_base: + state_dict["weight"] = state_dict["weight"].to(torch.float32) + lora_state_dict["weight"] = lora_state_dict["weight"].to(torch.float32) + state_dict["magnitude"] = state_dict.pop("lora_magnitude") + lora_state_dict.pop("lora_magnitude") + module.load_state_dict(state_dict) + lora_module.load_state_dict(lora_state_dict) + + # freeze the base params + module.weight.requires_grad_(False) + lora_module.weight.requires_grad_(False) + ref.weight.requires_grad_(False) + if use_bias: + module.bias.requires_grad_(False) + module.lora_module.requires_grad_(False) + ref.bias.requires_grad_(False) + + @torch.no_grad + def _dora_is_the_same_as_lora(): + module.eval() + lora_module.eval() + x = torch.randn(batch_size, in_dim, dtype=dtype) + lora_out = lora_module(x) + dora_out = module(x) + return torch.equal(lora_out, dora_out) + + # DoRA initializes the magnitude vector (after the base params are loaded) + # such that its outputs are initially identical to standard LoRA's outputs. + # Verify that this is true. + assert not _dora_is_the_same_as_lora() + module.initialize_dora_magnitude() + load_dora_magnitudes(module) + assert _dora_is_the_same_as_lora() + + def _compare_params(): + assert torch.allclose( + module.weight.to(torch.float32), ref.weight.to(torch.float32) + ) + assert torch.allclose(module.lora_a.weight, ref.lora_a.weight) + assert torch.allclose(module.lora_b.weight, ref.lora_b.weight) + assert torch.allclose(module.magnitude, ref.lora_magnitude) + + # verify that the param values match the reference + ref.initialize_dora() + _compare_params() + + # compare a single training step to the reference + module.train() + ref.train() + opt = torch.optim.Adam(module.parameters()) + opt_ref = torch.optim.Adam(ref.parameters()) + opt.zero_grad() + opt_ref.zero_grad() + # x = torch.randn(batch_size, in_dim, dtype=dtype) + x = torch.randn(batch_size, 32, in_dim) + y = torch.randn(batch_size, out_dim) + torch.manual_seed(0) + y1 = module(x.detach()) + torch.manual_seed(0) + y2 = ref(x.detach()) + F.mse_loss(y1.to(torch.float32), y.detach()).backward() + F.mse_loss(y2.to(torch.float32), y.detach()).backward() + assert torch.allclose(y1, y2) + assert torch.allclose(module.magnitude.grad, ref.lora_magnitude.grad) + assert torch.allclose(module.lora_a.weight.grad, ref.lora_a.weight.grad) + assert torch.allclose(module.lora_b.weight.grad, ref.lora_b.weight.grad) + opt.step() + opt_ref.step() + _compare_params() + + # verify that the merged and unmerged DoRA modules have nearly identical outputs + state_dict = get_merged_lora_ckpt(_Wrapper(module).state_dict(), rank, alpha) + merged = _Wrapper(nn.Linear(in_dim, out_dim, bias=use_bias, dtype=dtype)) + merged.load_state_dict(state_dict) + merged = merged.layer + module.eval() + merged.eval() + with torch.no_grad(): + x = torch.randn(batch_size, in_dim, dtype=dtype) + y1 = module(x) + y2 = merged(x) + mse = F.mse_loss(y1.float(), y2.float()) + assert mse < (1e-8 if dtype == torch.float32 else 1e-2) + + +class _Wrapper(nn.Module): + """ + For testing the merged checkpoint which requires that the LoRA layer has a parent. + """ + + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, x): + return self.layer(x) + + +class _DoraReference(nn.Module): + """ + DoRA linear layer reference. + + Paper: https://arxiv.org/abs/2402.09353 + + Based on the code from: + https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py + https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/dora.py + + For more info, see the discussion here: + https://github.com/huggingface/peft/pull/1474 + """ + + def __init__( + self, + dtype: torch.dtype, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + quantize_base: bool = False, + use_dora: bool = True, + ): + super().__init__() + self.use_bias = use_bias + self.quantize_base = quantize_base + self.use_dora = use_dora + + linear = nn.Linear( + in_features=in_dim, out_features=out_dim, bias=use_bias, dtype=dtype + ) + weight = linear.weight if not quantize_base else to_nf4(linear.weight) + bias = None + if use_bias: + if quantize_base: + raise NotImplementedError() + bias = linear.bias + self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) + + self.lora_a = nn.Linear(in_dim, rank, bias=False, dtype=dtype) + self.lora_b = nn.Linear(rank, out_dim, bias=False, dtype=dtype) + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_b.weight) + self.scaling = alpha / rank + if use_dora: + self.lora_magnitude = nn.Parameter(torch.randn(out_dim, dtype=dtype)) + self.dropout = nn.Dropout(p=dropout) + + def initialize_dora(self): + weight = self.weight.to(self.lora_a.weight.dtype) + lora_weight = self.lora_b.weight @ self.lora_a.weight + weight_norm = self._get_weight_norm(weight, lora_weight) + self.lora_magnitude = nn.Parameter(weight_norm, requires_grad=True) + + def forward(self, x): + result = self._base_forward(x) + torch_result_dtype = result.dtype + x = x.to(self.lora_a.weight.dtype) + if not self.use_dora: + result = result + self.lora_b(self.lora_a(self.dropout(x))) * self.scaling + else: + x = self.dropout(x) + result = result + self._dora_forward(x) + result = result.to(torch_result_dtype) + print("result mean", result.mean()) + return result + + def _base_forward(self, x): + if self.quantize_base: + return linear_nf4(input=x, weight=self.weight) + return F.linear(x, self.weight, self.bias) + + def _dora_forward(self, x): + lora_result = self.lora_b(self.lora_a(x)) + x_eye = torch.eye( + self.lora_a.weight.shape[1], device=self.lora_a.weight.device, dtype=x.dtype + ) + lora_weight = self.lora_b(self.lora_a(x_eye)).T + + magnitude = self.lora_magnitude + weight = self.weight.to(x.dtype) + weight_norm = self._get_weight_norm(weight, lora_weight.detach()) + weight_norm = weight_norm.detach() + mag_norm_scale = (magnitude / weight_norm).view(1, -1) + result_dora = (mag_norm_scale - 1) * ( + F.linear(x, weight) + ) + mag_norm_scale * lora_result * self.scaling + return result_dora + + def _get_weight_norm(self, weight, lora_weight): + weight = weight + self.scaling * lora_weight + weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) + return weight_norm diff --git a/tests/torchtune/models/llama2/test_lora_llama2.py b/tests/torchtune/models/llama2/test_lora_llama2.py index 3bbd5b10ee..10ad0ce8dd 100644 --- a/tests/torchtune/models/llama2/test_lora_llama2.py +++ b/tests/torchtune/models/llama2/test_lora_llama2.py @@ -12,7 +12,7 @@ from tests.test_utils import assert_expected, fixed_init_model from torch import nn from torchao.dtypes.nf4tensor import NF4Tensor -from torchtune import utils +from torchtune import training from torchtune.models.llama2 import llama2, lora_llama2 from torchtune.models.llama2._component_builders import lora_llama2_self_attention from torchtune.modules.low_precision import FrozenNF4Linear @@ -225,7 +225,7 @@ def test_qlora_linear_quantize_base_weights(self): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_llama2_parity(self, dtype, inputs): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): model_ref = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -255,7 +255,7 @@ def test_qlora_llama2_parity(self, dtype, inputs): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_llama2_state_dict(self, dtype): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): model_ref = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -291,7 +291,7 @@ def test_qlora_llama2_state_dict(self, dtype): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_llama2_merged_state_dict(self, dtype): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): qlora = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -312,7 +312,7 @@ def test_qlora_llama2_merged_state_dict(self, dtype): assert v.dtype == dtype # Ensure checkpoint can be loaded into non-LoRA model - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): llama2 = self.get_ref_llama2(vocab_size=50, embed_dim=512) llama2.load_state_dict(merged_ckpt) diff --git a/tests/torchtune/models/phi3/test_lora_phi3.py b/tests/torchtune/models/phi3/test_lora_phi3.py index d3c1cf1c44..9d0f9eb7e6 100644 --- a/tests/torchtune/models/phi3/test_lora_phi3.py +++ b/tests/torchtune/models/phi3/test_lora_phi3.py @@ -12,7 +12,7 @@ from tests.test_utils import assert_expected, fixed_init_model from torch import nn from torchao.dtypes.nf4tensor import NF4Tensor -from torchtune import utils +from torchtune import training from torchtune.models.phi3 import lora_phi3, phi3 from torchtune.models.phi3._component_builders import lora_phi3_self_attention from torchtune.modules.peft import get_merged_lora_ckpt, LoRALinear @@ -205,7 +205,7 @@ def test_lora_linear_quantize_base(self): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_phi3_parity(self, dtype, inputs): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): model_ref = self.get_lora_phi3( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -235,7 +235,7 @@ def test_qlora_phi3_parity(self, dtype, inputs): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_phi3_state_dict(self, dtype): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): model_ref = self.get_lora_phi3( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -271,7 +271,7 @@ def test_qlora_phi3_state_dict(self, dtype): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_phi3_merged_state_dict(self, dtype): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): qlora = self.get_lora_phi3( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -295,7 +295,7 @@ def test_qlora_phi3_merged_state_dict(self, dtype): assert v.dtype == dtype # Ensure checkpoint can be loaded into non-LoRA model - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): phi3 = self.get_ref_phi3(vocab_size=50, embed_dim=512) phi3.load_state_dict(merged_ckpt) diff --git a/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py b/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py new file mode 100644 index 0000000000..47b5596dd0 --- /dev/null +++ b/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected +from torchtune.modules.loss import CEWithChunkedOutputLoss +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(42) + + +class TestCEWithChunkedOutputLoss: + def test_chunked_cross_entropy_loss(self): + # Create a sample input and label + ignore_index = -100 + batch_size = 3 + num_tokens = 50 + vocab_size = 50 + logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) + labels = torch.randint( + 0, vocab_size, (batch_size, num_tokens), dtype=torch.long + ) + + # add random ignore index to random tokens in the label + random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) + labels[random_indices < num_tokens // 5] = ignore_index + + # chunked CE + ce_loss = CEWithChunkedOutputLoss( + num_output_chunks=8, ignore_index=ignore_index + ) + logits_chunks = logits.chunk(ce_loss.num_output_chunks, dim=1) + chunked_loss = ce_loss(logits_chunks, labels) + + # vanilla CE + logits = logits.reshape(-1, logits.size(-1)) + labels = labels.reshape(-1) + standard_loss = torch.nn.functional.cross_entropy( + logits.float(), labels, reduction="mean", ignore_index=ignore_index + ) + + # Assert + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layer.py index 74047690ca..33258f3d60 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layer.py @@ -24,7 +24,7 @@ def __init__(self, dim): self.linear = nn.Linear(dim, dim) self.cache_enabled = False - def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): + def setup_cache(self, batch_size, dtype): self.cache_enabled = True def reset_cache(self): @@ -115,9 +115,7 @@ def test_setup_cache(self, fused_layer): """ Test that the cache methods works as expected. """ - fused_layer.setup_cache( - 2, torch.float32, encoder_max_seq_len=1, decoder_max_seq_len=1 - ) + fused_layer.setup_cache(2, torch.float32) assert fused_layer.cache_enabled fused_layer.reset_cache() assert not fused_layer.cache_enabled diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index a6af12daef..1c579ea6ca 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -9,7 +9,6 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn -from torchtune.modules import setup_caches from torchtune.modules.model_fusion import DeepFusionModel from torchtune.utils.seed import set_seed @@ -19,17 +18,6 @@ def random(): set_seed(1) -class DummyLayer: - def __init__(self): - self.cache_enabled = False - - def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): - self.cache_enabled = True - - def reset_cache(self): - self.cache_enabled = False - - class DummyModel(nn.Module): def __init__(self, dim, vocab_size): super().__init__() @@ -39,15 +27,15 @@ def __init__(self, dim, vocab_size): self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.output = nn.Linear(dim, vocab_size) - self.max_seq_len = 2 - self.layers = [DummyLayer()] + + def setup_caches(self, batch_size, dtype): + self.cache_enabled = True def caches_are_enabled(self): - return self.layers[0].cache_enabled + return self.cache_enabled def reset_caches(self): - for layer in self.layers: - layer.reset_cache() + self.cache_enabled = False def forward(self, tokens, mask, encoder_input, encoder_mask, input_pos): x = self.embed(tokens) @@ -153,7 +141,7 @@ def test_setup_cache(self, fused_model): """ Test that the cache methods works as expected. """ - setup_caches(fused_model.decoder, batch_size=2, dtype=torch.float32) + fused_model.setup_caches(2, torch.float32) assert fused_model.caches_are_enabled() fused_model.reset_caches() assert not fused_model.caches_are_enabled() diff --git a/tests/torchtune/modules/peft/test_dora.py b/tests/torchtune/modules/peft/test_dora.py new file mode 100644 index 0000000000..4e630533a2 --- /dev/null +++ b/tests/torchtune/modules/peft/test_dora.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import pytest + +import torch +from tests.test_utils import fixed_init_model +from torch import nn +from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 +from torchtune import training +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook +from torchtune.modules.peft import DoRALinear +from torchtune.utils.seed import set_seed + +RANK = 4 +ALPHA = 1.0 +BSZ = 2 +SEQ_LEN = 32 +EXPECTED_VAL = 0.05201 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestDoRALinear: + """ + Class for testing our DoRALinear implementation. Expected values are computed + from the reference implementation and calculated in scripts/compare_lora.py. + """ + + @pytest.fixture + def in_dim(self) -> int: + return 64 + + @pytest.fixture + def out_dim(self) -> int: + return 128 + + @pytest.fixture + def inputs(self, in_dim) -> torch.Tensor: + inputs = torch.randn(BSZ, SEQ_LEN, in_dim) + return inputs + + @pytest.fixture + def dora_linear(self, in_dim, out_dim) -> DoRALinear: + dora_linear = DoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=RANK, + alpha=ALPHA, + use_bias=False, + ) + + fixed_init_model(dora_linear) + return dora_linear + + @pytest.fixture + def qdora_linear(self, in_dim, out_dim) -> DoRALinear: + with training.set_default_dtype(torch.bfloat16): + qdora_linear = DoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + fixed_init_model(qdora_linear, dtype=torch.bfloat16) + return qdora_linear + + def test_forward(self, inputs, dora_linear, out_dim) -> None: + expected = torch.tensor(EXPECTED_VAL) + actual = dora_linear(inputs) + assert actual.shape == (BSZ, SEQ_LEN, out_dim) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) + + def test_dora_weight_nf4_when_quantized(self, qdora_linear): + assert isinstance(qdora_linear.weight, NF4Tensor) + + def test_bias_raises(self): + with pytest.raises( + NotImplementedError, match="DoRALinear does not support using bias" + ): + DoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=True, + quantize_base=False, + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qdora_parity(self, dtype): + with training.set_default_dtype(dtype): + torch.manual_seed(0) + qdora_linear = DoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + torch.manual_seed(0) + dora_linear = DoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=False, + ) + + # set weight of dora_linear to unquantized weight of qdora_linear and check + # parity. + dora_linear.weight.data = qdora_linear.weight.to(dtype) + + qdora_linear.initialize_dora_magnitude() + dora_linear.initialize_dora_magnitude() + + # Ensure forward passes are the same. This is because DoRALinear should use a special + # quantized linear operator that runs compute in higher prec (but only saves the 4 bit quantized tensor) + # for autograd. + inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=dtype) + torch.manual_seed(0) + dora_linear_out = dora_linear(inputs) + torch.manual_seed(0) + qdora_linear_out = qdora_linear(inputs) + torch.testing.assert_close(dora_linear_out, qdora_linear_out) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_quantized_state_dict(self, dtype): + with training.set_default_dtype(dtype): + dora_linear = DoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + + dora_linear._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + dtype=dtype, + offload_to_cpu=False, + ) + ) + sd = dora_linear.state_dict() + # No nf4 tensors, all have type dtype + for v in sd.values(): + assert v.dtype == dtype + assert not isinstance(v, NF4Tensor) + + # Load back in results in re-quant and creates the same nf4 tensor. + # This also ensures that DoRALinear can load a bf16 state_dict. + dora_linear_reload = DoRALinear( + in_dim=512, + out_dim=512, + rank=RANK, + alpha=ALPHA, + use_bias=False, + quantize_base=True, + ) + # Zero out weight to verify reloading works + dora_linear_reload.weight = nn.Parameter( + to_nf4( + torch.zeros_like( + dora_linear.weight.get_original_weight(), + dtype=dtype, + device=dora_linear.weight.device, + ) + ) + ) + # nf4 tensors should be different + assert not torch.allclose( + dora_linear.weight.quantized_data, dora_linear_reload.weight.quantized_data + ) + # but should be the same after loading + dora_linear_reload.load_state_dict(sd) + assert torch.allclose( + dora_linear.weight.quantized_data, dora_linear_reload.weight.quantized_data + ) diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index d48f2bea82..ae03948a15 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -12,7 +12,7 @@ from tests.test_utils import fixed_init_model from torch import nn from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 -from torchtune import utils +from torchtune import training from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torchtune.modules.peft import LoRALinear from torchtune.utils.seed import set_seed @@ -62,7 +62,7 @@ def lora_linear(self, in_dim, out_dim) -> LoRALinear: @pytest.fixture def qlora_linear(self, in_dim, out_dim) -> LoRALinear: - with utils.set_default_dtype(torch.bfloat16): + with training.set_default_dtype(torch.bfloat16): qlora_linear = LoRALinear( in_dim=512, out_dim=512, @@ -113,7 +113,7 @@ def test_quantize_with_bias_raises(self): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_parity(self, dtype): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): qlora_linear = LoRALinear( in_dim=512, out_dim=512, @@ -145,7 +145,7 @@ def test_qlora_parity(self, dtype): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_quantized_state_dict(self, dtype): - with utils.set_default_dtype(dtype): + with training.set_default_dtype(dtype): lora_linear = LoRALinear( in_dim=512, out_dim=512, diff --git a/tests/torchtune/modules/peft/test_utils.py b/tests/torchtune/modules/peft/test_utils.py index f0f09e3584..032cd88ec4 100644 --- a/tests/torchtune/modules/peft/test_utils.py +++ b/tests/torchtune/modules/peft/test_utils.py @@ -14,6 +14,7 @@ from torchtune.modules.peft import ( AdapterModule, disable_adapter, + DoRALinear, get_adapter_params, get_merged_lora_ckpt, LoRALinear, @@ -116,11 +117,32 @@ def lora_llama2_model(): ) +@pytest.fixture +def dora_llama2_model(): + return lora_llama2( + lora_attn_modules=["q_proj", "v_proj"], + vocab_size=VOCAB_SIZE, + num_layers=N_LAYERS, + num_heads=NUM_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + lora_rank=4, + lora_alpha=1.0, + use_dora=True, + ) + + @pytest.fixture def lora_llama2_model_all_keys(lora_llama2_model): return lora_llama2_model.state_dict().keys() +@pytest.fixture +def dora_llama2_model_all_keys(dora_llama2_model): + return dora_llama2_model.state_dict().keys() + + @pytest.fixture def lora_llama2_expected_adapter_keys(): keys = [] @@ -136,6 +158,23 @@ def lora_llama2_expected_adapter_keys(): return keys +@pytest.fixture +def dora_llama2_expected_adapter_keys(): + keys = [] + for i in range(N_LAYERS): + keys.extend( + [ + f"layers.{i}.attn.q_proj.lora_a.weight", + f"layers.{i}.attn.q_proj.lora_b.weight", + f"layers.{i}.attn.v_proj.lora_a.weight", + f"layers.{i}.attn.v_proj.lora_b.weight", + f"layers.{i}.attn.q_proj.magnitude", + f"layers.{i}.attn.v_proj.magnitude", + ] + ) + return keys + + @pytest.fixture def lora_llama2_expected_base_model_keys(): @@ -156,6 +195,7 @@ class TestPeftUtils: [ ("dummy_adapter_parent_model", "dummy_model_expected_adapter_keys"), ("lora_llama2_model", "lora_llama2_expected_adapter_keys"), + ("dora_llama2_model", "dora_llama2_expected_adapter_keys"), ], ) def test_get_adapter_params(self, request, model_name, expected_keys): @@ -177,6 +217,11 @@ def test_get_adapter_params(self, request, model_name, expected_keys): "lora_llama2_expected_adapter_keys", "lora_llama2_expected_base_model_keys", ), + ( + "dora_llama2_model", + "dora_llama2_expected_adapter_keys", + "lora_llama2_expected_base_model_keys", + ), ], ) def test_set_trainable_params( @@ -305,6 +350,15 @@ def test_set_trainable_params( "lora_llama2_expected_base_model_keys", "", ), + ( + ["q_proj", "v_proj"], + False, + False, + "dora_llama2_model_all_keys", + "dora_llama2_expected_adapter_keys", + "lora_llama2_expected_base_model_keys", + "", + ), ], ) def test_validate_lora_state_dict( @@ -360,6 +414,7 @@ def test_validate_lora_state_dict( ), [ (["k_proj.lora"], [], ["q_proj.lora"], [], "Missing LoRA"), + (["k_proj.lora"], [], ["q_proj.magnitude"], [], "Missing LoRA"), (["output_proj.lora"], [], ["q_proj.lora"], [], "Missing non-LoRA"), ( ["k_proj.lora"], @@ -408,7 +463,7 @@ def test_validate_missing_and_unexpected_for_lora( class TestGetMergedLoRACkpt: - def dummy_model(self): + def dummy_lora_model(self): model = nn.Sequential( LoRALinear(in_dim=4, out_dim=6, rank=RANK, alpha=ALPHA), nn.Linear(6, 3), @@ -422,23 +477,58 @@ def dummy_model(self): model[0].weight = nn.Parameter(3 * torch.ones((6, 4))) return model - def test_get_merged_lora_ckpt(self): - dummy_model = self.dummy_model() + def dummy_dora_model(self): + model = nn.Sequential( + DoRALinear(in_dim=4, out_dim=6, rank=RANK, alpha=ALPHA), + nn.Linear(6, 3), + ) + model[0].lora_a.weight = nn.Parameter( + torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + ) + model[0].lora_b.weight = nn.Parameter( + torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) + ) + model[0].magnitude = nn.Parameter(torch.Tensor([1, 2, 3, 4, 5, 6])) + model[0].weight = nn.Parameter(3 * torch.ones((6, 4))) + return model + + @pytest.mark.parametrize("use_dora", [True, False]) + def test_get_merged_lora_ckpt(self, use_dora): + if use_dora: + dummy_model = self.dummy_dora_model() + else: + dummy_model = self.dummy_lora_model() merged_sd = get_merged_lora_ckpt( deepcopy(dummy_model.state_dict()), rank=RANK, alpha=ALPHA ) - expected_merged_weight = torch.Tensor( - [ - [8.5, 10.0, 11.5, 13.0], - [14.5, 18.0, 21.5, 25.0], - [20.5, 26.0, 31.5, 37.0], - [26.5, 34.0, 41.5, 49.0], - [32.5, 42.0, 51.5, 61.0], - [38.5, 50.0, 61.5, 73.0], - ] - ) + if use_dora: + expected_merged_weight = torch.Tensor( + [ + [0.3906, 0.4596, 0.5285, 0.5974], + [0.7202, 0.8940, 1.0671, 1.2417], + [1.0459, 1.3265, 1.6071, 1.8877], + [1.3706, 1.7585, 2.1464, 2.5343], + [1.6948, 2.1902, 2.6856, 3.1810], + [2.0188, 2.6218, 3.2248, 3.8278], + ] + ) + else: + expected_merged_weight = torch.Tensor( + [ + [8.5, 10.0, 11.5, 13.0], + [14.5, 18.0, 21.5, 25.0], + [20.5, 26.0, 31.5, 37.0], + [26.5, 34.0, 41.5, 49.0], + [32.5, 42.0, 51.5, 61.0], + [38.5, 50.0, 61.5, 73.0], + ] + ) + + print("dora", expected_merged_weight) assert merged_sd.keys() == {"0.weight", "1.weight", "1.bias"} - torch.testing.assert_close(merged_sd["0.weight"], expected_merged_weight) + torch.testing.assert_close( + merged_sd["0.weight"], expected_merged_weight, atol=1e-3, rtol=1e-3 + ) merged_model = nn.Sequential(nn.Linear(4, 6, bias=False), nn.Linear(6, 3)) merged_model.load_state_dict(merged_sd, strict=True) diff --git a/tests/torchtune/modules/test_transformer_decoder.py b/tests/torchtune/modules/test_transformer_decoder.py index 7dbbf9e92d..3cc9ff8c0b 100644 --- a/tests/torchtune/modules/test_transformer_decoder.py +++ b/tests/torchtune/modules/test_transformer_decoder.py @@ -23,7 +23,6 @@ MultiHeadAttention, RMSNorm, RotaryPositionalEmbeddings, - setup_caches, TanhGate, TransformerCrossAttentionLayer, TransformerDecoder, @@ -302,7 +301,7 @@ def decoder_with_kv_cache_enabled( for p in decoder.parameters(): nn.init.constant_(p, 0.2) decoder.eval() - setup_caches(decoder, batch_size=4, dtype=torch.float32) + decoder.setup_caches(batch_size=4, dtype=torch.float32) return decoder def test_forward( diff --git a/tests/torchtune/training/checkpointing/test_checkpointer.py b/tests/torchtune/training/checkpointing/test_checkpointer.py new file mode 100644 index 0000000000..f2660c239f --- /dev/null +++ b/tests/torchtune/training/checkpointing/test_checkpointer.py @@ -0,0 +1,766 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json + +from pathlib import Path +from typing import Tuple + +import pytest +import torch +from torch import randn + +from torchtune.models import gemma, llama2, mistral +from torchtune.modules.peft import ( + get_adapter_params, + get_lora_module_names, + validate_missing_and_unexpected_for_lora, +) + +from torchtune.training.checkpointing import FullModelHFCheckpointer +from torchtune.training.checkpointing._utils import ( + ADAPTER_CONFIG, + ADAPTER_KEY, + safe_torch_load, +) +from torchtune.utils.seed import set_seed + +_VOCAB_SIZE = 100 +_DIM = 64 +_HIDDEN_DIM = 256 +_NUM_HEADS = 4 +_NUM_KV_HEADS = 4 +_HEAD_DIM = 16 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestHFLlama2FullModelCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict_1(self, weight_dtype): + """ + State dict for a HF format checkpoint. This state dict is "complete" and + can be loaded into a torchtune model once correctly converted. + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.rotary_emb.inv_freq": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": torch.randn(_DIM, dtype=weight_dtype), + "lm_head.weight": torch.randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + } + return state_dict + + @pytest.fixture + def state_dict_2(self, weight_dtype): + """ + State dict for a HF format checkpoint. This state dict is "incomplete" and + should be used along with ``state_dict_1`` to test multi-file checkpointing. Specifically + it's missing the embedding, norm and lm_head keys. + """ + state_dict = { + "model.layers.1.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.1.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.1.self_attn.rotary_emb.inv_freq": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.1.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.1.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.1.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + } + return state_dict + + @pytest.fixture + def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2): + """ + Fixture which creates two checkpoint files for the Llama2 model. The + state dict follows the HF_FORMAT for the checkpoint format. + + The state dicts are structured in such a way that both single file and + multiple file checkpoints can be tested. + * The first checkpoint contains layer0 + embed + norm + lm_head keys + and can be tested in isolation + * The second checkpoint contains all layer1 keys and should be tested + in the multiple file checkpoint test along with the first checkpoint + + The model corresponds to the following config: + * vocab_size: 100 + * num_layers: 1 for single checkpoint and 2 for multiple checkpoint + * num_heads: 4 + * num_kv_heads: 4 + * embed_dim: 64 + * max_seq_len: 128 + """ + checkpoint_file_1 = tmp_path / "llama2_hf_checkpoint_01.pt" + checkpoint_file_2 = tmp_path / "llama2_hf_checkpoint_02.pt" + + torch.save(state_dict_1, checkpoint_file_1) + torch.save(state_dict_2, checkpoint_file_2) + + config = { + "hidden_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 4, + } + config_file = Path.joinpath(tmp_path, "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + return (checkpoint_file_1, checkpoint_file_2) + + @pytest.fixture + def single_file_checkpointer( + self, llama2_hf_checkpoints, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file, _ = llama2_hf_checkpoints + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file], + model_type="LLAMA2", + output_dir=tmp_path, + ) + + @pytest.fixture + def multi_file_checkpointer( + self, llama2_hf_checkpoints, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file_1, checkpoint_file_2 = llama2_hf_checkpoints + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file_1, checkpoint_file_2], + model_type="LLAMA2", + output_dir=tmp_path, + ) + + def test_load_save_checkpoint_single_file( + self, + single_file_checkpointer: FullModelHFCheckpointer, + llama2_hf_checkpoints: Tuple[Path, Path], + ): + """ + Test ``load_checkpoint`` and ``save_checkpoint`` method within the + FullModelHFCheckpointer for a single checkpoint file. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the llama2 torchtune implementation + * Saved checkpoint keys match the original checkpoint + """ + # Read the state dict directly from file using torch.load. This will be the state + # dict we test against + checkpoint_file, _ = llama2_hf_checkpoints + orig_state_dict = safe_torch_load(checkpoint_file) + + # Converted state dict from the checkpointer + state_dict = single_file_checkpointer.load_checkpoint() + + # Check that we've loaded all the keys; We ignore inv_freq as is standard practice + assert len(state_dict["model"].keys()) + 1 == len(orig_state_dict.keys()) + + # the keys in original state dict should match up with the keys in the weight_map + for key in orig_state_dict.keys(): + if "inv_freq" in key: + continue + assert key in single_file_checkpointer._weight_map + + # loading the state dict into the model implementation should work correctly + model = llama2.llama2( + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") + output_state_dict = safe_torch_load(output_file) + + # We ignore inv_freq as is standard practice and so output dict will have one less key + assert len(output_state_dict.keys()) + 1 == len(orig_state_dict.keys()) + + def test_save_load_checkpoint_multiple_file( + self, + multi_file_checkpointer: FullModelHFCheckpointer, + llama2_hf_checkpoints: Tuple[Path, Path], + ): + """ + Test ``load_checkpoint`` method within the FullModelCheckpointer for multiple + checkpoint file. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the llama2 torchtune implementation + """ + # Read the state dict directly from files + checkpoint_file_1, checkpoint_file_2 = llama2_hf_checkpoints + orig_state_dict_1 = safe_torch_load(checkpoint_file_1) + orig_state_dict_2 = safe_torch_load(checkpoint_file_2) + + # merged state dict from checkpointer + state_dict = multi_file_checkpointer.load_checkpoint() + + # We ignore inv_freq as is standard practice + assert len(state_dict["model"].keys()) + 2 == len( + orig_state_dict_1.keys() + ) + len(orig_state_dict_2.keys()) + + # the keys in the weight_map should match up with the keys in the weight_map + for key in orig_state_dict_1.keys(): + if "inv_freq" in key: + continue + assert key in multi_file_checkpointer._weight_map + + for key in orig_state_dict_2.keys(): + if "inv_freq" in key: + continue + assert key in multi_file_checkpointer._weight_map + + # finally loading into the model should work + model = llama2.llama2( + vocab_size=_VOCAB_SIZE, + num_layers=2, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + multi_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file_1 = Path.joinpath(checkpoint_file_1.parent, "hf_model_0001_1.pt") + output_file_2 = Path.joinpath(checkpoint_file_2.parent, "hf_model_0002_1.pt") + output_state_dict_1 = safe_torch_load(output_file_1) + output_state_dict_2 = safe_torch_load(output_file_2) + + assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys()) + assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys()) + + def test_load_save_adapter_only( + self, tmp_path, single_file_checkpointer, llama2_hf_checkpoints + ): + """ """ + state_dict = single_file_checkpointer.load_checkpoint() + + with pytest.raises( + ValueError, match="Adapter checkpoint not found in state_dict" + ): + single_file_checkpointer.save_checkpoint( + state_dict, epoch=2, adapter_only=True + ) + + state_dict[ADAPTER_KEY] = {} + single_file_checkpointer.save_checkpoint(state_dict, epoch=2, adapter_only=True) + + output_file_1 = Path.joinpath(tmp_path, "hf_model_0001_2.pt") + output_file_2 = Path.joinpath(tmp_path, "adapter_2.pt") + + with pytest.raises(ValueError, match="Unable to load checkpoint from"): + _ = safe_torch_load(output_file_1) + + output_state_dict_2 = safe_torch_load(output_file_2) + # Check that the empty adapter we saved is the one loaded succesfully + assert len(output_state_dict_2.keys()) == 0 + + def test_save_checkpoint_in_peft_format( + self, + single_file_checkpointer: FullModelHFCheckpointer, + llama2_hf_checkpoints: Tuple[Path, Path], + ): + """ + Test save_checkpoint method within the FullModelCheckpointer for + integration with HF PEFT (i.e. save_in_peft_format=True). + + We test that: + * The file adapter_config.json contains the fields required by PEFT + and the correct values + * The state dict keys of the saved adapter checkpoint are remapped as expected + * The state dict values of the saved adapter checkpoint (after key remapping) + match those in torchtune for parameters that are not permuted by HF + # The state dict values of the saved adapter checkpoint (after key remapping) + do not match those in torchtune for parameters that are permuted by HF, but the + sums along the dimension of permutation match + """ + + # Define LoRA params for this test + lora_attn_modules = ["q_proj", "output_proj"] + apply_lora_to_mlp = True + apply_lora_to_output = True + lora_rank = 4 + lora_alpha = 8 + + checkpoint_file, _ = llama2_hf_checkpoints + state_dict = single_file_checkpointer.load_checkpoint() + + # Build LoRA Llama2 model and load in base model weights + model = llama2.lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + max_seq_len=128, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + missing, unexpected = model.load_state_dict(state_dict["model"], strict=False) + validate_missing_and_unexpected_for_lora( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + base_missing=missing, + base_unexpected=unexpected, + ) + + # LoRA B params are zero-initialized, randomly initialize them to make + # the test of their permutation on checkpoint save nontrivial + lora_b_sd = { + k: torch.randn_like(v) + for k, v in model.state_dict().items() + if "lora_b" in k + } + model.load_state_dict(lora_b_sd, strict=False) + + # Construct the adapter weights and config and save using checkpointer + adapter_params = get_adapter_params(model) + adapter_key_filter = lambda x: x in adapter_params + expected_adapter_state_dict = { + k: v for k, v in model.state_dict().items() if adapter_key_filter(k) + } + adapter_config = { + "r": lora_rank, + "lora_alpha": lora_alpha, + "target_modules": get_lora_module_names( + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + ), + "peft_type": "LORA", + } + state_dict.update({ADAPTER_KEY: expected_adapter_state_dict}) + state_dict.update({ADAPTER_CONFIG: adapter_config}) + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Load saved adapter weights and config from file for comparison + adapter_weights_file = Path.joinpath( + checkpoint_file.parent, "adapter_model.bin" + ) + actual_adapter_state_dict = safe_torch_load(adapter_weights_file) + + adapter_config_file = Path.joinpath( + checkpoint_file.parent, "adapter_config.json" + ) + with open(adapter_config_file, "r") as f: + adapter_config = json.load(f) + + expected_target_modules = [ + "down_proj", + "gate_proj", + "lm_head", + "o_proj", + "q_proj", + "up_proj", + ] + assert sorted(adapter_config["target_modules"]) == expected_target_modules + + # Map PEFT keys back to torchtune keys + peft_to_tt = { + "o_proj": "output_proj", + "gate_proj": "w1", + "down_proj": "w2", + "up_proj": "w3", + "lm_head": "output", + } + for k, v in actual_adapter_state_dict.items(): + new_k = k.replace("base_model.model.", "").replace("self_attn", "attn") + if "lm_head" not in new_k: + new_k = new_k.replace("model.", "") + for kk, vv in peft_to_tt.items(): + if kk in k: + new_k = new_k.replace(kk, vv) + new_k = new_k.replace("lora_A", "lora_a").replace("lora_B", "lora_b") + + # LoRA B matrix for Q should not match due to Q and K permutation + # However, since they're permuted along embed dim, their sum along that axis should match + if "lora_b" in new_k and "q_proj" in new_k: + assert not torch.allclose( + actual_adapter_state_dict[k], expected_adapter_state_dict[new_k] + ) + torch.testing.assert_close( + actual_adapter_state_dict[k].sum(dim=0), + expected_adapter_state_dict[new_k].sum(dim=0), + ) + + # All other matrices should match exactly + if "lora_b" not in new_k: + torch.testing.assert_close( + actual_adapter_state_dict[k], expected_adapter_state_dict[new_k] + ) + + +class TestHFMistralRewardModelFullModelCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict(self, weight_dtype): + """ + State dict for a HF format mistral reward model checkpoint. This state dict is + "complete" and can be loaded into a torchtune model once correctly converted. + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": randn(_DIM, dtype=weight_dtype), + "score.weight": randn(1, _DIM, dtype=weight_dtype), + # adding bias to ensure it doesn't cause an unexpected key + "score.bias": randn(1, _DIM, dtype=weight_dtype), + } + return state_dict + + @pytest.fixture + def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict): + """ + Fixture which creates a checkpoint file for the Mistral reward model. The + state dict follows the HF_FORMAT for the checkpoint format. + + The state dicts supports testing for a single-file checkpoint. + Multiple file checkpoints are already tested for Llama2. + * The checkpoint contains layer0 + embed + norm + score keys + and can be tested in isolation + + The model corresponds to the following config: + * num_layers: 1 + * num_heads: 4 + * num_kv_heads: 4 + * embed_dim: 64 + * max_seq_len: 128 + * num_classes: 1 + * intermediate_dim: 256 + + """ + checkpoint_file = tmp_path / "mistral_reward_model_hf_checkpoint.pt" + + torch.save(state_dict, checkpoint_file) + + config = { + "hidden_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "num_classes": 1, + } + config_file = Path.joinpath(tmp_path, "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + return checkpoint_file + + @pytest.fixture + def single_file_checkpointer( + self, mistral_reward_model_hf_checkpoint, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file = mistral_reward_model_hf_checkpoint + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file], + model_type="REWARD", + output_dir=tmp_path, + ) + + def test_load_save_checkpoint_single_file( + self, + single_file_checkpointer: FullModelHFCheckpointer, + mistral_reward_model_hf_checkpoint: Path, + ): + """ + Test ``load_checkpoint`` and ``save_checkpoint`` method within the + FullModelHFCheckpointer for a single checkpoint file for a mistral reward model. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the `mistral_classifier` torchtune implementation + * Saved checkpoint keys match the original checkpoint + """ + # Read the state dict directly from file using torch.load. This will be the state + # dict we test against + checkpoint_file = mistral_reward_model_hf_checkpoint + orig_state_dict = safe_torch_load(checkpoint_file) + + # Converted state dict from the checkpointer + state_dict = single_file_checkpointer.load_checkpoint() + # Check that we've loaded all the keys minus the output bias + assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) - 1 + + # the keys in original state dict should match up with the keys in the weight_map + for key in orig_state_dict.keys(): + if "inv_freq" in key or "output.bias" in key: + continue + assert key in single_file_checkpointer._weight_map + + # loading the state dict into the model implementation should work correctly + model = mistral.mistral_classifier( + num_classes=1, + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + intermediate_dim=_HIDDEN_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") + output_state_dict = safe_torch_load(output_file) + + assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1 + + +class TestHFGemmaFullModelCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict(self, weight_dtype): + """ + State dict for a HF format Gemma checkpoint. This state dict is + "complete" and can be loaded into a TorchTune model once correctly converted. + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _NUM_HEADS * _HEAD_DIM, dtype=weight_dtype + ), + # setting num_kv_heads to 1 + "model.layers.0.self_attn.k_proj.weight": randn( + _HEAD_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _HEAD_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _NUM_HEADS * _HEAD_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": randn(_DIM, dtype=weight_dtype), + } + return state_dict + + @pytest.fixture + def gemma_hf_checkpoint(self, tmp_path, state_dict): + """ + Fixture which creates a checkpoint file for Gemma. The + state dict follows the HF_FORMAT for the checkpoint format. + + The state dicts supports testing for a single-file checkpoint. + Multiple file checkpoints are already tested for Llama2. + + The model corresponds to the following config: + * num_layers: 1 + * num_heads: 4 + * num_kv_heads: 1 + * embed_dim: 64 + * max_seq_len: 128 + * num_classes: 1 + * intermediate_dim: 256 + * head_dim : 16 + + """ + checkpoint_file = tmp_path / "gemma_hf_checkpoint.pt" + + torch.save(state_dict, checkpoint_file) + + config = { + "hidden_size": _DIM, + "num_attention_heads": _NUM_HEADS, + "num_key_value_heads": 1, + "head_dim": _HEAD_DIM, + "intermediate_size": _HIDDEN_DIM, + } + config_file = Path.joinpath(tmp_path, "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + return checkpoint_file + + @pytest.fixture + def single_file_checkpointer( + self, gemma_hf_checkpoint, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file = gemma_hf_checkpoint + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file], + model_type="GEMMA", + output_dir=tmp_path, + ) + + def test_load_save_checkpoint_single_file( + self, + single_file_checkpointer: FullModelHFCheckpointer, + gemma_hf_checkpoint: Path, + ): + """ + Test ``load_checkpoint`` and ``save_checkpoint`` method within the + FullModelHFCheckpointer for a single checkpoint file for Gemma. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the `gemma` TorchTune implementation + * lm_head weights are tied to the embed_tokens weights during saving + * lmhead weights are popped during loading + """ + # Read the state dict directly from file using torch.load. This will be the state + # dict we test against + checkpoint_file = gemma_hf_checkpoint + orig_state_dict = safe_torch_load(checkpoint_file) + + # Converted state dict from the checkpointer + + state_dict = single_file_checkpointer.load_checkpoint() + assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) + + # the keys in original state dict should match up with the keys in the weight_map + for key in orig_state_dict.keys(): + if "inv_freq" in key: + continue + assert key in single_file_checkpointer._weight_map + + # loading the state dict into the model implementation should work correctly + model = gemma.gemma( + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + head_dim=_HEAD_DIM, + num_kv_heads=1, + embed_dim=_DIM, + intermediate_dim=_HIDDEN_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") + output_state_dict = safe_torch_load(output_file) + + assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) diff --git a/tests/torchtune/training/checkpointing/test_checkpointer_utils.py b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py new file mode 100644 index 0000000000..326d3768f6 --- /dev/null +++ b/tests/torchtune/training/checkpointing/test_checkpointer_utils.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from pathlib import Path + +import pytest +import torch +from torchtune.models.llama2 import llama2, llama2_classifier +from torchtune.training.checkpointing._utils import ( + safe_torch_load, + update_state_dict_for_classifier, +) + +N_LAYERS = 3 +IN_DIM = 5 +OUT_DIM = 10 +VOCAB_SIZE = 50 +NUM_HEADS = 4 +NUM_KV_HEADS = 2 +EMBED_DIM = 64 +MAX_SEQ_LEN = 64 +NUM_CLASSES = 6 + + +class TestCheckpointerUtils: + @pytest.fixture + def model_checkpoint(self, tmp_path): + """ + Fixture which creates a checkpoint file for testing checkpointer utils. + """ + checkpoint_file = tmp_path / "model_checkpoint_01.pt" + + state_dict = { + "token_embeddings.weight": torch.ones(1, 10), + "output.weight": torch.ones(1, 10), + } + + torch.save(state_dict, checkpoint_file) + + return checkpoint_file + + @pytest.mark.parametrize("weights_only", [True, False]) + def test_safe_torch_load(self, model_checkpoint, weights_only): + state_dict = safe_torch_load(Path(model_checkpoint), weights_only) + + assert "token_embeddings.weight" in state_dict + assert "output.weight" in state_dict + + assert state_dict["token_embeddings.weight"].shape[1] == 10 + assert state_dict["output.weight"].shape[0] == 1 + + +class TestUpdateStateDictForClassifer: + @pytest.fixture() + def llama2_state_dict(self): + model = llama2( + vocab_size=VOCAB_SIZE, + num_layers=N_LAYERS, + num_heads=NUM_KV_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + return model.state_dict() + + @pytest.fixture() + def llama2_classifier_model(self): + return llama2_classifier( + num_classes=NUM_CLASSES, + vocab_size=VOCAB_SIZE, + num_layers=N_LAYERS, + num_heads=NUM_KV_HEADS, + num_kv_heads=NUM_KV_HEADS, + embed_dim=EMBED_DIM, + max_seq_len=MAX_SEQ_LEN, + ) + + def test_bias_in_classifier_checkpoint_is_removed(self, llama2_classifier_model): + # construct bogus state dict with output.bias included + state_dict_with_bias = llama2_classifier_model.state_dict().copy() + state_dict_with_bias["output.bias"] = torch.tensor([NUM_CLASSES]) + + # function should remove output.bias + update_state_dict_for_classifier( + state_dict_with_bias, llama2_classifier_model.named_parameters() + ) + + assert "output.bias" not in state_dict_with_bias + + def test_loading_base_checkpoint_into_classifier( + self, llama2_state_dict, llama2_classifier_model + ): + # grabbing the expected output.weight - the correct outcome here + # is for all weights aside from output.weight to be loaded in + # from the base model, so output.weight will remain in its rand init state + expected_output_weight = llama2_classifier_model.state_dict()[ + "output.weight" + ].clone() + + # update the state dict to load with the classifier's output.weight + update_state_dict_for_classifier( + llama2_state_dict, llama2_classifier_model.named_parameters() + ) + + # load in all the base params + llama2_classifier_model.load_state_dict(llama2_state_dict) + + # now we can assert that output.weight was unchanged + output_weight = llama2_classifier_model.state_dict()["output.weight"] + assert torch.equal(expected_output_weight, output_weight) + + def test_assertion_error_when_missing_output_in_state_dict( + self, llama2_state_dict, llama2_classifier_model + ): + llama2_state_dict.pop("output.weight") + with pytest.raises( + AssertionError, match="Expected output.weight in state_dict" + ): + update_state_dict_for_classifier( + llama2_state_dict, llama2_classifier_model.named_parameters() + ) + + def test_assertion_error_when_missing_output_in_model_named_parameters( + self, llama2_state_dict, llama2_classifier_model + ): + named_params = [ + (k, v) + for (k, v) in llama2_classifier_model.named_parameters() + if k != "output.weight" + ] + with pytest.raises( + AssertionError, match="Expected output.weight in model_named_parameters" + ): + update_state_dict_for_classifier(llama2_state_dict, named_params) + + def test_loading_classifier_weights(self, llama2_classifier_model): + state_dict_to_load = deepcopy(llama2_classifier_model.state_dict()) + state_dict_to_load["output.weight"] = torch.ones_like( + state_dict_to_load["output.weight"] + ) + + update_state_dict_for_classifier( + state_dict_to_load, llama2_classifier_model.named_parameters() + ) + llama2_classifier_model.load_state_dict(state_dict_to_load) + + model_state_dict = llama2_classifier_model.state_dict() + + assert set(model_state_dict.keys()) == set(state_dict_to_load.keys()) + assert torch.equal( + model_state_dict["output.weight"], + torch.ones_like(model_state_dict["output.weight"]), + ) + + def test_loading_classifier_weights_force_override(self, llama2_classifier_model): + state_dict_to_load = deepcopy(llama2_classifier_model.state_dict()) + state_dict_to_load["output.weight"] = torch.ones_like( + state_dict_to_load["output.weight"] + ) + + expected_output_weight = llama2_classifier_model.state_dict()[ + "output.weight" + ].clone() + + update_state_dict_for_classifier( + state_dict_to_load, llama2_classifier_model.named_parameters(), True + ) + llama2_classifier_model.load_state_dict(state_dict_to_load) + + model_state_dict = llama2_classifier_model.state_dict() + + assert set(model_state_dict.keys()) == set(state_dict_to_load.keys()) + assert torch.equal(model_state_dict["output.weight"], expected_output_weight) diff --git a/tests/torchtune/training/test_metric_logging.py b/tests/torchtune/training/test_metric_logging.py new file mode 100644 index 0000000000..2fc29e72aa --- /dev/null +++ b/tests/torchtune/training/test_metric_logging.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import tempfile +from io import StringIO +from typing import cast +from unittest.mock import patch + +import pytest +from omegaconf import OmegaConf +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from tests.test_utils import assert_expected, captured_output + +from torchtune.training.metric_logging import ( + CometLogger, + DiskLogger, + StdoutLogger, + TensorBoardLogger, + WandBLogger, +) + + +class TestDiskLogger: + def test_log(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = DiskLogger(log_dir=log_dir) + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + log_path = logger.path_to_log_file() + assert log_path.exists() + values = open(log_path).readlines() + assert_expected(len(values), 5) + for i in range(5): + assert values[i] == f"Step {i} | test_log:{float(i) ** 2}\n" + + def test_log_dict(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = DiskLogger(log_dir=log_dir) + for i in range(5): + logger.log_dict(step=i, payload={"metric_1": i, "metric_2": i**2}) + logger.close() + + log_path = logger.path_to_log_file() + assert log_path.exists() + values = open(log_path).readlines() + assert_expected(len(values), 5) + for i in range(5): + assert values[i] == f"Step {i} | metric_1:{i} metric_2:{i ** 2} \n" + + +class TestStdoutLogger: + def test_stdout_log(self) -> None: + logger = StdoutLogger() + with captured_output() as (out, _): + logger.log(step=0, name="metric_1", data=1.1) + out = cast(StringIO, out) + assert ( + out.getvalue() == "Step 0 | metric_1:1.1\n" + ), f"Actual output: {out.getvalue()}" + + logger.log(step=1, name="metric_1", data=2.1) + assert ( + out.getvalue() == "Step 0 | metric_1:1.1\nStep 1 | metric_1:2.1\n" + ), f"Actual output: {out.getvalue()}" + + logger.close() + assert ( + out.getvalue() == "Step 0 | metric_1:1.1\nStep 1 | metric_1:2.1\n" + ), f"Actual output: {out.getvalue()}" + + def test_stdout_log_dict(self) -> None: + logger = StdoutLogger() + with captured_output() as (out, _): + logger.log_dict(step=0, payload={"metric_1": 1, "metric_2": 1}) + out = cast(StringIO, out) + assert ( + out.getvalue() == "Step 0 | metric_1:1 metric_2:1 \n" + ), f"Actual output: {out.getvalue()}" + + logger.log_dict( + step=1, payload={"metric_1": 2, "metric_2": 2.2, "metric_3": 2.2344} + ) + assert ( + out.getvalue() + == "Step 0 | metric_1:1 metric_2:1 \nStep 1 | metric_1:2 metric_2:2.2 metric_3:2.2344 \n" + ), f"Actual output: {out.getvalue()}" + + logger.close() + assert ( + out.getvalue() + == "Step 0 | metric_1:1 metric_2:1 \nStep 1 | metric_1:2 metric_2:2.2 metric_3:2.2344 \n" + ), f"Actual output: {out.getvalue()}" + + +class TestTensorBoardLogger: + def test_log(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = TensorBoardLogger(log_dir=log_dir) + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + acc = EventAccumulator(logger.log_dir) + acc.Reload() + for i, event in enumerate(acc.Tensors("test_log")): + assert_expected(event.tensor_proto.float_val[0], float(i) ** 2) + assert_expected(event.step, i) + + def test_log_dict(self) -> None: + with tempfile.TemporaryDirectory() as log_dir: + logger = TensorBoardLogger(log_dir=log_dir) + metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)} + logger.log_dict(metric_dict, 1) + logger.close() + + acc = EventAccumulator(logger.log_dir) + acc.Reload() + for i in range(5): + tensor_tag = acc.Tensors(f"log_dict_{i}")[0] + assert_expected(tensor_tag.tensor_proto.float_val[0], float(i) ** 2) + assert_expected(tensor_tag.step, 1) + + +@pytest.mark.skip(reason="This was never running and needs to be fixed") +class TestWandBLogger: + def test_log(self) -> None: + with patch("wandb.init") as mock_init, patch("wandb.log") as mock_log: + logger = WandBLogger(project="test_project") + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + assert mock_log.call_count == 5 + for i in range(5): + mock_log.assert_any_call({"test_log": float(i) ** 2}, step=i) + + def test_log_dict(self) -> None: + with patch("wandb.init") as mock_init, patch("wandb.log") as mock_log: + logger = WandBLogger(project="test_project") + metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)} + logger.log_dict(metric_dict, 1) + logger.close() + + mock_log.assert_called_with(metric_dict, step=1) + + def test_save_config(self) -> None: + with patch("wandb.init") as mock_init, patch( + "wandb.run", create=True + ) as mock_run, patch("OmegaConf.save") as mock_save, patch( + "wandb.save" + ) as mock_wandb_save: + + logger = WandBLogger(project="test_project") + cfg = OmegaConf.create({"a": 1, "b": 2}) + + with patch.object(logger, "_wandb", mock_run): + logger.save_config(cfg) + + expected_config_path = "torchtune_config.yaml" + mock_save.assert_called_once_with(cfg, expected_config_path) + mock_wandb_save.assert_called_once_with(expected_config_path) + + +class TestCometLogger: + def test_log(self) -> None: + with patch("comet_ml.start") as mock_experiment: + logger = CometLogger(project="test_project") + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + assert mock_experiment.return_value.log_metric.call_count == 5 + for i in range(5): + mock_experiment.return_value.log_metric.assert_any_call( + "test_log", float(i) ** 2, step=i + ) + + def test_log_dict(self) -> None: + with patch("comet_ml.start") as mock_experiment: + logger = CometLogger(project="test_project") + metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)} + logger.log_dict(metric_dict, 1) + logger.close() + + mock_experiment.return_value.log_metrics.assert_called_with( + metric_dict, step=1 + ) + + def test_log_config(self) -> None: + with patch("comet_ml.start") as mock_experiment: + logger = CometLogger(project="test_project") + cfg = OmegaConf.create({"a": 1, "b": 2}) + logger.log_config(cfg) + mock_experiment.return_value.log_parameters.assert_called_with(cfg) diff --git a/tests/torchtune/training/test_precision.py b/tests/torchtune/training/test_precision.py new file mode 100644 index 0000000000..6f94ffd9db --- /dev/null +++ b/tests/torchtune/training/test_precision.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +from unittest import mock + +import pytest +import torch + +from torchtune.training.precision import ( + _set_float32_precision, + get_dtype, + PRECISION_STR_TO_DTYPE, + set_default_dtype, + validate_expected_param_dtype, + verify_bf16_support, +) + + +class TestPrecisionUtils: + + cuda_available: bool = torch.cuda.is_available() + + def test_get_dtype(self): + """ + Tests that the correct dtype is returned based on the input string. + """ + dtypes = [None, torch.half] + list(PRECISION_STR_TO_DTYPE.keys()) + expected_dtypes = [ + torch.float32, + torch.float16, + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ] + for dtype, expected_dtype in zip(dtypes, expected_dtypes): + if dtype == "bf16" and not verify_bf16_support(): + continue # skip bf16 tests if not supported. + assert ( + get_dtype(dtype) == expected_dtype + ), f"{dtype} should return {expected_dtype}" + + @mock.patch("torchtune.training.precision.verify_bf16_support", return_value=False) + def test_error_bf16_unsupported(self, mock_verify): + """ + Tests that an error is raised if bf16 is specified but not supported. + """ + with pytest.raises( + RuntimeError, match="bf16 precision was requested but not available" + ): + get_dtype(torch.bfloat16) + + @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") + def test_set_float32_precision(self) -> None: + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", True + ) + _set_float32_precision("highest") + assert torch.get_float32_matmul_precision() == "highest" + assert not torch.backends.cudnn.allow_tf32 + assert not torch.backends.cuda.matmul.allow_tf32 + + _set_float32_precision("high") + setattr( # noqa: B010 + torch.backends, "__allow_nonbracketed_mutation_flag", False + ) + assert torch.get_float32_matmul_precision() == "high" + assert torch.backends.cudnn.allow_tf32 + assert torch.backends.cuda.matmul.allow_tf32 + + def test_set_default_dtype(self): + dtype = torch.bfloat16 + prev_dtype = torch.get_default_dtype() + with set_default_dtype(dtype): + assert torch.get_default_dtype() == dtype + + assert torch.get_default_dtype() == prev_dtype + + def test_validate_expected_param_dtype(self): + """ + Tests that we raise if any model param has a different dtype than the expected dtype. + """ + m = torch.nn.Linear(10, 10) + with pytest.raises(ValueError, match=f"has dtype {next(m.parameters()).dtype}"): + validate_expected_param_dtype(m.named_parameters(), dtype=torch.float16) diff --git a/tests/torchtune/training/test_profiler.py b/tests/torchtune/training/test_profiler.py new file mode 100644 index 0000000000..58d4c4a164 --- /dev/null +++ b/tests/torchtune/training/test_profiler.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import pytest +import torch +from omegaconf import DictConfig, OmegaConf +from torch._C._profiler import _ExperimentalConfig +from torchtune import config +from torchtune.training import ( + DEFAULT_PROFILE_DIR, + DEFAULT_PROFILER_ACTIVITIES, + DEFAULT_SCHEDULE, + DEFAULT_TRACE_OPTS, + DummyProfiler, + PROFILER_KEY, +) + +# Disable logging otherwise output will be very verbose +logging.basicConfig(level=logging.ERROR) + +PROFILER_ATTRS = [ + "activities", + "profile_memory", + "with_stack", + "record_shapes", + "with_flops", +] + + +@pytest.fixture +def profiler_cfg(): + return """ +profiler: + enabled: True + cpu: True + cuda: True + profile_memory: False + with_stack: False + record_shapes: True + with_flops: True + wait_steps: 3 + warmup_steps: 1 + active_steps: 1 + num_cycles: 0 +""" + + +# This is a reference implementation of a profiler setup method to be defined within a `recipe`. +# A version of this lives in `torch.utils._profiler` but is not exported as the public API. +# Rather, the user is expected to define their own high-level setup function that parses the `cfg` +# and call a user-facing profiler setup function (e.g. `setup_torch_profiler`). +def _setup_profiler( + cfg_profiler: DictConfig, return_cfg: bool = False +) -> torch.profiler.profile: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (DictConfig): `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`) + return_cfg (bool): Doesn't seem to be used. Default False. + + Returns: + profiler: torch.profiler.profile | DummyProfiler - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + return profiler, profiler_cfg + + +@pytest.fixture +def reference_profiler_basic(): + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=3, warmup=1, active=1, repeat=0), + profile_memory=False, + with_stack=False, + record_shapes=True, + with_flops=True, + ) + + +@pytest.fixture +def reference_profiler_full(): + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=3, warmup=1, active=1, repeat=0), + profile_memory=True, + with_stack=True, + record_shapes=True, + with_flops=True, + experimental_config=_ExperimentalConfig(verbose=True), + ) + + +def check_profiler_attrs(profiler, ref_profiler): + for attr in PROFILER_ATTRS: + assert getattr(profiler, attr) == getattr(ref_profiler, attr) + + +def check_schedule(schedule, ref_schedule, num_steps=10): + ref_steps = [ref_schedule(i) for i in range(num_steps)] + test_steps = [schedule(i) for i in range(num_steps)] + assert ref_steps == test_steps + + +def test_instantiate_basic(profiler_cfg, reference_profiler_basic): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + profiler, updated_cfg = _setup_profiler(cfg) + + check_profiler_attrs(profiler, reference_profiler_basic) + + ref_schedule = torch.profiler.schedule( + wait=updated_cfg["wait_steps"], + warmup=updated_cfg["warmup_steps"], + active=updated_cfg["active_steps"], + repeat=updated_cfg["num_cycles"], + ) + check_schedule(profiler.schedule, ref_schedule) + + +def test_instantiate_full(profiler_cfg, reference_profiler_full): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Check `setup` automatically overrides `with_stack` and `record_shapes` when profile_memory is True and adds + # experimental_config, which is needed for stack exporting (see comments in `setup_torch_profiler`) + cfg.profile_memory = True + cfg.with_stack = False + cfg.record_shapes = False + profiler, updated_cfg = _setup_profiler(cfg) + + check_profiler_attrs(profiler, reference_profiler_full) + assert profiler.experimental_config is not None + assert updated_cfg.with_stack is True + assert updated_cfg.record_shapes is True + + +def test_schedule_setup(profiler_cfg, reference_profiler_basic): + + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test that after removing schedule, setup method will implement default schedule + _ = [cfg.pop(k) for k in DEFAULT_SCHEDULE.keys()] + profiler, updated_cfg = _setup_profiler(cfg) + test_schedule = profiler.schedule + ref_schedule = torch.profiler.schedule( + wait=DEFAULT_SCHEDULE["wait_steps"], + warmup=DEFAULT_SCHEDULE["warmup_steps"], + active=DEFAULT_SCHEDULE["active_steps"], + repeat=DEFAULT_SCHEDULE["num_cycles"], + ) + check_schedule(ref_schedule, test_schedule) + + # Check cfg is updated correctly + for k in DEFAULT_SCHEDULE.keys(): + assert updated_cfg[k] == DEFAULT_SCHEDULE[k] + + # Test missing key is automatically set to default + for k in DEFAULT_SCHEDULE.keys(): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + cfg.pop(k) + profiler, updated_cfg = _setup_profiler(cfg) + assert updated_cfg[k] == DEFAULT_SCHEDULE[k] + + +def test_default_activities(profiler_cfg): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test setup automatically adds CPU + CUDA tracing if neither CPU nor CUDA is specified + cfg.pop("cpu") + cfg.pop("cuda") + profiler, updated_cfg = _setup_profiler(cfg) + assert profiler.activities == DEFAULT_PROFILER_ACTIVITIES + assert updated_cfg.cpu is True + assert updated_cfg.cuda is True + + +def test_default_output_dir(profiler_cfg): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test cfg output_dir is set correctly + if cfg.get("output_dir", None) is not None: + cfg.pop("output_dir") + _, updated_cfg = _setup_profiler(cfg, return_cfg=True) + assert updated_cfg.output_dir == DEFAULT_PROFILE_DIR + + +def test_default_trace_opts(profiler_cfg): + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + + # Test missing profiler options are set to defaults + cfg.pop("profile_memory") + cfg.pop("with_stack") + cfg.pop("record_shapes") + cfg.pop("with_flops") + profiler, updated_cfg = _setup_profiler(cfg) + check_profiler_attrs( + profiler, + torch.profiler.profile( + activities=DEFAULT_PROFILER_ACTIVITIES, **DEFAULT_TRACE_OPTS + ), + ) + for k in ["profile_memory", "with_stack", "record_shapes", "with_flops"]: + assert updated_cfg[k] == DEFAULT_TRACE_OPTS[k] + + +def test_dummy_profiler(profiler_cfg): + + # Test missing `profile` key returns fake profiler + cfg = OmegaConf.create(profiler_cfg) + cfg.pop(PROFILER_KEY) + profiler, _ = _setup_profiler(cfg) + assert isinstance(profiler, DummyProfiler) + + # Test that disabled profiler creates fake profiler + cfg = OmegaConf.create(profiler_cfg)[PROFILER_KEY] + cfg.enabled = False + profiler, _ = _setup_profiler(cfg) + assert isinstance(profiler, DummyProfiler) + + # Test that fake_profiler.step() does nothing both when used as context manager and as standalone object + with profiler as prof: + prof.step() + + # Additional DummyProfiler no-ops when used as object and not context + assert profiler.step() is None + assert profiler.start() is None + assert profiler.stop() is None diff --git a/tests/torchtune/utils/test_distributed.py b/tests/torchtune/utils/test_distributed.py index 318de2b1c9..b18ff75d3c 100644 --- a/tests/torchtune/utils/test_distributed.py +++ b/tests/torchtune/utils/test_distributed.py @@ -28,7 +28,12 @@ from torchtune.models.llama2._component_builders import llama2, lora_llama2 from torchtune.models.llama3._component_builders import llama3 from torchtune.modules import TransformerSelfAttentionLayer -from torchtune.modules.peft import get_adapter_params, LoRALinear, set_trainable_params +from torchtune.modules.peft import ( + DoRALinear, + get_adapter_params, + LoRALinear, + set_trainable_params, +) class TestDistributed: @@ -165,7 +170,7 @@ def _get_n_lora_and_tformer_layers(model): num_lora_ab = 0 num_transformer_layers = 0 for module in model.modules(): - if isinstance(module, LoRALinear): + if isinstance(module, LoRALinear) or isinstance(module, DoRALinear): num_nested_linears = len( [m for m in module.modules() if isinstance(m, nn.Linear)] ) @@ -212,11 +217,10 @@ def test_lora_fsdp_wrap(self): assert not p.is_meta for m in wrapped_lora.modules(): - if isinstance(m, LoRALinear): + if isinstance(m, LoRALinear) or isinstance(m, DoRALinear): torch.testing.assert_close( m.lora_b.weight, torch.zeros_like(m.lora_b.weight) ) - # Total # FSDP modules should be num_transformer + num_lora_ab + 1 total_fsdp_submodules = len([m for m in FSDP.fsdp_modules(wrapped_lora)]) assert total_fsdp_submodules == (num_lora_ab + num_transformer_layers + 1) diff --git a/tests/torchtune/utils/test_generation.py b/tests/torchtune/utils/test_generation.py index 8a451897e9..15c4a336fa 100644 --- a/tests/torchtune/utils/test_generation.py +++ b/tests/torchtune/utils/test_generation.py @@ -11,7 +11,6 @@ from torchtune import utils from torchtune.models.llama2 import llama2 -from torchtune.modules import setup_caches from torchtune.utils._generation import sample @@ -31,7 +30,7 @@ def generation_model(self, dtype=torch.float32): max_seq_len=2048, ) fixed_init_model(model) - setup_caches(model, batch_size=1, dtype=dtype) + model.setup_caches(batch_size=1, dtype=dtype) model.eval() return model @@ -60,7 +59,7 @@ def generation_model_batched(self, dtype=torch.float32): max_seq_len=2048, ) fixed_init_model(model) - setup_caches(model, batch_size=2, dtype=dtype) + model.setup_caches(batch_size=2, dtype=dtype) model.eval() return model diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index 293fe2dcc9..60d91998f6 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -66,29 +66,15 @@ def __init__( self.head_dim = head_dim self.causal_mask = None self.norm_embeddings = norm_embeddings + self.cache_max_seq_len = None - def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: - """Setup key value caches for attention calculation. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - """ + def setup_caches( + self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + ): + self.cache_max_seq_len = decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len for layer in self.layers: - layer.attn.kv_cache = KVCache( - batch_size=batch_size, - max_seq_len=self.max_seq_len, - num_heads=self.num_heads, - head_dim=self.head_dim, - dtype=dtype, - ) - - # causal_mask is used during inference to ensure we're attending - # to the right tokens - self.causal_mask = torch.tril( - torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool) - ) - + layer.setup_cache(batch_size, dtype, encoder_max_seq_len=None, decoder_max_seq_len=self.cache_max_seq_len) + def forward( self, tokens: Tensor, @@ -136,13 +122,9 @@ def forward( if self.causal_mask is not None: if input_pos is None: - raise ValueError( - "Caches are setup, but the position of input token is missing" - ) + raise ValueError("Caches are setup, but the position of input token is missing") if mask is not None: - raise ValueError( - "An attention mask was set. Cannot use a non-causal mask for inference" - ) + raise ValueError("An attention mask was set. Cannot use a non-causal mask for inference") # shape: [1, input_pos_len, m_s] # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index fbe9ed61a1..bf0dce81e0 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -46,9 +46,7 @@ class FusionLayer(nn.Module): fusion_first (bool): boolean to insert fusion layer before or after the decoder layer. """ - def __init__( - self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True - ): + def __init__(self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True): super().__init__() self.layer = layer self.fusion_layer = fusion_layer @@ -56,9 +54,7 @@ def __init__( # Keep FusionLayer wrappings out of the state_dict self._register_state_dict_hook(FusionLayer._state_dict_hook) - self._register_load_state_dict_pre_hook( - FusionLayer._load_state_dict_hook, with_module=True - ) + self._register_load_state_dict_pre_hook(FusionLayer._load_state_dict_hook, with_module=True) # TODO: Switch to register_load_state_dict_pre_hook and # register_state_dict_pre_hook after PyTorch v2.5 @@ -104,10 +100,11 @@ def setup_cache( decoder_max_seq_len (int): this parameter is ignored in this layer. """ self.layer.setup_cache( - batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len + batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len, decoder_max_seq_len=decoder_max_seq_len ) + self.fusion_layer.setup_cache( - batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len + batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len, decoder_max_seq_len=decoder_max_seq_len ) @property @@ -124,9 +121,7 @@ def fusion_params(self) -> List[str]: """ Return parameters of fusion layer. """ - fusion_params = [ - f"fusion_layer.{k}" for k, v in self.fusion_layer.named_parameters() - ] + fusion_params = [f"fusion_layer.{k}" for k, v in self.fusion_layer.named_parameters()] return fusion_params def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: @@ -190,9 +185,7 @@ def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> N # Keep FusionLayer wrappings out of the state_dict self._register_state_dict_hook(FusionEmbedding._state_dict_hook) - self._register_load_state_dict_pre_hook( - FusionEmbedding._load_state_dict_hook, with_module=True - ) + self._register_load_state_dict_pre_hook(FusionEmbedding._load_state_dict_hook, with_module=True) # TODO: Switch to register_load_state_dict_pre_hook and # register_state_dict_pre_hook after PyTorch v2.5 @@ -317,13 +310,20 @@ def __init__( self.decoder = decoder self.encoder = encoder + def setup_caches( + self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + ): + self.encoder.setup_caches(batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len) + self.decoder.setup_caches(batch_size, dtype, decoder_max_seq_len=decoder_max_seq_len) + def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" - return self.decoder.caches_are_enabled() + return self.decoder.caches_are_enabled() and self.encoder.caches_are_enabled() def reset_caches(self): """Reset the key value caches.""" self.decoder.reset_caches() + self.encoder.reset_caches() def forward( self, diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index a2adaa050e..866d9890bf 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -338,9 +338,7 @@ def __init__( super().__init__() if num_layers is None: if isinstance(layers, nn.Module): - raise AssertionError( - "If num_layers is undefined, it is assumed that a list of layers is provided." - ) + raise AssertionError("If num_layers is undefined, it is assumed that a list of layers is provided.") layers = nn.ModuleList(layers) else: if not isinstance(layers, nn.Module): @@ -356,6 +354,14 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None + self.cache_max_seq_len = None + + def setup_caches( + self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + ): + self.cache_max_seq_len = decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + for layer in self.layers: + layer.setup_cache(batch_size, dtype, encoder_max_seq_len=None, decoder_max_seq_len=self.cache_max_seq_len) def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" @@ -364,9 +370,7 @@ def caches_are_enabled(self) -> bool: def reset_caches(self): """Reset the key value caches.""" if not self.caches_are_enabled(): - raise RuntimeError( - "Key value caches are not setup. Call ``setup_caches()`` first." - ) + raise RuntimeError("Key value caches are not setup. Call ``setup_caches()`` first.") for layer in self.layers: layer.reset_cache() @@ -426,8 +430,7 @@ def forward( if seq_len > self.max_seq_len: raise ValueError( - f"seq_len ({seq_len}) of input tensor should be smaller " - f"than max_seq_len ({self.max_seq_len})" + f"seq_len ({seq_len}) of input tensor should be smaller " f"than max_seq_len ({self.max_seq_len})" ) # shape: [b, s, d] @@ -435,13 +438,9 @@ def forward( if self.causal_mask is not None: if input_pos is None: - raise ValueError( - "Caches are setup, but the position of input token is missing" - ) + raise ValueError("Caches are setup, but the position of input token is missing") if mask is not None: - raise ValueError( - "An attention mask was set. Cannot use a non-causal mask for inference" - ) + raise ValueError("An attention mask was set. Cannot use a non-causal mask for inference") # shape: [1, input_pos_len, m_s] # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] @@ -518,9 +517,7 @@ def __init__( super().__init__() if num_layers is None: if isinstance(layers, nn.Module): - raise AssertionError( - "If num_layers is undefined, it is assumed that a list of layers is provided." - ) + raise AssertionError("If num_layers is undefined, it is assumed that a list of layers is provided.") layers = nn.ModuleList(layers) else: if not isinstance(layers, nn.Module): @@ -535,6 +532,14 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None + self.cache_max_seq_len = None + + def setup_caches( + self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + ): + self.cache_max_seq_len = decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + for layer in self.layers: + layer.setup_cache(batch_size, dtype, encoder_max_seq_len=None, decoder_max_seq_len=self.cache_max_seq_len) def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" @@ -543,9 +548,7 @@ def caches_are_enabled(self) -> bool: def reset_caches(self): """Reset the key value caches.""" if not self.caches_are_enabled(): - raise RuntimeError( - "Key value caches are not setup. Call ``setup_caches()`` first." - ) + raise RuntimeError("Key value caches are not setup. Call ``setup_caches()`` first.") for layer in self.layers: layer.reset_cache() @@ -604,8 +607,7 @@ def forward( if seq_len > self.max_seq_len: raise ValueError( - f"seq_len ({seq_len}) of input tensor should be smaller " - f"than max_seq_len ({self.max_seq_len})" + f"seq_len ({seq_len}) of input tensor should be smaller " f"than max_seq_len ({self.max_seq_len})" ) # shape: [b, s, d] @@ -613,13 +615,9 @@ def forward( if self.causal_mask is not None: if input_pos is None: - raise ValueError( - "Caches are setup, but the position of input token is missing" - ) + raise ValueError("Caches are setup, but the position of input token is missing") if mask is not None: - raise ValueError( - "An attention mask was set. Cannot use a non-causal mask for inference" - ) + raise ValueError("An attention mask was set. Cannot use a non-causal mask for inference") # shape: [1, input_pos_len, m_s] # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] From 854609c839ac20f5fdc38b22d6e7d4c25715fdc4 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 4 Sep 2024 20:56:19 +0100 Subject: [PATCH 03/22] refactoring setup_cache and kv_cache --- torchtune/models/gemma/transformer.py | 31 ++++++--- torchtune/modules/attention.py | 6 +- torchtune/modules/kv_cache.py | 30 ++++---- torchtune/modules/model_fusion/_fusion.py | 41 ++++++++--- torchtune/modules/transformer.py | 83 ++++++++++++++++++----- 5 files changed, 136 insertions(+), 55 deletions(-) diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index aa8dae0df9..bc281ae539 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -10,7 +10,6 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torchtune.modules import KVCache from torchtune.modules.transformer import _get_clones, TransformerSelfAttentionLayer @@ -72,19 +71,31 @@ def __init__( def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" return self.layers[0].cache_enabled -# + def set_num_output_chunks(self, num_output_chunks: int) -> None: """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. This should be called before the first forward pass, in the recipe.""" self.num_output_chunks = num_output_chunks def setup_caches( - self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int = None, + decoder_max_seq_len: int = None, ): - self.cache_max_seq_len = decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + self.cache_max_seq_len = ( + decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + ) for layer in self.layers: - layer.setup_cache(batch_size, dtype, encoder_max_seq_len=None, decoder_max_seq_len=self.cache_max_seq_len) - + layer.setup_cache( + batch_size, + dtype, + encoder_max_seq_len=None, + decoder_max_seq_len=self.cache_max_seq_len, + ) + def forward( self, tokens: Tensor, @@ -132,9 +143,13 @@ def forward( if self.causal_mask is not None: if input_pos is None: - raise ValueError("Caches are setup, but the position of input token is missing") + raise ValueError( + "Caches are setup, but the position of input token is missing" + ) if mask is not None: - raise ValueError("An attention mask was set. Cannot use a non-causal mask for inference") + raise ValueError( + "An attention mask was set. Cannot use a non-causal mask for inference" + ) # shape: [1, input_pos_len, m_s] # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index efbe243975..8bf7f1fe1f 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -213,10 +213,6 @@ def forward( y = y if y is not None else x s_y = y.shape[1] - if self.kv_cache and input_pos is None: - cache_size = self.kv_cache.size - input_pos = torch.arange(cache_size, cache_size + s_y, device=x.device) - # q has shape [b, s_x, num_heads * head_dim] # k has shape [b, s_y, num_kv_heads * head_dim] # v has shape [b, s_y, num_kv_heads * head_dim] @@ -266,7 +262,7 @@ def forward( # Update key-value cache if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) + k, v = self.kv_cache.update(k, v) # shape: [b, 1, s, s] if mask is not None: diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 4118bedd49..7048f01c88 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -48,15 +48,12 @@ def reset(self) -> None: self.k_cache.zero_() self.v_cache.zero_() - def update( - self, input_pos: Tensor, k_val: Tensor, v_val: Tensor - ) -> Tuple[Tensor, Tensor]: + def update(self, k_val: Tensor, v_val: Tensor) -> Tuple[Tensor, Tensor]: """Update KV cache with the new k_val, v_val and return the updated cache. Raises an assertion error Args: - input_pos (Tensor): Current position tensor with shape [S] k_val (Tensor): Current key tensor with shape [B, H, S, D] v_val (Tensor): Current value tensor with shape [B, H, S, D] @@ -64,28 +61,31 @@ def update( Tuple[Tensor, Tensor]: Updated KV cache with key first Raises: - ValueError: if the sequence length of ``input_pos`` is longer than the maximum sequence length. + ValueError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. ValueError: if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup. """ - - if input_pos.shape[0] != k_val.shape[2]: + bsz, _, seq_len, _ = k_val.shape + if bsz > self.k_cache.shape[0]: raise ValueError( - f"The current cache has been setup with a sequence length of {k_val.shape[2]}" - f", but found cache positions with sequence length {input_pos.shape[0]}!" + f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}" + f", but found new key tensors with batch size {k_val.shape[0]}!" ) - if k_val.shape[0] > self.k_cache.shape[0]: + if (self.size + seq_len) > self.k_cache.shape[2]: raise ValueError( - f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}" - f", but found new key tensors with batch size {k_val.shape[0]}!" + f"The current cache has been setup with a sequence length of {self.k_cache.shape[2]}" + f", but the cache has reached a sequence length of {(self.size + seq_len)}!" ) + cache_pos = torch.arange( + self.size, self.size + seq_len, device=k_val.device + ).unsqueeze(0) + self.size += seq_len - self.size = input_pos.max().item() + 1 k_out = self.k_cache v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val + k_out.index_copy_(2, cache_pos, k_val) + v_out.index_copy_(2, cache_pos, v_val) return k_out, v_out diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index bf0dce81e0..1a3ca0276f 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -46,7 +46,9 @@ class FusionLayer(nn.Module): fusion_first (bool): boolean to insert fusion layer before or after the decoder layer. """ - def __init__(self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True): + def __init__( + self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True + ): super().__init__() self.layer = layer self.fusion_layer = fusion_layer @@ -54,7 +56,9 @@ def __init__(self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool # Keep FusionLayer wrappings out of the state_dict self._register_state_dict_hook(FusionLayer._state_dict_hook) - self._register_load_state_dict_pre_hook(FusionLayer._load_state_dict_hook, with_module=True) + self._register_load_state_dict_pre_hook( + FusionLayer._load_state_dict_hook, with_module=True + ) # TODO: Switch to register_load_state_dict_pre_hook and # register_state_dict_pre_hook after PyTorch v2.5 @@ -100,11 +104,17 @@ def setup_cache( decoder_max_seq_len (int): this parameter is ignored in this layer. """ self.layer.setup_cache( - batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len, decoder_max_seq_len=decoder_max_seq_len + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, ) self.fusion_layer.setup_cache( - batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len, decoder_max_seq_len=decoder_max_seq_len + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, ) @property @@ -121,7 +131,9 @@ def fusion_params(self) -> List[str]: """ Return parameters of fusion layer. """ - fusion_params = [f"fusion_layer.{k}" for k, v in self.fusion_layer.named_parameters()] + fusion_params = [ + f"fusion_layer.{k}" for k, v in self.fusion_layer.named_parameters() + ] return fusion_params def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: @@ -185,7 +197,9 @@ def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> N # Keep FusionLayer wrappings out of the state_dict self._register_state_dict_hook(FusionEmbedding._state_dict_hook) - self._register_load_state_dict_pre_hook(FusionEmbedding._load_state_dict_hook, with_module=True) + self._register_load_state_dict_pre_hook( + FusionEmbedding._load_state_dict_hook, with_module=True + ) # TODO: Switch to register_load_state_dict_pre_hook and # register_state_dict_pre_hook after PyTorch v2.5 @@ -311,10 +325,19 @@ def __init__( self.encoder = encoder def setup_caches( - self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int = None, + decoder_max_seq_len: int = None, ): - self.encoder.setup_caches(batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len) - self.decoder.setup_caches(batch_size, dtype, decoder_max_seq_len=decoder_max_seq_len) + self.encoder.setup_caches( + batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len + ) + self.decoder.setup_caches( + batch_size, dtype, decoder_max_seq_len=decoder_max_seq_len + ) def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index d5871f19a7..7cbd304e3b 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -338,7 +338,9 @@ def __init__( super().__init__() if num_layers is None: if isinstance(layers, nn.Module): - raise AssertionError("If num_layers is undefined, it is assumed that a list of layers is provided.") + raise AssertionError( + "If num_layers is undefined, it is assumed that a list of layers is provided." + ) layers = nn.ModuleList(layers) else: if not isinstance(layers, nn.Module): @@ -363,11 +365,23 @@ def set_num_output_chunks(self, num_output_chunks: int) -> None: self.num_output_chunks = num_output_chunks def setup_caches( - self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int = None, + decoder_max_seq_len: int = None, ): - self.cache_max_seq_len = decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + self.cache_max_seq_len = ( + decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + ) for layer in self.layers: - layer.setup_cache(batch_size, dtype, encoder_max_seq_len=None, decoder_max_seq_len=self.cache_max_seq_len) + layer.setup_cache( + batch_size, + dtype, + encoder_max_seq_len=None, + decoder_max_seq_len=self.cache_max_seq_len, + ) def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" @@ -376,7 +390,9 @@ def caches_are_enabled(self) -> bool: def reset_caches(self): """Reset the key value caches.""" if not self.caches_are_enabled(): - raise RuntimeError("Key value caches are not setup. Call ``setup_caches()`` first.") + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) for layer in self.layers: layer.reset_cache() @@ -436,7 +452,8 @@ def forward( if seq_len > self.max_seq_len: raise ValueError( - f"seq_len ({seq_len}) of input tensor should be smaller " f"than max_seq_len ({self.max_seq_len})" + f"seq_len ({seq_len}) of input tensor should be smaller " + f"than max_seq_len ({self.max_seq_len})" ) # shape: [b, s, d] @@ -444,9 +461,13 @@ def forward( if self.causal_mask is not None: if input_pos is None: - raise ValueError("Caches are setup, but the position of input token is missing") + raise ValueError( + "Caches are setup, but the position of input token is missing" + ) if mask is not None: - raise ValueError("An attention mask was set. Cannot use a non-causal mask for inference") + raise ValueError( + "An attention mask was set. Cannot use a non-causal mask for inference" + ) # shape: [1, input_pos_len, m_s] # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] @@ -471,7 +492,9 @@ def forward( # shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size # Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe, # before calling forward. Upcasting it done inside of the loss function. - output = [self.output(chunk) for chunk in h.chunk(self.num_output_chunks, dim=1)] + output = [ + self.output(chunk) for chunk in h.chunk(self.num_output_chunks, dim=1) + ] else: # shape: [b, seq_len, out_dim] output = self.output(h).float() @@ -529,7 +552,9 @@ def __init__( super().__init__() if num_layers is None: if isinstance(layers, nn.Module): - raise AssertionError("If num_layers is undefined, it is assumed that a list of layers is provided.") + raise AssertionError( + "If num_layers is undefined, it is assumed that a list of layers is provided." + ) layers = nn.ModuleList(layers) else: if not isinstance(layers, nn.Module): @@ -553,11 +578,23 @@ def set_num_output_chunks(self, num_output_chunks: int) -> None: self.num_output_chunks = num_output_chunks def setup_caches( - self, batch_size: int, dtype: torch.dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int = None, + decoder_max_seq_len: int = None, ): - self.cache_max_seq_len = decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + self.cache_max_seq_len = ( + decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + ) for layer in self.layers: - layer.setup_cache(batch_size, dtype, encoder_max_seq_len=None, decoder_max_seq_len=self.cache_max_seq_len) + layer.setup_cache( + batch_size, + dtype, + encoder_max_seq_len=None, + decoder_max_seq_len=self.cache_max_seq_len, + ) def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" @@ -566,7 +603,9 @@ def caches_are_enabled(self) -> bool: def reset_caches(self): """Reset the key value caches.""" if not self.caches_are_enabled(): - raise RuntimeError("Key value caches are not setup. Call ``setup_caches()`` first.") + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) for layer in self.layers: layer.reset_cache() @@ -625,7 +664,8 @@ def forward( if seq_len > self.max_seq_len: raise ValueError( - f"seq_len ({seq_len}) of input tensor should be smaller " f"than max_seq_len ({self.max_seq_len})" + f"seq_len ({seq_len}) of input tensor should be smaller " + f"than max_seq_len ({self.max_seq_len})" ) # shape: [b, s, d] @@ -633,9 +673,13 @@ def forward( if self.causal_mask is not None: if input_pos is None: - raise ValueError("Caches are setup, but the position of input token is missing") + raise ValueError( + "Caches are setup, but the position of input token is missing" + ) if mask is not None: - raise ValueError("An attention mask was set. Cannot use a non-causal mask for inference") + raise ValueError( + "An attention mask was set. Cannot use a non-causal mask for inference" + ) # shape: [1, input_pos_len, m_s] # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] @@ -660,7 +704,10 @@ def forward( # shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size # Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe, # before calling forward. Upcasting it done inside of the loss function. - output = [F.linear(chunk, self.tok_embeddings.weight) for chunk in h.chunk(self.num_output_chunks, dim=1)] + output = [ + F.linear(chunk, self.tok_embeddings.weight) + for chunk in h.chunk(self.num_output_chunks, dim=1) + ] else: # shape: [b, seq_len, out_dim] output = F.linear(h, self.tok_embeddings.weight).float() From aa4be3251bd44084a58a7e46dc1ffa66df6ff52a Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 4 Sep 2024 23:04:25 +0100 Subject: [PATCH 04/22] removing setup_caches function --- docs/source/api_ref_modules.rst | 1 - recipes/eleuther_eval.py | 2 +- recipes/generate.py | 3 +-- tests/torchtune/modules/test_common_utils.py | 15 --------------- torchtune/modules/__init__.py | 3 +-- 5 files changed, 3 insertions(+), 21 deletions(-) delete mode 100644 tests/torchtune/modules/test_common_utils.py diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 5ebeb40e15..a95d9306f0 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -92,7 +92,6 @@ These are utilities that are common to and can be used by all modules. :nosignatures: common_utils.reparametrize_as_dtype_state_dict_post_hook - common_utils.setup_caches Vision Transforms diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 18151d9bb9..ffb5e37551 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -144,7 +144,7 @@ def _model_generate( # are not needed for a regular model call, so we just setup here if self.enable_kv_cache: with context.device: - setup_caches(batch_size=curr_batch_size, dtype=self._dtype) + self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype) temperature = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", False) diff --git a/recipes/generate.py b/recipes/generate.py index c4b48c4748..31ddd9dee4 100644 --- a/recipes/generate.py +++ b/recipes/generate.py @@ -15,7 +15,6 @@ from torchtune import config, training, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import ChatFormat, InstructTemplate, Message -from torchtune.modules import setup_caches logger = utils.get_logger("DEBUG") @@ -84,7 +83,7 @@ def _setup_model( # Ensure the cache is setup on the right device if enable_kv_cache: with self._device: - setup_caches(batch_size=1, dtype=self._dtype) + model.setup_caches(batch_size=1, dtype=self._dtype) return model diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py deleted file mode 100644 index b4ca73c92c..0000000000 --- a/tests/torchtune/modules/test_common_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# from torchtune.modules import setup_caches - - -# class TestSetupCaches: -# def test_setup_caches_decoder_model(self): -# pass - -# def test_setup_caches_fusion_model(self): -# pass diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index c1ec585f0e..c60ff69bad 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from .attention import MultiHeadAttention # noqa -from .common_utils import reparametrize_as_dtype_state_dict_post_hook, setup_caches +from .common_utils import reparametrize_as_dtype_state_dict_post_hook from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa @@ -30,7 +30,6 @@ "FrozenNF4Linear", "get_cosine_schedule_with_warmup", "KVCache", - "setup_caches", "RotaryPositionalEmbeddings", "RMSNorm", "Fp32LayerNorm", From f37c803b5b7ff6b5c0f7e2c6a9d591f9c31d2ddd Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 4 Sep 2024 23:44:40 +0100 Subject: [PATCH 05/22] adding tests --- tests/torchtune/modules/test_kv_cache.py | 80 +++++++++++++++++++ .../modules/test_transformer_decoder.py | 27 +++++++ torchtune/modules/kv_cache.py | 4 +- 3 files changed, 108 insertions(+), 3 deletions(-) create mode 100644 tests/torchtune/modules/test_kv_cache.py diff --git a/tests/torchtune/modules/test_kv_cache.py b/tests/torchtune/modules/test_kv_cache.py new file mode 100644 index 0000000000..23e257030c --- /dev/null +++ b/tests/torchtune/modules/test_kv_cache.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchtune.modules import KVCache + +BSZ = 2 +MAX_SEQ_LEN = 16 +NUM_HEADS = 4 +HEAD_DIM = 256 +DTYPE = torch.float32 + + +class TestKVCache: + @pytest.fixture() + def k_vals_full(self): + return torch.ones((BSZ, NUM_HEADS, MAX_SEQ_LEN, HEAD_DIM)) + + @pytest.fixture() + def v_vals_full(self): + return torch.ones((BSZ, NUM_HEADS, MAX_SEQ_LEN, HEAD_DIM)) * 2 + + @pytest.fixture() + def kv_cache(self): + return KVCache(BSZ, MAX_SEQ_LEN, NUM_HEADS, HEAD_DIM, DTYPE) + + def test_kv_cache_init(self, kv_cache): + # kv cache should be init with zero + assert (kv_cache.k_cache == 0).all() and (kv_cache.v_cache == 0).all() + + def test_kv_cache_reset(self, kv_cache, k_vals_full, v_vals_full): + kv_cache.update(k_vals_full, v_vals_full) + kv_cache.reset() + assert (kv_cache.k_cache == 0).all() and (kv_cache.v_cache == 0).all() + + def test_kv_cache_error_when_bsz_exceeded(self, kv_cache, k_vals_full, v_vals_full): + with pytest.raises(ValueError): + kv_cache.update(k_vals_full.repeat(4, 1, 1, 1), v_vals_full) + + def test_kv_cache_error_when_seq_len_exceeded( + self, kv_cache, k_vals_full, v_vals_full + ): + with pytest.raises(ValueError): + kv_cache.update(k_vals_full.repeat(1, 1, 4, 1), v_vals_full) + + def test_kv_cache_error_when_seq_len_exceeded_after_update( + self, kv_cache, k_vals_full, v_vals_full + ): + # test that the cache position is correctly being used to check for seq len exceeded + # make a valid update filling half the cache + kv_cache.update( + k_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + v_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + ) + with pytest.raises( + ValueError, + match=f"cache has reached a sequence length of {MAX_SEQ_LEN + MAX_SEQ_LEN // 2}", + ): + # now an invalid update exceeding the cache + kv_cache.update(k_vals_full, v_vals_full) + + def test_kv_cache_size_update(self, kv_cache, k_vals_full, v_vals_full): + # tests that the kv_cache is correctly tracking the cache position + + # make a valid update filling half the cache - like a prefill + kv_cache.update( + k_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + v_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + ) + assert kv_cache.size == MAX_SEQ_LEN // 2 + # now one update with the next key and value + kv_cache.update( + k_vals_full[:, :, (MAX_SEQ_LEN // 2) + 1].unsqueeze(-2), + v_vals_full[:, :, (MAX_SEQ_LEN // 2) + 1].unsqueeze(-2), + ) + assert kv_cache.size == (MAX_SEQ_LEN // 2) + 1 diff --git a/tests/torchtune/modules/test_transformer_decoder.py b/tests/torchtune/modules/test_transformer_decoder.py index 4c999eb37d..78ffee8935 100644 --- a/tests/torchtune/modules/test_transformer_decoder.py +++ b/tests/torchtune/modules/test_transformer_decoder.py @@ -304,6 +304,33 @@ def decoder_with_kv_cache_enabled( decoder.setup_caches(batch_size=4, dtype=torch.float32) return decoder + @pytest.fixture + def decoder_with_kv_cache_fixed_length( + self, decoder_params: Tuple[int, int, int, int, int, int] + ) -> TransformerDecoder: + ( + vocab_size, + embed_dim, + num_layers, + num_heads, + max_seq_len, + num_kv_heads, + ) = decoder_params + decoder = llama2( + vocab_size=vocab_size, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + embed_dim=embed_dim, + max_seq_len=max_seq_len, + ) + # TODO: fix weight initialization to use fixed_init_model + for p in decoder.parameters(): + nn.init.constant_(p, 0.2) + decoder.eval() + decoder.setup_caches(batch_size=4, dtype=torch.float32, decoder_max_seq_len=12) + return decoder + def test_forward( self, input: torch.Tensor, diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 9c43a28be7..2d9775ac48 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -79,9 +79,7 @@ def update( f"The current cache has been setup with a sequence length of {self.k_cache.shape[2]}" f", but the cache has reached a sequence length of {(self.size + seq_len)}!" ) - cache_pos = torch.arange( - self.size, self.size + seq_len, device=k_val.device - ).unsqueeze(0) + cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device) self.size += seq_len k_out = self.k_cache From e175592968eebb7814db927b4b41ab89e192bdc4 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 13:52:33 +0100 Subject: [PATCH 06/22] updating tests --- .../modules/model_fusion/test_fusion_layer.py | 46 +++++++++-- .../model_fusion/test_fusion_models.py | 2 +- tests/torchtune/modules/test_kv_cache.py | 77 ++++++++++++++++++- torchtune/modules/kv_cache.py | 19 ++++- torchtune/modules/model_fusion/_fusion.py | 3 - 5 files changed, 135 insertions(+), 12 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layer.py index a275687823..c30ed94110 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layer.py @@ -18,14 +18,34 @@ def random(): set_seed(1) -class DummyLayer(nn.Module): +class DummyCrossAttentionLayer(nn.Module): def __init__(self, dim): super().__init__() self.linear = nn.Linear(dim, dim) self.cache_enabled = False + self.encoder_max_seq_len = None - def setup_cache(self, batch_size, dtype): + def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): self.cache_enabled = True + self.encoder_max_seq_len = encoder_max_seq_len + + def reset_cache(self): + self.cache_enabled = False + + def forward(self, x): + return self.linear(x) + + +class DummySelfAttentionLayer(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + self.cache_enabled = False + self.decoder_max_seq_len = None + + def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): + self.cache_enabled = True + self.decoder_max_seq_len = decoder_max_seq_len def reset_cache(self): self.cache_enabled = False @@ -45,13 +65,13 @@ def dim(self) -> int: @pytest.fixture def layer(self, dim) -> nn.Module: - layer = DummyLayer(dim) + layer = DummySelfAttentionLayer(dim) fixed_init_model(layer, min_val=-0.1, max_val=0.1) return layer @pytest.fixture def fusion_layer(self, dim) -> nn.Module: - layer = DummyLayer(dim) + layer = DummyCrossAttentionLayer(dim) fixed_init_model(layer, min_val=-0.2, max_val=0.2) return layer @@ -115,7 +135,23 @@ def test_setup_cache(self, fused_layer): """ Test that the cache methods works as expected. """ - fused_layer.setup_cache(2, torch.float32) + fused_layer.setup_cache( + 2, torch.float32, encoder_max_seq_len=10, decoder_max_seq_len=10 + ) assert fused_layer.cache_enabled fused_layer.reset_cache() assert not fused_layer.cache_enabled + + def test_setup_cache_different_cache_seq_len(self, fused_layer): + """ + Test that the cache methods works as expected. + """ + fused_layer.setup_cache( + 2, torch.float32, encoder_max_seq_len=5, decoder_max_seq_len=10 + ) + + assert fused_layer.layer.decoder_max_seq_len == 10 + assert fused_layer.fusion_layer.encoder_max_seq_len == 5 + + assert not fused_layer.layer.hasattr("encoder_max_seq_len") + assert not fused_layer.fusion_layer.hasattr("decoder_max_seq_len") diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 1c3dc44132..53dc1b0d5e 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -28,7 +28,7 @@ def __init__(self, dim, vocab_size): self.v = nn.Linear(dim, dim) self.output = nn.Linear(dim, vocab_size) - def setup_caches(self, batch_size, dtype): + def setup_caches(self, batch_size, dtype, *args, **kwargs): self.cache_enabled = True def caches_are_enabled(self): diff --git a/tests/torchtune/modules/test_kv_cache.py b/tests/torchtune/modules/test_kv_cache.py index 23e257030c..75c4be8fb2 100644 --- a/tests/torchtune/modules/test_kv_cache.py +++ b/tests/torchtune/modules/test_kv_cache.py @@ -18,11 +18,28 @@ class TestKVCache: @pytest.fixture() def k_vals_full(self): - return torch.ones((BSZ, NUM_HEADS, MAX_SEQ_LEN, HEAD_DIM)) + return ( + torch.tril(torch.ones(MAX_SEQ_LEN, HEAD_DIM))[ + None, + None, + :, + :, + ] + .repeat(BSZ, NUM_HEADS, 1, 1) + .to(DTYPE) + ) @pytest.fixture() def v_vals_full(self): - return torch.ones((BSZ, NUM_HEADS, MAX_SEQ_LEN, HEAD_DIM)) * 2 + return ( + torch.tril(torch.ones(MAX_SEQ_LEN, HEAD_DIM))[ + None, + None, + :, + :, + ].repeat(BSZ, NUM_HEADS, 1, 1) + * 2 + ).to(DTYPE) @pytest.fixture() def kv_cache(self): @@ -36,6 +53,7 @@ def test_kv_cache_reset(self, kv_cache, k_vals_full, v_vals_full): kv_cache.update(k_vals_full, v_vals_full) kv_cache.reset() assert (kv_cache.k_cache == 0).all() and (kv_cache.v_cache == 0).all() + assert kv_cache.size == 0 def test_kv_cache_error_when_bsz_exceeded(self, kv_cache, k_vals_full, v_vals_full): with pytest.raises(ValueError): @@ -78,3 +96,58 @@ def test_kv_cache_size_update(self, kv_cache, k_vals_full, v_vals_full): v_vals_full[:, :, (MAX_SEQ_LEN // 2) + 1].unsqueeze(-2), ) assert kv_cache.size == (MAX_SEQ_LEN // 2) + 1 + + def test_kv_cache_single_update(self, kv_cache, k_vals_full, v_vals_full): + # tests that the kv_cache is correctly returning the updated cache values + # after a single cache update + + # make a valid update filling half the cache - like a prefill + k_out, v_out = kv_cache.update( + k_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + v_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + ) + + expected_k_out = torch.zeros_like(k_vals_full) + expected_v_out = torch.zeros_like(v_vals_full) + + expected_k_out[:, :, torch.arange(0, (MAX_SEQ_LEN // 2))] = k_vals_full[ + :, :, : (MAX_SEQ_LEN // 2) + ] + expected_v_out[:, :, torch.arange(0, (MAX_SEQ_LEN // 2))] = v_vals_full[ + :, :, : (MAX_SEQ_LEN // 2) + ] + + assert torch.equal(expected_k_out, k_out) + assert torch.equal(expected_v_out, v_out) + + def test_kv_cache_multiple_updates(self, kv_cache, k_vals_full, v_vals_full): + # tests that the kv_cache is correctly returning the updated cache values + # after a single cache update, followed by another cache update with seq_len=1 + + # make an update filling half the cache - like a prefill + # fills position 0 through to (MAX_SEQ_LEN // 2) - 1 + kv_cache.update( + k_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + v_vals_full[:, :, : (MAX_SEQ_LEN // 2)], + ) + + # make an update for one more token, which is the value at + # (MAX_SEQ_LEN // 2) + k_out, v_out = kv_cache.update( + k_vals_full[:, :, (MAX_SEQ_LEN // 2)].unsqueeze(2), + v_vals_full[:, :, (MAX_SEQ_LEN // 2)].unsqueeze(2), + ) + + expected_k_out = torch.zeros_like(k_vals_full) + expected_v_out = torch.zeros_like(v_vals_full) + + # cache should be incremented by one position + expected_k_out[:, :, torch.arange(0, ((MAX_SEQ_LEN // 2) + 1))] = k_vals_full[ + :, :, : ((MAX_SEQ_LEN // 2) + 1) + ] + expected_v_out[:, :, torch.arange(0, ((MAX_SEQ_LEN // 2) + 1))] = v_vals_full[ + :, :, : ((MAX_SEQ_LEN // 2) + 1) + ] + + assert torch.equal(expected_k_out, k_out) + assert torch.equal(expected_v_out, v_out) diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 2d9775ac48..21aa9ac50e 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -47,13 +47,30 @@ def reset(self) -> None: """Reset the cache to zero.""" self.k_cache.zero_() self.v_cache.zero_() + self.size = 0 def update( self, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Update KV cache with the new k_val, v_val and return the updated cache. - Raises an assertion error + Note: + When updating the KV cache, it is assumed that subsequent updates should update key-value + positions in consecutive sequence positions. If you wish to update cache values which have + already been filled, use ``.reset()``, which will reset the cache to the zero-th position. + + Example: + >>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16) + >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) + >>> cache.update(keys, values) + >>> # now positions 0 through 7 are filled + >>> cache.size + >>> 8 + >>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32)) + >>> cache.update(keys, values) + >>> # this will fill at position 8 + >>> cache.size + >>> 9 Args: k_val (torch.Tensor): Current key tensor with shape [B, H, S, D] diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 2e9277f3db..2615a2af1e 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -332,9 +332,6 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): - self.encoder.setup_caches( - batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len - ) self.decoder.setup_caches( batch_size, dtype, decoder_max_seq_len=decoder_max_seq_len ) From e9b4871808e227bfcda71343ab7dc300147a343f Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 13:57:11 +0100 Subject: [PATCH 07/22] fixing tests for real --- .../torchtune/modules/model_fusion/test_fusion_layer.py | 4 ++-- torchtune/modules/model_fusion/_fusion.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layer.py index c30ed94110..94ca29085e 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layer.py @@ -153,5 +153,5 @@ def test_setup_cache_different_cache_seq_len(self, fused_layer): assert fused_layer.layer.decoder_max_seq_len == 10 assert fused_layer.fusion_layer.encoder_max_seq_len == 5 - assert not fused_layer.layer.hasattr("encoder_max_seq_len") - assert not fused_layer.fusion_layer.hasattr("decoder_max_seq_len") + assert not hasattr(fused_layer.layer, "encoder_max_seq_len") + assert not hasattr(fused_layer.fusion_layer, "decoder_max_seq_len") diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 2615a2af1e..5a3f31c1c6 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -332,18 +332,21 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): + """Sets up caches for self-attention, cross-attention, and fusion layers in the decoder.""" self.decoder.setup_caches( - batch_size, dtype, decoder_max_seq_len=decoder_max_seq_len + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, ) def caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" - return self.decoder.caches_are_enabled() and self.encoder.caches_are_enabled() + return self.decoder.caches_are_enabled() def reset_caches(self): """Reset the key value caches.""" self.decoder.reset_caches() - self.encoder.reset_caches() def forward( self, From e26c781198ed240036e0e87d3f3f59f145ad64ea Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 14:14:40 +0100 Subject: [PATCH 08/22] fixing docs --- torchtune/modules/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 21aa9ac50e..c18102183b 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -77,7 +77,7 @@ def update( v_val (torch.Tensor): Current value tensor with shape [B, H, S, D] Returns: - Tuple[Tensor, Tensor]: Updated KV cache with key first + Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: ValueError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. From 7e5625bbb038aa550d19e395604f1d4c604abe64 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 14:19:43 +0100 Subject: [PATCH 09/22] removing cache_max_seq_len --- torchtune/models/gemma/transformer.py | 10 ++++++---- torchtune/modules/transformer.py | 20 ++++++++++++-------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index c3fb98d1a3..df8cfde4c5 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -63,7 +63,6 @@ def __init__( self.head_dim = head_dim self.causal_mask = None self.norm_embeddings = norm_embeddings - self.cache_max_seq_len = None self.num_output_chunks = 0 def caches_are_enabled(self) -> bool: @@ -83,15 +82,18 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): - self.cache_max_seq_len = ( + encoder_max_seq_len = ( + encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len + ) + decoder_max_seq_len = ( decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len ) for layer in self.layers: layer.setup_cache( batch_size, dtype, - encoder_max_seq_len=None, - decoder_max_seq_len=self.cache_max_seq_len, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, ) def forward( diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 330997d140..da3d165da7 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -356,7 +356,6 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None - self.cache_max_seq_len = None self.num_output_chunks = 0 def set_num_output_chunks(self, num_output_chunks: int) -> None: @@ -372,15 +371,18 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): - self.cache_max_seq_len = ( + encoder_max_seq_len = ( + encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len + ) + decoder_max_seq_len( decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len ) for layer in self.layers: layer.setup_cache( batch_size, dtype, - encoder_max_seq_len=None, - decoder_max_seq_len=self.cache_max_seq_len, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, ) def caches_are_enabled(self) -> bool: @@ -569,7 +571,6 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None - self.cache_max_seq_len = None self.num_output_chunks = 0 def set_num_output_chunks(self, num_output_chunks: int) -> None: @@ -585,15 +586,18 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): - self.cache_max_seq_len = ( + encoder_max_seq_len = ( + encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len + ) + decoder_max_seq_len = ( decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len ) for layer in self.layers: layer.setup_cache( batch_size, dtype, - encoder_max_seq_len=None, - decoder_max_seq_len=self.cache_max_seq_len, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, ) def caches_are_enabled(self) -> bool: From 03506893a45e455d7984da15479e9ed32c94b38f Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 14:24:54 +0100 Subject: [PATCH 10/22] updating docs --- torchtune/models/gemma/transformer.py | 12 ++++++++++++ torchtune/modules/model_fusion/_fusion.py | 14 ++++++++++++- torchtune/modules/transformer.py | 24 +++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index df8cfde4c5..46bd777261 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -82,6 +82,18 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): + """ + Sets up key-value attention caches for inference. For each layer in ``self.layers``: + - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): maximum encoder cache sequence length. + decoder_max_seq_len (int): maximum decoder cache sequence length. + """ encoder_max_seq_len = ( encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len ) diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 5a3f31c1c6..a868c600db 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -332,7 +332,19 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): - """Sets up caches for self-attention, cross-attention, and fusion layers in the decoder.""" + """ + Sets up key-value attention caches for inference for ``self.decoder``. + For each layer in ``self.decoder.layers``: + - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): maximum encoder cache sequence length. + decoder_max_seq_len (int): maximum decoder cache sequence length. + """ self.decoder.setup_caches( batch_size, dtype, diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index da3d165da7..d1a5926a7c 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -371,6 +371,18 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): + """ + Sets up key-value attention caches for inference. For each layer in ``self.layers``: + - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): maximum encoder cache sequence length. + decoder_max_seq_len (int): maximum decoder cache sequence length. + """ encoder_max_seq_len = ( encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len ) @@ -586,6 +598,18 @@ def setup_caches( encoder_max_seq_len: int = None, decoder_max_seq_len: int = None, ): + """ + Sets up key-value attention caches for inference. For each layer in ``self.layers``: + - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): maximum encoder cache sequence length. + decoder_max_seq_len (int): maximum decoder cache sequence length. + """ encoder_max_seq_len = ( encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len ) From 432adeff95af45b91020e8145a9fcd34e8cb9e59 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 14:26:11 +0100 Subject: [PATCH 11/22] updating docs --- torchtune/modules/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index c18102183b..84996518ad 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -52,7 +52,7 @@ def reset(self) -> None: def update( self, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - """Update KV cache with the new k_val, v_val and return the updated cache. + """Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache. Note: When updating the KV cache, it is assumed that subsequent updates should update key-value From 18318033df424d27157e0288a19c8a0dc1d91683 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 14:36:29 +0100 Subject: [PATCH 12/22] i'm a muppet --- torchtune/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index d1a5926a7c..a2eb9cd483 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -386,7 +386,7 @@ def setup_caches( encoder_max_seq_len = ( encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len ) - decoder_max_seq_len( + decoder_max_seq_len = ( decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len ) for layer in self.layers: From a898bc85fdb3cc68db2dc6d53e71dda10d4dfa1e Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 14:48:39 +0100 Subject: [PATCH 13/22] adding to generation tests --- tests/torchtune/utils/test_generation.py | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/torchtune/utils/test_generation.py b/tests/torchtune/utils/test_generation.py index 15c4a336fa..a08ccd1a2b 100644 --- a/tests/torchtune/utils/test_generation.py +++ b/tests/torchtune/utils/test_generation.py @@ -63,6 +63,21 @@ def generation_model_batched(self, dtype=torch.float32): model.eval() return model + @pytest.fixture + def generation_model_batched_fixed_cache_seq_len(self, dtype=torch.float32): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.setup_caches(batch_size=2, dtype=dtype, decoder_max_seq_len=1024) + model.eval() + return model + @pytest.fixture def prompt_tokens(self): """ @@ -103,6 +118,16 @@ def test_sample_consistency(self): "generation_model_no_kv_cache", "prompt_tokens_batched", ), + ( + "generation_model_batched", + "generation_model_batched_fixed_cache_seq_len", + "prompt_tokens_batched", + ), + ( + "generation_model_batched_fixed_cache_seq_len", + "generation_model_no_kv_cache", + "prompt_tokens_batched", + ), ], ) def test_reproducibility(self, request, model1, model2, prompt): From 1b0e7a973989e7d5316ce6c457de1295ff7c58f3 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 5 Sep 2024 15:07:19 +0100 Subject: [PATCH 14/22] updating kv cache transformerdecoder tests --- .../modules/test_transformer_decoder.py | 41 +++++++++++++---- torchtune/modules/__init__.py | 1 - torchtune/modules/common_utils.py | 45 +------------------ torchtune/modules/model_fusion/_fusion.py | 4 +- 4 files changed, 35 insertions(+), 56 deletions(-) diff --git a/tests/torchtune/modules/test_transformer_decoder.py b/tests/torchtune/modules/test_transformer_decoder.py index 78ffee8935..2796cca73e 100644 --- a/tests/torchtune/modules/test_transformer_decoder.py +++ b/tests/torchtune/modules/test_transformer_decoder.py @@ -278,7 +278,7 @@ def decoder( return decoder @pytest.fixture - def decoder_with_kv_cache_enabled( + def decoder_with_kv_cache( self, decoder_params: Tuple[int, int, int, int, int, int] ) -> TransformerDecoder: ( @@ -328,7 +328,9 @@ def decoder_with_kv_cache_fixed_length( for p in decoder.parameters(): nn.init.constant_(p, 0.2) decoder.eval() - decoder.setup_caches(batch_size=4, dtype=torch.float32, decoder_max_seq_len=12) + decoder.setup_caches( + batch_size=4, dtype=torch.float32, decoder_max_seq_len=max_seq_len + 512 + ) return decoder def test_forward( @@ -351,30 +353,42 @@ def test_max_seq_len_exceeded( with pytest.raises(Exception): output = decoder(input_max_len_exceeded) + @pytest.mark.parametrize( + "decoder_with_kv_cache_enabled", + ["decoder_with_kv_cache", "decoder_with_kv_cache_fixed_length"], + ) def test_kv_cache( self, + request, input: torch.Tensor, - decoder_with_kv_cache_enabled: TransformerDecoder, + decoder_with_kv_cache_enabled: str, decoder: TransformerDecoder, ) -> None: _, seq_len = input.shape - input_pos = torch.arange(seq_len) - + decoder_with_kv_cache_enabled = request.getfixturevalue( + decoder_with_kv_cache_enabled + ) with torch.no_grad(): - output_cache = decoder_with_kv_cache_enabled(input, input_pos=input_pos) + output_cache = decoder_with_kv_cache_enabled(input) output_no_cache = decoder(input) assert_expected(output_cache.mean(), output_no_cache.mean()) + @pytest.mark.parametrize( + "decoder_with_kv_cache_enabled", + ["decoder_with_kv_cache", "decoder_with_kv_cache_fixed_length"], + ) def test_kv_cache_reset_values( self, + request, input: torch.Tensor, decoder_with_kv_cache_enabled: TransformerDecoder, ) -> None: - _, seq_len = input.shape - input_pos = torch.arange(seq_len) + decoder_with_kv_cache_enabled = request.getfixturevalue( + decoder_with_kv_cache_enabled + ) with torch.no_grad(): - _ = decoder_with_kv_cache_enabled(input, input_pos=input_pos) + _ = decoder_with_kv_cache_enabled(input) kv_cache_k_val = decoder_with_kv_cache_enabled.layers[ 0 ].attn.kv_cache.k_cache.clone() @@ -400,11 +414,20 @@ def test_kv_cache_reset_values_fails_when_not_enabled_first( with pytest.raises(RuntimeError, match="Key value caches are not setup"): decoder.reset_caches() + @pytest.mark.parametrize( + "decoder_with_kv_cache_enabled", + ["decoder_with_kv_cache", "decoder_with_kv_cache_fixed_length"], + ) def test_kv_cache_batch_size_exceeded( self, + request, input_max_bs_exceeded: torch.Tensor, decoder_with_kv_cache_enabled: TransformerDecoder, ) -> None: + + decoder_with_kv_cache_enabled = request.getfixturevalue( + decoder_with_kv_cache_enabled + ) with pytest.raises(ValueError): decoder_with_kv_cache_enabled(input_max_bs_exceeded) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index c60ff69bad..1c508941f7 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -8,7 +8,6 @@ from .common_utils import reparametrize_as_dtype_state_dict_post_hook from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa - from .layer_norm import Fp32LayerNorm # noqa from .low_precision import FrozenNF4Linear # noqa from .lr_schedulers import get_cosine_schedule_with_warmup # noqa diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index b4ab3fb5cd..9c588fabba 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -4,55 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple import torch import torch.nn as nn from torchao.dtypes.nf4tensor import NF4Tensor -from torchtune.modules.transformer import TransformerDecoder - - -def setup_caches( - model: TransformerDecoder, - batch_size: int, - dtype: torch.dtype, - *, - encoder_max_seq_len: Optional[int] = None, - decoder_max_seq_len: Optional[int] = None, -): - """ - Setup static key-value caches for attention calculation for a given ``TransformerDecoder` model. - This function supports cache setup for both decoder, and encoder-decoder models. - - Concretely, all layers which are an instance of :class:`~torchtune.modules.TransformerSelfAttentionLayer` - will use ``decoder_max_seq_len``, and all layers which are an instance - of :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. - :class:`~torchtune.modules.model_fusion.FusionLayer` will use both. - - Args: - model (TransformerDecoder): An instance of a ``TransformerDecoder`` model. - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (Optional[int]): maximum cache sequence length for encoder layers. - Default None, in which case ``model.max_seq_len`` is used. - decoder_max_seq_len (Optional[int]): maximum cache sequence length for decoder layers. - Default None, in which case ``model.max_seq_len`` is used. - - """ - encoder_max_seq_len = ( - model.max_seq_len if encoder_max_seq_len is None else encoder_max_seq_len - ) - decoder_max_seq_len = ( - model.max_seq_len if decoder_max_seq_len is None else decoder_max_seq_len - ) - for layer in model.layers: - layer.setup_cache( - batch_size, - dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, - ) def reparametrize_as_dtype_state_dict_post_hook( diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index a868c600db..97f8de9e71 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -100,8 +100,8 @@ def setup_cache( Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum cache sequence length. - decoder_max_seq_len (int): this parameter is ignored in this layer. + encoder_max_seq_len (int): maximum cache sequence length for cross-attention layer. + decoder_max_seq_len (int): maximum cache sequence length for self-attention layer. """ self.layer.setup_cache( batch_size, From 807f8b2d51d7152050bc74fa41d00e796dd92f21 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 6 Sep 2024 17:06:27 +0100 Subject: [PATCH 15/22] adding support for encoder-decoder max seq len overrides --- .../modules/test_transformer_decoder.py | 113 ++++++------ torchtune/models/gemma/transformer.py | 18 +- torchtune/modules/model_fusion/_fusion.py | 7 +- torchtune/modules/transformer.py | 172 ++++++++++++------ torchtune/utils/_generation.py | 49 ++++- 5 files changed, 229 insertions(+), 130 deletions(-) diff --git a/tests/torchtune/modules/test_transformer_decoder.py b/tests/torchtune/modules/test_transformer_decoder.py index 53d58c49ba..1415859914 100644 --- a/tests/torchtune/modules/test_transformer_decoder.py +++ b/tests/torchtune/modules/test_transformer_decoder.py @@ -9,7 +9,6 @@ import pytest import torch - from tests.test_utils import assert_expected from torch import nn @@ -185,7 +184,7 @@ def transformer_layer( def test_forward( self, - input: [torch.Tensor, torch.Tensor, torch.Tensor], + input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], transformer_layer: TransformerSelfAttentionLayer, ) -> None: input_x, input_y, mask = input @@ -219,6 +218,20 @@ def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor: batch_size, seq_len, vocab_size = input_params return torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + @pytest.fixture + def causal_mask(self, input_params: Tuple[int, int, int]) -> torch.Tensor: + batch_size, seq_len, _ = input_params + return ( + torch.tril(torch.ones((seq_len, seq_len))) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) + + @pytest.fixture + def input_pos(self, input_params: Tuple[int, int, int]) -> torch.Tensor: + batch_size, seq_len, _ = input_params + return torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1) + @pytest.fixture def decoder_params(self) -> Tuple[int, int, int, int, int, int]: vocab_size = 256 @@ -278,7 +291,7 @@ def decoder( return decoder @pytest.fixture - def decoder_with_kv_cache( + def decoder_with_kv_cache_enabled( self, decoder_params: Tuple[int, int, int, int, int, int] ) -> TransformerDecoder: ( @@ -304,35 +317,6 @@ def decoder_with_kv_cache( decoder.setup_caches(batch_size=4, dtype=torch.float32) return decoder - @pytest.fixture - def decoder_with_kv_cache_fixed_length( - self, decoder_params: Tuple[int, int, int, int, int, int] - ) -> TransformerDecoder: - ( - vocab_size, - embed_dim, - num_layers, - num_heads, - max_seq_len, - num_kv_heads, - ) = decoder_params - decoder = llama2( - vocab_size=vocab_size, - num_layers=num_layers, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - embed_dim=embed_dim, - max_seq_len=max_seq_len, - ) - # TODO: fix weight initialization to use fixed_init_model - for p in decoder.parameters(): - nn.init.constant_(p, 0.2) - decoder.eval() - decoder.setup_caches( - batch_size=4, dtype=torch.float32, decoder_max_seq_len=max_seq_len + 512 - ) - return decoder - def test_forward( self, input: torch.Tensor, @@ -353,42 +337,34 @@ def test_max_seq_len_exceeded( with pytest.raises(Exception): output = decoder(input_max_len_exceeded) - @pytest.mark.parametrize( - "decoder_with_kv_cache_enabled", - ["decoder_with_kv_cache", "decoder_with_kv_cache_fixed_length"], - ) def test_kv_cache( self, - request, input: torch.Tensor, - decoder_with_kv_cache_enabled: str, + causal_mask: torch.Tensor, + input_pos: torch.Tensor, + decoder_with_kv_cache_enabled: TransformerDecoder, decoder: TransformerDecoder, ) -> None: _, seq_len = input.shape - decoder_with_kv_cache_enabled = request.getfixturevalue( - decoder_with_kv_cache_enabled - ) with torch.no_grad(): - output_cache = decoder_with_kv_cache_enabled(input) + output_cache = decoder_with_kv_cache_enabled( + input, mask=causal_mask, input_pos=input_pos + ) output_no_cache = decoder(input) assert_expected(output_cache.mean(), output_no_cache.mean()) - @pytest.mark.parametrize( - "decoder_with_kv_cache_enabled", - ["decoder_with_kv_cache", "decoder_with_kv_cache_fixed_length"], - ) def test_kv_cache_reset_values( self, - request, input: torch.Tensor, + causal_mask: torch.Tensor, + input_pos: torch.Tensor, decoder_with_kv_cache_enabled: TransformerDecoder, ) -> None: - decoder_with_kv_cache_enabled = request.getfixturevalue( - decoder_with_kv_cache_enabled - ) with torch.no_grad(): - _ = decoder_with_kv_cache_enabled(input) + _ = decoder_with_kv_cache_enabled( + input, mask=causal_mask, input_pos=input_pos + ) kv_cache_k_val = decoder_with_kv_cache_enabled.layers[ 0 ].attn.kv_cache.k_cache.clone() @@ -414,22 +390,37 @@ def test_kv_cache_reset_values_fails_when_not_enabled_first( with pytest.raises(RuntimeError, match="Key value caches are not setup"): decoder.reset_caches() - @pytest.mark.parametrize( - "decoder_with_kv_cache_enabled", - ["decoder_with_kv_cache", "decoder_with_kv_cache_fixed_length"], - ) + def test_kv_cache_setup_no_mask_in_forward( + self, + input: torch.Tensor, + decoder_with_kv_cache_enabled: TransformerDecoder, + ) -> None: + + with pytest.raises(ValueError, match="causal masks must be provided"): + decoder_with_kv_cache_enabled(input) + + def test_kv_cache_setup_mask_no_input_pos_in_forward( + self, + input: torch.Tensor, + causal_mask: torch.Tensor, + decoder_with_kv_cache_enabled: TransformerDecoder, + ) -> None: + + with pytest.raises(ValueError, match="input positions must be provided!"): + decoder_with_kv_cache_enabled(input, mask=causal_mask) + def test_kv_cache_batch_size_exceeded( self, - request, input_max_bs_exceeded: torch.Tensor, + causal_mask: torch.Tensor, + input_pos: torch.Tensor, decoder_with_kv_cache_enabled: TransformerDecoder, ) -> None: - decoder_with_kv_cache_enabled = request.getfixturevalue( - decoder_with_kv_cache_enabled - ) - with pytest.raises(RuntimeError, match="shape mismatch:"): - decoder_with_kv_cache_enabled(input_max_bs_exceeded) + with pytest.raises(RuntimeError, match="The size of tensor a"): + decoder_with_kv_cache_enabled( + input_max_bs_exceeded, mask=causal_mask, input_pos=input_pos + ) def test_rms_norm_propagation( self, decoder_params: Tuple[int, int, int, int, int, int] diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index 46bd777261..8ae3508e60 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -94,12 +94,10 @@ def setup_caches( encoder_max_seq_len (int): maximum encoder cache sequence length. decoder_max_seq_len (int): maximum decoder cache sequence length. """ - encoder_max_seq_len = ( - encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len - ) - decoder_max_seq_len = ( - decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len - ) + if encoder_max_seq_len is not None: + self.encoder_max_seq_len = encoder_max_seq_len + if decoder_max_seq_len is not None: + self.decoder_max_seq_len = decoder_max_seq_len for layer in self.layers: layer.setup_cache( batch_size, @@ -126,16 +124,16 @@ def forward( input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. + During inference, this indicates the position of the current token and + is required. Note: At the very first step of inference, when the model is provided with a prompt, - ``input_pos`` would contain the positions of all of the tokens in the prompt + ``input_pos`` should contain the positions of all of the tokens in the prompt (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the KV values for each position. Returns: - Tensor: output tensor with shape [b x s x v] + torch.Tensor: output tensor with shape [b x s x v] Raises: ValueError: if causal_mask is set but input_pos is None diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 9aaded001a..95ca1896d5 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -3,12 +3,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from typing import Dict, List, Optional, Union import torch +import torchtune from torch import nn -from torchtune.modules import TransformerDecoder class FusionLayer(nn.Module): @@ -313,13 +314,13 @@ class DeepFusionModel(nn.Module): >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) Args: - decoder (TransformerDecoder): decoder module + decoder (torchtune.modules.TransformerDecoder): decoder module encoder (nn.Module): encoder module """ def __init__( self, - decoder: TransformerDecoder, + decoder: torchtune.modules.TransformerDecoder, encoder: nn.Module, ): super().__init__() diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index a407673299..b9bf5699f1 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -11,6 +11,7 @@ from torch import nn from torchtune.modules import MultiHeadAttention +from torchtune.modules.model_fusion import FusionLayer class TransformerSelfAttentionLayer(nn.Module): @@ -360,6 +361,10 @@ def __init__( self.causal_mask = None self.num_output_chunks = 0 + # attributes for KV caches during inference + self.encoder_max_cache_seq_len = None + self.decoder_max_cache_seq_len = None + def set_num_output_chunks(self, num_output_chunks: int) -> None: """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. This should be called before the first forward pass, in the recipe.""" @@ -375,9 +380,9 @@ def setup_caches( ): """ Sets up key-value attention caches for inference. For each layer in ``self.layers``: - - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. - - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. - - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + - :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. Args: batch_size (int): batch size for the caches. @@ -385,27 +390,49 @@ def setup_caches( encoder_max_seq_len (int): maximum encoder cache sequence length. decoder_max_seq_len (int): maximum decoder cache sequence length. """ - encoder_max_seq_len = ( - encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len + + has_encoder_layers = any( + isinstance(l, TransformerCrossAttentionLayer) or isinstance(l, FusionLayer) + for l in self.layers ) - decoder_max_seq_len = ( - decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + has_decoder_layers = any( + isinstance(l, TransformerSelfAttentionLayer) for l in self.layers ) + if has_encoder_layers: + if encoder_max_seq_len is not None: + self.encoder_max_cache_seq_len = encoder_max_seq_len + else: + self.encoder_max_cache_seq_len = self.max_seq_len + + if has_decoder_layers: + if decoder_max_seq_len is not None: + self.decoder_max_cache_seq_len = decoder_max_seq_len + else: + self.decoder_max_cache_seq_len = self.max_seq_len + for layer in self.layers: layer.setup_cache( batch_size, dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, + encoder_max_seq_len=self.encoder_max_cache_seq_len, + decoder_max_seq_len=self.decoder_max_cache_seq_len, ) - def caches_are_enabled(self) -> bool: + @property + def encoder_caches_are_enabled(self) -> bool: + """Checks if there are any :class:`~torchtune.modules.TransformerCrossAttentionLayer`, + or :class:`~torchtune.modules.fusion.FusionLayer` layers which have cache enabled. + """ + return self.encoder_max_cache_seq_len is not None + + @property + def decoder_caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" - return self.layers[0].cache_enabled + return self.decoder_max_cache_seq_len is not None def reset_caches(self): """Reset the key value caches.""" - if not self.caches_are_enabled(): + if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled): raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." ) @@ -413,8 +440,6 @@ def reset_caches(self): for layer in self.layers: layer.reset_cache() - self.pos = 0 - def forward( self, tokens: torch.Tensor, @@ -431,16 +456,20 @@ def forward( with shape [b x s x s]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no - mask is specified, a causal mask is used by default. Default is None. + mask is specified, a causal mask is used by default. Default is None, + but this is required during inference if the model has been setup with any + self-attention layers. encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e] encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend - to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None. + to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None, + but this is required during inference if the model has been setup with any layers + which use encoder embeddings. input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. + During inference, this indicates the position of the current token, and + is a required parameter. Default is None. Note: At the very first step of inference, when the model is provided with a prompt, ``input_pos`` would contain the positions of all of the tokens in the prompt @@ -477,18 +506,24 @@ def forward( # shape: [b, s, d] h = self.tok_embeddings(tokens) - if self.causal_mask is not None: - if mask is not None: + if self.decoder_caches_are_enabled: + if mask is None: + raise ValueError( + "KV-caches for self-attention layers are setup for inference mode, causal masks must be provided!" + " Use the `mask` arg to provide a causal mask." + ) + if self.encoder_caches_are_enabled: + if encoder_mask is None: raise ValueError( - "An attention mask was set. Cannot use a non-causal mask for inference" + "KV-caches for cross-attention/fusion layers are setup for inference mode, causal masks must be provided!" + " Use the `encoder_mask` arg to provide a causal mask." ) - # Track the input position - if input_pos is None: - input_pos = torch.arange(self.pos, self.pos + seq_len, device=h.device) - self.pos = input_pos.max() + 1 - # shape: [1, input_pos_len, m_s] - # in most cases input_pos_len should be 1 - mask = self.causal_mask[None, input_pos] + if ( + self.encoder_caches_are_enabled or self.decoder_caches_are_enabled + ) and input_pos is None: + raise ValueError( + "KV-caches are setup for inference mode, input positions must be provided!" + ) hidden = [] for i, layer in enumerate(self.layers): @@ -590,6 +625,10 @@ def __init__( self.causal_mask = None self.num_output_chunks = 0 + # attributes for KV caches during inference + self.encoder_max_cache_seq_len = None + self.decoder_max_cache_seq_len = None + def set_num_output_chunks(self, num_output_chunks: int) -> None: """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. This should be called before the first forward pass, in the recipe.""" @@ -605,9 +644,9 @@ def setup_caches( ): """ Sets up key-value attention caches for inference. For each layer in ``self.layers``: - - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. - - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. - - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + - :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. Args: batch_size (int): batch size for the caches. @@ -615,27 +654,48 @@ def setup_caches( encoder_max_seq_len (int): maximum encoder cache sequence length. decoder_max_seq_len (int): maximum decoder cache sequence length. """ - encoder_max_seq_len = ( - encoder_max_seq_len if encoder_max_seq_len is not None else self.max_seq_len + has_encoder_layers = any( + isinstance(l, TransformerCrossAttentionLayer) or isinstance(l, FusionLayer) + for l in self.layers ) - decoder_max_seq_len = ( - decoder_max_seq_len if decoder_max_seq_len is not None else self.max_seq_len + has_decoder_layers = any( + isinstance(l, TransformerSelfAttentionLayer) for l in self.layers ) + if has_encoder_layers: + if encoder_max_seq_len is not None: + self.encoder_max_cache_seq_len = encoder_max_seq_len + else: + self.encoder_max_cache_seq_len = self.max_seq_len + + if has_decoder_layers: + if decoder_max_seq_len is not None: + self.decoder_max_cache_seq_len = decoder_max_seq_len + else: + self.decoder_max_cache_seq_len = self.decoder_max_cache_seq_len + for layer in self.layers: layer.setup_cache( batch_size, dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, + self.encoder_max_cache_seq_len, + self.decoder_max_cache_seq_len, ) - def caches_are_enabled(self) -> bool: + @property + def encoder_caches_are_enabled(self) -> bool: + """Checks if there are any :class:`~torchtune.modules.TransformerCrossAttentionLayer`, + or :class:`~torchtune.modules.fusion.FusionLayer` layers which have cache enabled. + """ + return self.encoder_max_cache_seq_len is not None + + @property + def decoder_caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" - return self.layers[0].cache_enabled + return self.decoder_max_cache_seq_len is not None def reset_caches(self): """Reset the key value caches.""" - if not self.caches_are_enabled(): + if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled): raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." ) @@ -643,8 +703,6 @@ def reset_caches(self): for layer in self.layers: layer.reset_cache() - self.pos = 0 - def forward( self, tokens: torch.Tensor, @@ -653,7 +711,7 @@ def forward( encoder_input: Optional[torch.Tensor] = None, encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Args: tokens (torch.Tensor): input tensor with shape [b x s] @@ -678,7 +736,7 @@ def forward( KV values for each position. Returns: - torch.Tensor: output tensor with shape [b x s x v] or a list of layer + Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape [b x s x v] or a list of layer output tensors defined by ``output_hidden_states`` with the final output tensor appended to the list. @@ -707,18 +765,26 @@ def forward( # shape: [b, s, d] h = self.tok_embeddings(tokens) - if self.causal_mask is not None: + if self.decoder_caches_are_enabled: if mask is not None: raise ValueError( - "An attention mask was set. Cannot use a non-causal mask for inference" + "KV-caches for self-attention layers are setup for inference mode, causal masks must be provided!" + " Use the `mask` arg to provide a causal mask." ) - # Track the input position - if input_pos is None: - input_pos = torch.arange(self.pos, self.pos + seq_len, device=h.device) - self.pos = input_pos.max() + 1 - # shape: [1, input_pos_len, m_s] - # in most cases input_pos_len should be 1 - mask = self.causal_mask[None, input_pos] + if self.encoder_caches_are_enabled: + if encoder_mask is not None: + raise ValueError( + "KV-caches for cross-attention/fusion layers are setup for inference mode, causal masks must be provided!" + " Use the `encoder_mask` arg to provide a causal mask." + ) + if ( + self.encoder_caches_are_enabled + or self.decoder_caches_are_enabled + and input_pos is None + ): + raise ValueError( + "KV-caches are setup for inference mode, input positions must be provided!" + ) hidden = [] for i, layer in enumerate(self.layers): diff --git a/torchtune/utils/_generation.py b/torchtune/utils/_generation.py index b60fe09099..2386d31895 100644 --- a/torchtune/utils/_generation.py +++ b/torchtune/utils/_generation.py @@ -40,11 +40,15 @@ def generate_next_token( x: torch.Tensor, temperature: float = 1.0, top_k: int = None, + *, + # this will all be better once https://github.com/pytorch/torchtune/pull/1424 lands. stopgap for now + mask: torch.Tensor = None, + encoder_mask: torch.Tensor = None, ) -> torch.Tensor: """Generates the next tokens.""" # model produces logits in [bsz, seq_length, vocab_size] # we want to take the last token's logits as the input to the next model call - logits = model(x, input_pos=input_pos)[:, -1] + logits = model(x, mask=mask, encoder_mask=encoder_mask, input_pos=input_pos)[:, -1] return sample(logits, temperature, top_k) @@ -114,6 +118,35 @@ def generate( (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device ) + # if key value caches are enabled, we can incrementally decode + incremental_decoding = ( + model.encoder_caches_are_enabled or model.decoder_caches_are_enabled + ) + + # setup encoder+/decoder masks for caches + mask, encoder_mask = None, None + curr_mask, curr_encoder_mask = None, None + if model.encoder_caches_are_enabled: + # generation for bsz=1 isn't supported right now anyway, broadcasting is fine here + encoder_mask = torch.tril( + torch.ones( + 1, + model.encoder_max_cache_seq_len, + model.encoder_max_cache_seq_len, + device=prompt.device, + ) + ) + if model.decoder_caches_are_enabled: + # generation for bsz=1 isn't supported right now anyway, broadcasting is fine here + mask = torch.tril( + torch.ones( + 1, + model.decoder_max_cache_seq_len, + model.decoder_max_cache_seq_len, + device=prompt.device, + ) + ) + if custom_generate_next_token is None: custom_generate_next_token = generate_next_token @@ -121,6 +154,10 @@ def generate( input_pos = torch.arange(0, model.max_seq_len, device=prompt.device) tokens = generate_next_token( model, + mask=mask[:, :prompt_length] if mask is not None else None, + encoder_mask=encoder_mask[:, :prompt_length] + if encoder_mask is not None + else None, input_pos=input_pos[:prompt_length], x=prompt, temperature=temperature, @@ -137,8 +174,6 @@ def generate( return generated_tokens.tolist() curr_pos = prompt_length - # if key value caches are enabled, we can incrementally decode - incremental_decoding = model.caches_are_enabled() for _ in range(max_generated_tokens - 1): # update stop_token_mask if we reached a stop token in a previous step # by appending the logical not of stop_token_reached to the end of the mask @@ -152,12 +187,20 @@ def generate( # otherwise, we take the whole sequence up to the current position if incremental_decoding: curr_input_pos = input_pos[curr_pos].unsqueeze(0) + curr_mask = mask[:, curr_pos, None, :] if mask is not None else None + curr_encoder_mask = ( + encoder_mask[:, curr_pos, None, :] + if curr_encoder_mask is not None + else None + ) else: curr_input_pos = input_pos[: curr_pos + 1] tokens = generated_tokens.clone() tokens = custom_generate_next_token( model, + mask=curr_mask, + encoder_mask=encoder_mask, input_pos=curr_input_pos, x=tokens, temperature=temperature, From 36ccc35fa59775b1d2ac5da3bb8c6ab15d6d8f28 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 6 Sep 2024 17:10:15 +0100 Subject: [PATCH 16/22] bug in generate --- torchtune/utils/_generation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchtune/utils/_generation.py b/torchtune/utils/_generation.py index 2386d31895..30a4ea06db 100644 --- a/torchtune/utils/_generation.py +++ b/torchtune/utils/_generation.py @@ -155,9 +155,13 @@ def generate( tokens = generate_next_token( model, mask=mask[:, :prompt_length] if mask is not None else None, - encoder_mask=encoder_mask[:, :prompt_length] - if encoder_mask is not None - else None, + encoder_mask=( + encoder_mask[:, :prompt_length] + if encoder_mask is not None + else None + if encoder_mask is not None + else None + ), input_pos=input_pos[:prompt_length], x=prompt, temperature=temperature, From 232a91ccfba87fa411adddcf745ee186396485b9 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 6 Sep 2024 17:21:49 +0100 Subject: [PATCH 17/22] updating typign --- torchtune/models/gemma/transformer.py | 8 ++++---- torchtune/modules/transformer.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index 8ae3508e60..a176f410b2 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -79,8 +79,8 @@ def setup_caches( batch_size: int, dtype: torch.dtype, *, - encoder_max_seq_len: int = None, - decoder_max_seq_len: int = None, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, ): """ Sets up key-value attention caches for inference. For each layer in ``self.layers``: @@ -91,8 +91,8 @@ def setup_caches( Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum encoder cache sequence length. - decoder_max_seq_len (int): maximum decoder cache sequence length. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. """ if encoder_max_seq_len is not None: self.encoder_max_seq_len = encoder_max_seq_len diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index b9bf5699f1..38664999b1 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -375,8 +375,8 @@ def setup_caches( batch_size: int, dtype: torch.dtype, *, - encoder_max_seq_len: int = None, - decoder_max_seq_len: int = None, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, ): """ Sets up key-value attention caches for inference. For each layer in ``self.layers``: @@ -387,8 +387,8 @@ def setup_caches( Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum encoder cache sequence length. - decoder_max_seq_len (int): maximum decoder cache sequence length. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. """ has_encoder_layers = any( @@ -639,8 +639,8 @@ def setup_caches( batch_size: int, dtype: torch.dtype, *, - encoder_max_seq_len: int = None, - decoder_max_seq_len: int = None, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, ): """ Sets up key-value attention caches for inference. For each layer in ``self.layers``: @@ -651,8 +651,8 @@ def setup_caches( Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum encoder cache sequence length. - decoder_max_seq_len (int): maximum decoder cache sequence length. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. """ has_encoder_layers = any( isinstance(l, TransformerCrossAttentionLayer) or isinstance(l, FusionLayer) From f00a79581d68dcd957d06b84e698d12737cf9bef Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 16 Sep 2024 19:57:08 +0100 Subject: [PATCH 18/22] undoing changes --- recipes/eleuther_eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 60cb83ff57..98d4f98973 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -141,6 +141,7 @@ def _model_generate( # Technically this is not necessary, but it's a good way to ensure that # the caches won't error on a different batch size. In addition, caches # are not needed for a regular model call, so we just setup here + # TODO @joecummings this is being called multiple times resulting in many WARNINGs if self.enable_kv_cache: with context.device: self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype) @@ -154,7 +155,7 @@ def _model_generate( "``do_sample`` for generation tasks is not supported yet in torchtune." ) - toks = generation.generate( + toks, _ = generation.generate( self._model, context, max_generated_tokens=self.max_gen_toks, From 4677ff58365a6068fa0e53aa76f675cfe6f55dbb Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 16 Sep 2024 19:57:13 +0100 Subject: [PATCH 19/22] undoing changes --- tests/recipes/test_eleuther_eval.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/recipes/test_eleuther_eval.py b/tests/recipes/test_eleuther_eval.py index c522805b38..1575ca04cc 100644 --- a/tests/recipes/test_eleuther_eval.py +++ b/tests/recipes/test_eleuther_eval.py @@ -19,12 +19,19 @@ class TestEleutherEval: + @pytest.mark.parametrize( + "eval_name, expected_acc, bsz", + [("truthfulqa_gen", 0.1, 1), ("truthfulqa_mc2", 0.3, 8)], + ) @pytest.mark.integration_test - def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir): + def test_torchtune_checkpoint_eval_results( + self, capsys, monkeypatch, tmpdir, eval_name, expected_acc, bsz + ): ckpt = "llama2_tune" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent + # TODO @joecummings bsz > 1 isn't supported for generation tasks, update test once integrated cmd = f""" tune run eleuther_eval \ --config eleuther_evaluation \ @@ -39,6 +46,8 @@ def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir): limit=10 \ dtype=fp32 \ device=cpu \ + tasks=[{eval_name}]\ + batch_size={bsz} \ """.split() model_config = llama2_test_config() @@ -66,7 +75,7 @@ def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir): ) assert search_results is not None acc_result = float(search_results.group(1)) - assert math.isclose(acc_result, 0.3, abs_tol=0.05) + assert math.isclose(acc_result, expected_acc, abs_tol=0.05) @pytest.fixture def hide_available_pkg(self, monkeypatch): From 33ad2f5394d5f6a4fd9bf92ec6c5d02c6aa35010 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 16 Sep 2024 20:05:55 +0100 Subject: [PATCH 20/22] nits --- torchtune/modules/transformer.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index ff1f0bb3c8..fa70f4fd35 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -438,7 +438,7 @@ def decoder_caches_are_enabled(self) -> bool: def reset_caches(self): """Reset the key value caches.""" - if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled): + if not (self.encoder_caches_are_enabled() or self.decoder_caches_are_enabled()): raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." ) @@ -695,6 +695,25 @@ def __init__( self.encoder_max_cache_seq_len = None self.decoder_max_cache_seq_len = None + @torch.compiler.disable + def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]: + """ + Apply output projection in chunks. This should be applied in conjunction with + :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there. + To use this method, you should first call + :func:`~torchtune.modules.TiedEmbeddingTransformerDecoder.set_num_output_chunks`. + Args: + last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape + [b, seq_len, embed_dim]. + Returns: + List[torch.Tensor]: List of num_chunks output tensors, each with shape + [b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size. + """ + return [ + F.linear(chunk, self.tok_embeddings.weight) + for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1) + ] + def set_num_output_chunks(self, num_output_chunks: int) -> None: """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. This should be called before the first forward pass, in the recipe.""" From 0dd43dc29fe233142f23bfb103ecc489aee84ae5 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 16 Sep 2024 20:14:33 +0100 Subject: [PATCH 21/22] NITS --- torchtune/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index fa70f4fd35..6750045422 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -564,7 +564,7 @@ def forward( At the very first step of inference, when the model is provided with a prompt, ``input_pos`` should contain the positions of all of the tokens in the prompt. For a single-batch prompt, or a batch of prompts with identical lengths, this - will be``torch.arange(prompt_length)``. For a batch of varying-length prompts, + will be ``torch.arange(prompt_length)``. For a batch of varying-length prompts, shorter prompts are left-padded and position ids are correspondingly right-shifted, thus positional ids should be of shape ``[b, padded_prompt_length]``. This is because we will need to retrieve the positional embeddings for each input id. From 3fc1135f1df82cded7637cc3baefd82605e7222e Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 16 Sep 2024 21:40:49 +0100 Subject: [PATCH 22/22] fixing eval recipe --- recipes/eleuther_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 98d4f98973..2fe8db1a29 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -141,7 +141,6 @@ def _model_generate( # Technically this is not necessary, but it's a good way to ensure that # the caches won't error on a different batch size. In addition, caches # are not needed for a regular model call, so we just setup here - # TODO @joecummings this is being called multiple times resulting in many WARNINGs if self.enable_kv_cache: with context.device: self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype) @@ -163,6 +162,7 @@ def _model_generate( top_k=None, # do_sample is not supported currently stop_tokens=self._tokenizer.stop_tokens, ) + self._model.reset_caches() return torch.tensor(toks, dtype=torch.int32)