-
Notifications
You must be signed in to change notification settings - Fork 970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix skip_keys
usage in forward hooks
#3088
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find @152334H ! I'm a bit curious on how you found this issue ? Did you have a specific case that required you to skip keys ?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@SunMarc I had a simple kv cache impl that passes a cache tensor through kwargs. Something like, top-level transformer receives gigantic kv tensor that gets indexed per-layer, with each layer statefully-modifying the cache from the input. and the reason this breaks without skip_keys is that if the tensor is moved to another gpu in input, then only a copy of the input tensor is modified, and the cache provided at top-level forward is not. more simply, any model that expects forward pass inputs to be statefully modified will get surprised by the big modelling approach of copying input tensors around, if skip_keys isn't used |
Thanks for the detailed report. This is a bit surprising that the inputs are being copied since we do the following: |
my the issue isn't memory increase (even if it does happen slightly). it is that stateful modifications to input tensors do not persist due to the movement of the args to a new device and hence a new pointer please see this minified problem case: import torch
from torch import nn, zeros
class Z(nn.Module):
def __init__(self):
super().__init__()
self.l = nn.Linear(1,1) # some dummy weight
def forward(self, *, x): x[0] = 1 # <--- if you print x.device here it will not always be cuda:0 in the 3rd test case without the patch
class A(nn.Module):
def __init__(self):
super().__init__()
self.z = Z()
def forward(self, **k): self.z(**k) # <-- k['x'].device here will still be correct in all cases
class B(nn.ModuleList):
def forward(self, device):
K = zeros(9, device=device).unsqueeze(1)
for i,l in enumerate(self): l(x=K[i])
return K.flatten()
def test_conditions(device='cpu', callback=lambda _:0):
m = B([A() for _ in range(9)])
callback(m)
print(m(device))
from accelerate.big_modeling import dispatch_model, infer_auto_device_map, get_balanced_memory
def split_model(m, skip_keys=None):
kw = dict(no_split_module_classes=["A"], dtype=torch.bfloat16)
device_map = infer_auto_device_map(m, max_memory=get_balanced_memory(m, **kw), **kw)
dispatch_model(m, device_map=device_map, skip_keys=skip_keys)
# this is correct:
test_conditions()
# this is alaways wrong:
test_conditions("cuda", split_model)
# this is correct with the PR and wrong without:
test_conditions("cuda", lambda m: split_model(m, ["x"])) with patch$ python3 ac.py
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0') without patch$ python3 ac.py
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0') all tests done on multi 3090 node |
Makes sense ! Thanks for the reproducer ! Merging it |
What does this PR do?
it fixes
skip_keys
to be actually used by the big modelling forward hooks.previously
skip_keys
would not be propagated to nested modules (idk what level of nesting but some level), && kwargs that were supposed to be skipped would just be transferred to other devices bysend_to_device
anyway.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?