Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix skip_keys usage in forward hooks #3088

Merged
merged 2 commits into from
Sep 10, 2024
Merged

Conversation

152334H
Copy link
Contributor

@152334H 152334H commented Sep 7, 2024

What does this PR do?

it fixes skip_keys to be actually used by the big modelling forward hooks.

previously skip_keys would not be propagated to nested modules (idk what level of nesting but some level), && kwargs that were supposed to be skipped would just be transferred to other devices by send_to_device anyway.

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
    • there is no open issue for this as far as i know
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
    • there is no documentation change needed for this because this just fixes something the docs say is supposed to work
  • Did you write any new necessary tests?
    • no it's a 2 line change

Who can review?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find @152334H ! I'm a bit curious on how you found this issue ? Did you have a specific case that required you to skip keys ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@152334H
Copy link
Contributor Author

152334H commented Sep 9, 2024

@SunMarc I had a simple kv cache impl that passes a cache tensor through kwargs. Something like, top-level transformer receives gigantic kv tensor that gets indexed per-layer, with each layer statefully-modifying the cache from the input. and the reason this breaks without skip_keys is that if the tensor is moved to another gpu in input, then only a copy of the input tensor is modified, and the cache provided at top-level forward is not.

more simply, any model that expects forward pass inputs to be statefully modified will get surprised by the big modelling approach of copying input tensors around, if skip_keys isn't used

@SunMarc
Copy link
Member

SunMarc commented Sep 9, 2024

@SunMarc I had a simple kv cache impl that passes a cache tensor through kwargs. Something like, top-level transformer receives gigantic kv tensor that gets indexed per-layer, with each layer statefully-modifying the cache from the input. and the reason this breaks without skip_keys is that if the tensor is moved to another gpu in input, then only a copy of the input tensor is modified, and the cache provided at top-level forward is not.

more simply, any model that expects forward pass inputs to be statefully modified will get surprised by the big modelling approach of copying input tensors around, if skip_keys isn't used

Thanks for the detailed report. This is a bit surprising that the inputs are being copied since we do the following:
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
Are you seeing memory increase when you don't set skip_keys / without this PR due to some edge cases ? Just like your use case, we added skip_keys to deal with the cache in transformers, avoiding expensive to() ops each time.

@152334H
Copy link
Contributor Author

152334H commented Sep 9, 2024

my the issue isn't memory increase (even if it does happen slightly). it is that stateful modifications to input tensors do not persist due to the movement of the args to a new device and hence a new pointer

please see this minified problem case:

import torch
from torch import nn, zeros

class Z(nn.Module):
    def __init__(self):
        super().__init__()
        self.l = nn.Linear(1,1) # some dummy weight
    def forward(self, *, x): x[0] = 1 # <--- if you print x.device here it will not always be cuda:0 in the 3rd test case without the patch
class A(nn.Module):
    def __init__(self):
        super().__init__()
        self.z = Z()
    def forward(self, **k): self.z(**k) # <-- k['x'].device here will still be correct in all cases
class B(nn.ModuleList):
    def forward(self, device):
        K = zeros(9, device=device).unsqueeze(1)
        for i,l in enumerate(self): l(x=K[i])
        return K.flatten()

def test_conditions(device='cpu', callback=lambda _:0):
    m = B([A() for _ in range(9)])
    callback(m)
    print(m(device))

from accelerate.big_modeling import dispatch_model, infer_auto_device_map, get_balanced_memory
def split_model(m, skip_keys=None):
    kw = dict(no_split_module_classes=["A"], dtype=torch.bfloat16)
    device_map = infer_auto_device_map(m, max_memory=get_balanced_memory(m, **kw), **kw)
    dispatch_model(m, device_map=device_map, skip_keys=skip_keys)

# this is correct:
test_conditions()
# this is alaways wrong:
test_conditions("cuda", split_model)
# this is correct with the PR and wrong without:
test_conditions("cuda", lambda m: split_model(m, ["x"]))

with patch

$ python3 ac.py
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')

without patch

$ python3 ac.py
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')

all tests done on multi 3090 node

@SunMarc
Copy link
Member

SunMarc commented Sep 10, 2024

Makes sense ! Thanks for the reproducer ! Merging it

@SunMarc SunMarc merged commit 7d3bbe7 into huggingface:main Sep 10, 2024
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants