-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Prefix tuning with model on multiple devices #2189
FIX: Prefix tuning with model on multiple devices #2189
Conversation
See huggingface#2134 After introducing the usage of DynamicCache for prefix tuning, a bug could now occur if the model is dispatched to different devices. This is because we need to move the key and value cache for each layer to that layer's respective device. The new code mostly consists of code copied from transformers to be consistent with how transformers solves this.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
""" | ||
Ensure that the key and value cache of the model are on the same device as their corresponding layers. | ||
""" | ||
if not (isinstance(cache, transformers.Cache) and hasattr(model, "hf_device_map")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prob or not hastatt(model, "hf_device_map")
otherwise we don't skip cases when the cache is tuple and model has no device map
Also I'm wondering if this didn't fail for tuple because accelerate handles device allocations for tensors, but can't handle for object? Since in tuple format we also don't do any device mapping usually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, maybe I misunderstand, but right now we have "not (A and B)", which is the same as "not A or not B".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right, I didn't see the brackets around I guess XD
cache.key_cache[idx] = cache.key_cache[idx].to(layer_device) | ||
cache.value_cache[idx] = cache.value_cache[idx].to(layer_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might fail for encoder-decoder cache, can we also try with T5 model in multi-gpu environment to see if it works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried and it didn't work because the get_layer_device_map
function would fail. I tried to check why it doesn't fail when the function is called in transformers. But for T5, it is never reached. Honestly, I couldn't figure out why it's different for encoder-decoder. My solution for now is to call map_cache_to_layer_device_map
only when peft_config.num_transformer_submodules == 1
and leave encoder-decoder untouched for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, sounds good to me, and we can add multi-gpu support for encoder-decoder as more models get converted to new cache format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for adding this
@zucchini-nlp I did manage to make this work with T5, the fix was actually quite simple. Could you please review again? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome that it worked for T5! Thanks!
See #2134
After introducing the usage of
DynamicCache
for prefix tuning, a bug could now occur if the model is dispatched to different devices. This is because we need to move the key and value cache for each layer to that layer's respective device.The new code mostly consists of code copied from transformers to be consistent with how transformers solves this.
Note that this only works if the
hf_device_map
attribute is set on the model.