Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent OOM during checkpoint save on colab for llama3-8b qlora recipe #1315

Merged
merged 36 commits into from
Sep 10, 2024

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Aug 12, 2024

Context

This PR prevents OOM during checkpoint save on colab for the following recipe,

tune run recipes/lora_finetune_single_device.py --config recipes/configs/llama3/8B_qlora_single_device.yaml

I do believe it is possible to do better (perf improvements for this colab case) if we refactor the checkpointing logic, but that would be a more invasive refactor imo for now this is the most minimally invasive change that unblocks this use case.

Changelog

Old changelog (keeping around for posterity)

  • Set a seed in 8B_qlora_single_device.yaml to make dataloader samples (and hence weights) deterministic
  • [for testing velocity purposes]Changed number of epochs to 8, reduced max_steps_per_epoch --> 20, gradient_accumulation_steps --> 2 to make save_checkpoint be called sooner
  • [for loss curves] some changes to lora_finetune_single_device.py to mimic a user doing resume_from_checkpoint after each epoch (while preserving logger)
  • [Not to be landed in torchtune, for PoC purpose] patched FakeTensor.__reduce_ex__ which is needed to ensure the write_record_metadata utility to create empty checkpoints is called (requires changes in Prototype changes to create fake checkpoints with empty storages pytorch#133272) I need to figure out how to land this piece :)
  • Skip registration of reparametrize_as_dtype_state_dict_post_hook for llama3
    - Showed an example of how to do the corresponding (using mmap to prevent OOM) in lora_finetune_single_device.py:save_checkpoint

Test plan

Sanity check

Ran tune run command on devgpu (with only the changes in 8B_qlora_single_device.yaml to set seed and decrease steps per epoch) and verified that meta_model_0.pt generated is the same before and after the changes in this PR with small snippet

Before ckpt save time (s) After ckpt save time (s)
devgpu 178s ~200s (ranged from 190 to 220ish)
colab OOM (no baseline) 510s (ranged from 421s to 610s)

Verified that colab does not OOM
https://colab.research.google.com/drive/1y7Az78ATauK7gkewZkcMO3cNgVWm1233?usp=sharing

Loss Curves

The validation was run on commit abdbd7 which has special logic to mimic resume_from_checkpoint for each epoch

Config is per the changes in recipes/configs/llama3/8B_qlora_single_device.yaml (8 epochs with 20 steps per epoch and gradient accumulation every 2 steps) on 6cf31b6. For the reloading checkpoint case I modified lora_fine_tune_single_device.py:recipe_main to mimic resume_from_checkpoint after each epoch

Devgpu:
Screenshot 2024-09-05 at 10 29 09 PM

Colab:
Screenshot 2024-09-05 at 10 26 19 PM

Copy link

pytorch-bot bot commented Aug 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1315

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dc8c6b0 with merge base 66590b4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 12, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Left a few basic questions, but overall really excited to see that we'll be able to provide proper end-to-end Colab support with these changes. One high-level comment is that we should think about writing a utility for the portion in ~L520-L550. Then we can gate behind a config like low_memory_save or something like that. But that's more of a UX thing, happy to help out there if needed.

recipes/configs/llama3/8B_qlora_single_device.yaml Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved

# Construct the full state dict with LoRA weights merged into base LLM weights
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
dest_state_dict=dest_state_dict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(another noob q:) so after this, we can save the checkpoint normally (e.g. in the call to save_checkpoint on L580), even though we are now saving something like {"model": dest_state_dict, ... (other stuff)} to a separate file from the mmapped one?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think of dest_state_dict as being backed by disk (and yes we are writing it once to disk in a torch.save format in this whole bit)

The final save_checkpoint on L580 is re-saving a new checkpoint with {"model": merged_state_dict, ... (other stuff)} to a new file checkpoint file. So we are saving dest_state_dict twice

We could potentially do something smarter than what we're doing right now to avoid this "re-save" but the code changes would be more invasive than they are now, given that the checkpointers seem to do some remapping though I refrained from doing this as a v0, wdyt?

Comment on lines +511 to +539
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we want to only run state_dict post hooks once, so we should reuse the state_dict

this does change the semantic slightly though -- before adapter_*.pt contained weights tagged with CUDA, but now it contains weights tagged with CPU.

Not sure whether the old behavior was intended/whether this change is ok (but no CI seems to fail :D) Also when loading, we map_location="cpu" regardless

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this comment before. But yeah I think this change makes sense, I don't think there's any reason we need to require CUDA weights. And as you point out since we load on CPU when resuming it was probably never really an issue. Plus not re-running the state dict post hooks is a nice bonus.

@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review August 29, 2024 17:11
@mikaylagawarecki mikaylagawarecki changed the title [PoC] Prevent OOM during checkpoint save on colab Prevent OOM during checkpoint save on colab for llama38b qlora recipe Aug 29, 2024
@mikaylagawarecki mikaylagawarecki changed the title Prevent OOM during checkpoint save on colab for llama38b qlora recipe Prevent OOM during checkpoint save on colab for llama3-8b qlora recipe Aug 29, 2024
@@ -23,6 +23,7 @@ model:
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
low_cpu_ram: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we have to modify the state dict hook, I see why it makes sense to put this in the model config. But it does feel a bit weird to me since it's not really a property of the model (more just the model is a convenient place for us to know that we're gonna have to upcast NF4 tensors).

I wonder if we can instead define a standalone config, parse it in the recipe with e.g. low_cpu_ram = cfg.get("low_cpu_ram", False), then use that to overwrite the reparametrize_as_dtype_state_dict_post_hook. Maybe a bit hacky, but we can at least assert that the expected state dict hook is there before replacing it to ensure that we aren't adding this onto any old non-QLoRA model. Is that obviously worse? (Mainly I want to avoid our model classes having to know or care about low-level details of how they're gonna be checkpointed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think to manually remove and reregister the hook we would need the handle returned by _register_state_dict_hook

Is the way I updated this to patch the hook ok with you

Comment on lines 261 to 266
if sys.platform == "win32":
raise RuntimeError(
"low_cpu_ram=True not supported on Windows."
)
else:
raise RuntimeError("low_cpu_ram=True requires torch.__version__ >= 2.5.0.dev20240830.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here: ideally we can do these checks in the recipe (or in a utility) rather than in the builder of the model



# mmap.MAP_SHARED is not supported on Windows but this change targets colab.
if hasattr(torch.serialization, "skip_data") and not sys.platform == "win32":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the first half of this check just a proxy for a particular torch version? If so maybe better to just directly gate on that (with torch_version_ge or something)

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_version_ge seems to cause circular import when imported in this file :/ so just using __torch_version__

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah we are working to fix that, __torch_version__ is good too

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more small comments and questions, but otherwise this looks good to go!

Comment on lines +511 to +539
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this comment before. But yeah I think this change makes sense, I don't think there's any reason we need to require CUDA weights. And as you point out since we load on CPU when resuming it was probably never really an issue. Plus not re-running the state dict post hooks is a nice bonus.

Comment on lines 724 to 725
if cfg.get("low_cpu_ram", False):
common_utils._use_low_cpu_ram = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be in recipe_main? Can we instead do it somewhere inside the recipe class (before the model gets instantiated)? Also would add a one-line comment explaining this



# mmap.MAP_SHARED is not supported on Windows but this change targets colab.
if torch.__version__ >= "2.5.0.dev20240906" and not sys.platform == "win32":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a sanity check here: is this if/else just to ensure that no one tries to directly import the _low_ram_reparametrize_as_dtype_state_dict_post_hook API on an unsupported environment? Mainly asking because we have the equivalent checks in _register_reparametrize_state_dict_hooks now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, removing this if else

model._register_state_dict_hook(
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True)
)
_register_reparametrize_state_dict_hooks(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be done in this PR, but we can think about adding this for other models that would have a similar memory situation when running QLoRA (Llama 3.1 8B is an obvious choice, but there are a handful of other similarly-sized models supported in our repo that could benefit from this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do in followup

recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
# Create a state_dict on disk with space reserved for storage bytes
# Then load with mmap and MAP_SHARED (can writeback to disk file)
dest_state_dict_path = "/tmp/fake_state_dict.pt"
with torch.serialization.skip_data(materialize_fake_tensors=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: what does materialize_fake_tensors mean in this context?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means that FakeTensors in the object passed to torch.save will be treated as if they were real tensors. The implication is that torch.load will load a tensor (not FakeTensor) on the FakeTensor's device with storage allocated but uninitialized (0s)

torchtune/modules/common_utils.py Outdated Show resolved Hide resolved
Comment on lines 127 to 131
# In place update original state_dict object. Although the private state dict
# post hook supports out of place behavior, the semantic actually buggy. We eventually want
# to use the public state_dict post hook which does not support out of place behavior.
for k in state_dict.keys():
state_dict[k] = dest_state_dict[k]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have some misunderstanding here. If we inplace update the state dict to the upcasted version of the weights, why won't it cause an OOM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we do sd = torch.load(..., mmap=True), the storages of the tensors in sd are mmap-backed

state_dict[k] = dest_state_dict[k] does not access any pages of the storage of the tensor given by dest_state_dict[k], so the storage is not materialized, and no OOM will happen

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah nvm I get it now, I was not thinking it through carefully enough. Thanks for the explanation!

@codecov-commenter
Copy link

codecov-commenter commented Sep 9, 2024

Codecov Report

Attention: Patch coverage is 17.64706% with 42 lines in your changes missing coverage. Please review.

Project coverage is 27.18%. Comparing base (66590b4) to head (dc8c6b0).

Files with missing lines Patch % Lines
torchtune/modules/common_utils.py 22.22% 28 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 7 Missing ⚠️
torchtune/modules/peft/_utils.py 0.00% 6 Missing ⚠️
torchtune/models/llama3/_component_builders.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1315      +/-   ##
==========================================
- Coverage   27.22%   27.18%   -0.04%     
==========================================
  Files         286      286              
  Lines       13828    13869      +41     
==========================================
+ Hits         3764     3770       +6     
- Misses      10064    10099      +35     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for enabling this, can't wait to put together some torchtune Colab notebooks for our users!

@mikaylagawarecki
Copy link
Contributor Author

mikaylagawarecki commented Sep 10, 2024

Added changes from #1535 + one more round of loss curves comparing the fake resume_from_checkpoint in eba1ffa + same config on base

Screenshot 2024-09-10 at 3 30 08 PM Screenshot 2024-09-10 at 3 27 23 PM

@ebsmothers ebsmothers merged commit 515efbe into pytorch:main Sep 10, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants