Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cpu_offload_with_hook #1045

Merged
merged 6 commits into from
Feb 7, 2023
Merged

Add cpu_offload_with_hook #1045

merged 6 commits into from
Feb 7, 2023

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Feb 7, 2023

This is a feature request for Diffusers. This PR adds a new cpu_offload_with_hook function that will offload the model to CPU, then put it back on the GPU once executed, but without offloading it just after the forward like cpu_offload does. Instead, it's up to the user to call the hook.offload() to offload it again.

Example:

import torch
form accelerate import cpu_offload_with_hook

model = nn.Linear(4, 5)
model, hook = cpu_offload_with_hook(model)
print(model.weight.device) # Always cpu

outputs = model(inputs)
print(outputs.device) # Outputs are on GPU, execution was done on GPU
print(model.weight.device) # Stays on the GPU until hook.offload() is called

hook.offload()
print(model.weight.device) # Back to cpu

cc @pcuenca @patrickvonplaten

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 7, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nice! If possible we were also envisioning the possibility to pass a previous user_hook to the CpuOffload hook so that the previous model can be offloaded at forward. E.g. so that the following could be possible:

hook_1 = cpu_offload_with_hook(model_1, cuda_device)
hook_2 = cpu_offload_with_hook(model_2, cuda_device, user_hook=hook_1)
hook_3 = cpu_offload_with_hook(model_3, cuda_device, user_hook=hook_2)

so that the following would automatically work:

hid_1 = model_1(input)
for i in range(50):
     hid_2 = model_2(hid_1)
hid_3 = model_3(hid_3)

Alternatively, we could also call the hooks ourselves if preferred, but if we could directly pass the hooks as shown in the code review that would make diffusers code much nicer I think :-)

src/accelerate/big_modeling.py Outdated Show resolved Hide resolved
src/accelerate/big_modeling.py Outdated Show resolved Hide resolved
src/accelerate/hooks.py Outdated Show resolved Hide resolved
src/accelerate/hooks.py Show resolved Hide resolved
sgugger and others added 3 commits February 7, 2023 11:01
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@pacman100
Copy link
Contributor

Really Cool!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks a mille for the super fast PR ❤️

@sgugger sgugger merged commit 71e81ba into main Feb 7, 2023
@sgugger sgugger deleted the cpu_offload_hook branch February 7, 2023 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants