Skip to content

Commit

Permalink
Fix save adapter weights only (pytorch#1764)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored and mori360 committed Oct 14, 2024
1 parent 7f175a0 commit 951294f
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 32 deletions.
17 changes: 2 additions & 15 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,34 +401,21 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}

merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)

ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
else:
# No need to merge state dict if we're only saving adapter weights
adapter_state_dict = {
k: v.cpu() for k, v in get_adapter_params(self._model).items()
}

ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

self._checkpointer.save_checkpoint(
ckpt_dict,
Expand Down
17 changes: 3 additions & 14 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,34 +577,23 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}

merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)

ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
else:
# No need to merge state dict if we're only saving adapter weights
adapter_state_dict = {
k: v.cpu() for k, v in get_adapter_params(self._model).items()
}

ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
Expand Down
4 changes: 3 additions & 1 deletion tests/recipes/test_lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"batch_size=8",
"device=cpu",
f"dtype={dtype_str}",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
"seed=9",
f"epochs={epochs}",
Expand Down Expand Up @@ -83,6 +82,7 @@ def test_training_state_on_resume(
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down Expand Up @@ -112,6 +112,7 @@ def test_training_state_on_resume(
metric_logger.filename={resumed_log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down
5 changes: 4 additions & 1 deletion tests/recipes/test_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class TestLoRAFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
return [
"batch_size=4",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
"seed=9",
"epochs=2",
Expand Down Expand Up @@ -81,6 +80,7 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch):
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
reshard_after_forward={reshard_after_forward} \
enable_activation_checkpointing=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down Expand Up @@ -147,6 +147,7 @@ def test_training_state_on_resume(
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand All @@ -171,6 +172,7 @@ def test_training_state_on_resume(
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
""".split()

cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
Expand Down Expand Up @@ -213,6 +215,7 @@ def test_save_and_load_merged_weights(
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand Down
5 changes: 4 additions & 1 deletion tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"batch_size=8",
"device=cpu",
f"dtype={dtype_str}",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
"seed=9",
f"epochs={epochs}",
Expand Down Expand Up @@ -133,6 +132,7 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch):
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
compile={compile} \
enable_activation_checkpointing=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_qlora"]
Expand Down Expand Up @@ -188,6 +188,7 @@ def test_training_state_on_resume(
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand All @@ -213,6 +214,7 @@ def test_training_state_on_resume(
metric_logger.filename={log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
Expand Down Expand Up @@ -244,6 +246,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down

0 comments on commit 951294f

Please sign in to comment.