-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Clone tensors to be able to mutate #33179
Conversation
For huggingface#33178 Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the PR, left a comment !
k_out = self.key_cache[layer_idx].clone() | ||
v_out = self.value_cache[layer_idx].clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the values properly updated (self.key_cache and self.value_cache) since you are doing a copy ? Also do you see a memory increase from this modification ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the values properly updated (self.key_cache and self.value_cache) since you are doing a copy ?
I afraid I have an issue with how to check this explicitly since adding print()
here does not work, Can you suggest what and how I should check? However, documentation on clone says: "This function is differentiable, so gradients will flow back from the result of this operation to input. To create a tensor without an autograd relationship to input see detach()." So, I think updates to cloned tensor should propagate back, but I honestly don't know how to check this explicitly.
Also do you see a memory increase from this modification ?
I did not check, but my expectation is that there will be memory increase since what clone()
dos is to create a new tensor.
Also, I noted this issue on pytorch side:
I think that the better fix here will actually be to if-check whether tensors are on wrong device and if they are call to(copy=True)
. But this assumes a fix to the above pytorch issue. I also see this change (merged) on pytorch side pytorch/pytorch#132529 where they seem to wrokaround the issue with non-mutable tensors after calling .to()
for some other particular case (search for 131679 issue in the change). If I read it correctly, they are doing the same - calling clone()
and doing operation on a cloned tensor.
@bdhirsh, @SherlockNoMad, @albanD : may I ask your help to weigh in on this issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah just commented here pytorch/pytorch#131679 (comment) - Tugsuu is (tentatively) going to look into making that error less intrusive so this case should "just work" without a clone. It'll probably take some time though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh : thank you so much for looking into this. In a meanwhile we are considering 2 workarounds on HF side:
- This PR (WIP: Clone tensors to be able to mutate #33179)
- PR from @guangy10 (Unbreak torch export with static cache #33287)
@bdhirsh, can you guide us whether 1st or 2nd WA is preferable or you just recommend to wait couple days for the fix on PyTorch side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually a bit confused by the code. For example - in https://github.com/huggingface/transformers/pull/33287/files, I don't think that code will do anything:
# let's say self.key_cache is a tensor with device =='cuda'
# and let's say key_states.device == 'cpu'
self.key_cache[0] = self.key_cache[0].to(device=key_states.device)
the above code won't actually change the device of self.key_cache
(it will remain on cuda, even if key_states lives on cpu).
self.key_cache[0]is taking a view/slice off of
self.key_cache`, and you can't change the device of "part" of a tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my meta question is: in general, don't you know the device of your kv cache ahead of time? Can you ensure that your model (and its kv cache) are initialized on the right device to begin with, so that when you export, you don't need to do any device conversions? You probably don't want any device conversions in the runtime path (inside of your exported program) anyway, since they will hurt perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my meta question is: in general, don't you know the device of your kv cache ahead of time? Can you ensure that your model (and its kv cache) are initialized on the right device to begin with, so that when you export, you don't need to do any device conversions? You probably don't want any device conversions in the runtime path (inside of your exported program) anyway, since they will hurt perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my meta question is: in general, don't you know the device of your kv cache ahead of time?
#33303 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool :)
Superseded by #33303. |
Submitting this as draft since I am not sure whether that's correct fix. However, it does help to resolve issue described at #33178 where we get "RuntimeError: cannot mutate tensors with frozen storage". So posting this change for illustration for above issue.
CC: @gante