From 3e83a537cc22df8d7ffe1a12743cd1f621d2180e Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Mon, 7 Oct 2024 17:50:44 -0700 Subject: [PATCH 1/2] [WIP] Fix save adapter weights only --- recipes/lora_finetune_single_device.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 6cc57d7bcd..6641863e4d 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -577,20 +577,15 @@ def save_checkpoint(self, epoch: int) -> None: } ) + adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} + ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + if not self._save_adapter_weights_only: # Construct the full state dict with LoRA weights merged into base LLM weights # Move to CPU to avoid a copy on GPU state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} - # Construct the adapter weights - # Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice - # Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in state_dict.items() if adapter_key_filter(k) - } - merged_state_dict = get_merged_lora_ckpt( state_dict, rank=self._lora_rank, @@ -598,13 +593,7 @@ def save_checkpoint(self, epoch: int) -> None: ) ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) - else: - # No need to merge state dict if we're only saving adapter weights - adapter_state_dict = { - k: v.cpu() for k, v in get_adapter_params(self._model).items() - } - ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) adapter_config = { "r": self._lora_rank, "lora_alpha": self._lora_alpha, From 8357be370cd320f0b9e985a0c6f9d532df6c6d81 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Mon, 7 Oct 2024 21:15:58 -0700 Subject: [PATCH 2/2] Update tests and DPO recipe --- recipes/lora_dpo_single_device.py | 17 ++--------------- tests/recipes/test_lora_dpo_single_device.py | 4 +++- tests/recipes/test_lora_finetune_distributed.py | 5 ++++- .../recipes/test_lora_finetune_single_device.py | 5 ++++- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index edd2d10427..b7d931accc 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -401,20 +401,14 @@ def save_checkpoint(self, epoch: int) -> None: } ) - adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} + ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) if not self._save_adapter_weights_only: # Construct the full state dict with LoRA weights merged into base LLM weights # Move to CPU to avoid a copy on GPU state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} - # Construct the adapter weights - # Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice - # Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys - adapter_state_dict = { - k: v for k, v in state_dict.items() if adapter_key_filter(k) - } - merged_state_dict = get_merged_lora_ckpt( state_dict, rank=self._lora_rank, @@ -422,13 +416,6 @@ def save_checkpoint(self, epoch: int) -> None: ) ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) - else: - # No need to merge state dict if we're only saving adapter weights - adapter_state_dict = { - k: v.cpu() for k, v in get_adapter_params(self._model).items() - } - - ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) self._checkpointer.save_checkpoint( ckpt_dict, diff --git a/tests/recipes/test_lora_dpo_single_device.py b/tests/recipes/test_lora_dpo_single_device.py index 195d3181d0..d8cdca76c2 100644 --- a/tests/recipes/test_lora_dpo_single_device.py +++ b/tests/recipes/test_lora_dpo_single_device.py @@ -32,7 +32,6 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "batch_size=8", "device=cpu", f"dtype={dtype_str}", - "enable_activation_checkpointing=False", "dataset.train_on_input=False", "seed=9", f"epochs={epochs}", @@ -83,6 +82,7 @@ def test_training_state_on_resume( tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ metric_logger.filename={log_file} \ + enable_activation_checkpointing=True \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -112,6 +112,7 @@ def test_training_state_on_resume( metric_logger.filename={resumed_log_file} \ tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ + enable_activation_checkpointing=True \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config monkeypatch.setattr(sys, "argv", cmd_2) @@ -142,6 +143,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): checkpointer.model_type=LLAMA2 \ tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ + enable_activation_checkpointing=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index 233309b4d4..7777b02862 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -33,7 +33,6 @@ class TestLoRAFinetuneDistributedRecipe: def _get_test_config_overrides(self): return [ "batch_size=4", - "enable_activation_checkpointing=False", "dataset.train_on_input=False", "seed=9", "epochs=2", @@ -81,6 +80,7 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ reshard_after_forward={reshard_after_forward} \ + enable_activation_checkpointing=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -147,6 +147,7 @@ def test_training_state_on_resume( tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ + enable_activation_checkpointing=True \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] @@ -171,6 +172,7 @@ def test_training_state_on_resume( tokenizer.prompt_template=null \ resume_from_checkpoint=True \ metric_logger.filename={log_file} \ + enable_activation_checkpointing=True \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config @@ -213,6 +215,7 @@ def test_save_and_load_merged_weights( checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ + enable_activation_checkpointing=True \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index 4499e6614f..f2d7409042 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -35,7 +35,6 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "batch_size=8", "device=cpu", f"dtype={dtype_str}", - "enable_activation_checkpointing=False", "dataset.train_on_input=False", "seed=9", f"epochs={epochs}", @@ -133,6 +132,7 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ compile={compile} \ + enable_activation_checkpointing=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_qlora"] @@ -188,6 +188,7 @@ def test_training_state_on_resume( tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ + enable_activation_checkpointing=True \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -213,6 +214,7 @@ def test_training_state_on_resume( metric_logger.filename={log_file} \ tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ + enable_activation_checkpointing=True \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config monkeypatch.setattr(sys, "argv", cmd_2) @@ -244,6 +246,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): checkpointer.model_type=LLAMA2 \ tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ + enable_activation_checkpointing=True \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"]