-
Notifications
You must be signed in to change notification settings - Fork 498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch XLA erroneously update t
's views when it shouldn't (because t.data
isn't in-place modified)
#3392
Comments
@bdhirsh Could you take this one? |
Sorry for the delay, I will try to take a look today or tmr. |
@JackCaoG Thank you! |
@ronghanghu One thing I'm wondering - are you able to update your code to not use I tried running the following in colab, and it seems to run fine (the same results on cpu vs xla):
some context - |
Hi @bdhirsh, unfortunately, our use case has to rely on the This
Yes, the case you showed above works well in both CPU and PyTorch XLA. However, the following case doesn't work (giving different results between CPU/TPU)
Yeah, exactly! In our use case, we would like it to behave more like the previous So our case would require that the Thanks a lot for your help! |
Would
|
Hi @bdhirsh, unfortunately |
A toy use case is as follows -- for example, if we want to change a Variable's
Despite being a toy example, it reflects many real use cases where one wants to control the auto-grad behavior in PyTorch such as changing a variable's All these applications require us to edit a Variable's |
There are many profound use cases that involve editing a tensor/variable's
The expected behavior of tensor/variable in PyTorch (as in the case of CPU/GPU) is that accessing Although everything is okay when we don't use views, the IR generation of views in PyTorch XLA deviates from this behavior and therefore breaks many cases (including all the 3 examples above from Meta or Microsoft) involving |
OK, I think I root caused the issue using example
calls Line 888 in 3ceb226
and then Line 622 in 3ceb226
In
In this case |
Thanks, @JackCaoG! Even if it won't be part of the 1.11 release, it's still very helpful to have it in a nightly version. |
I think I fixed the data issue mentioned here locally, trying to see if that fix will also fix the shape issue we discussed in previous pr. The change is pretty small but very deep under our stack, I think I will just fix it in the nightly and provide you with a working wheel after I verified everything works. |
@ronghanghu All of the shape related issue also seems to be fixed. Let me clean up the pr and submit it. Hopefully we can get you the nightly wheel with the fix this Wednesday. |
This is very nice! Thanks again @JackCaoG |
@JackCaoG it seems that the The previous example
was working on nightly 20220408 but is broken on 20220413. After installing 20220413 builds with
now the example above gives
|
@ronghanghu Can you open a new issue and I can follow up. |
🐛 Bug
There seems to be a bug in the IR implementation of
.view
in PyTorch XLA that causes different behaviors between CPU/GPU and PyTorch XLA. This discrepancy gives different computation results when running the same model code on GPUs vs on TPUs (which might be a surprise to users). I found this bug when tracing a model that have different results between GPU and TPU.In short, the bug is that PyTorch XLA seems to update a tensor's view when it shouldn't (because the tensor isn't in-place modified). It is illustrated in the following example:
To Reproduce
On TPU (error behavior)
prints
As can be seen, TPU behavior from PyTorch XLA diverges from CPU (see below) since here
y2
is erroneously updated to 1's even ifx2.data
is not in-place modified. Here, the IR ofy2
is incorrectly built in PyTorch XLA, which results in different computation outputs between CPU/GPU and TPU and may break model training or inference.I think to fix this discrepancy and get the same computational outputs between CPU and TPU, it's necessary to distinguish whether t1's data is being in-place modified (such as
t.data.add_(1.0)
) or not (such ast.data = <another tensor>
). Only the former (in-place) case should trigger an update to all of the views associated with it, while the latter shouldn't cause views on previoust
to be updated (just like that it won't update the views when we call something liket = t + 1
).Expected behavior
On CPU (expected behavior)
prints
Here in PyTorch CPU/GPU, only
y0
andy1
are updated to 1's becausex0.data
andx1.data
are in-place modified. (The opx1.data += 1.0
has__iadd__
mapped toTHPVariable_add_
to in-place modifyx1.data
and it doesn't changeid(x1.data)
after the op). On the other hand,y2
is NOT (and should not be) updated becausex2.data
is assigned to a new memory buffer (which is not an in-place modification) whiley2
is a view on the old memory buffer.Note that accessing
x2
is equivalent (except for PyTorch's auto-grad tracing) to accessing the exact version ofx2.data
at that moment whenx2
being accessed. So in the example above, callingy2 = x2.view(2, 5)
should give the same view as callingy2 = x2.data.view(2, 5)
, except that the former allows us to back-propagate tox2
and put the gradient inx2.grad
(when we call something likey2.sum().backward()
) while the latter breaks auto-grad tracing onx2
and hasy2.requires_grad
beingFalse
. Several PyTorch-based packages rely on this behavior in PyTorch.Environment
tpu-vm-pt-1.10
Additional context
I suspect by fixing this problem, the original shape mismatch error in #3330 (comment) should also go away.
On CPU (expected behavior):
prints
Here
y3
did not change (as expected) becausex3.data = y3[:5] + 1
is not an in-place update.On TPU (error behavior):
gives
It crashes because PyTorch XLA erroneously tries to update
y3
(the view associated with the former version ofx3
) although it shouldn't updatey3
becausex3.data
is not being in-place modified (but is assigned to something else).On TPU (correct with an extra
xm.mark_step()
):Note that one can also get the correct results on PyTorch XLA by inserting an extra
xm.mark_step()
after the view op to break the IR graph. This prevents the crashing above and gives consistent results to CPU ones:prints
which is what one would expect from PyTorch and consistent with the behavior on CPU/GPU. However, in actual PyTorch XLA training it is often not practical to have an extra
xm.mark_step()
call in the middle of a model's forward pass, so it would be great if this example could also work without the extraxm.mark_step()
.So I think in summary, assignments to a tensor
t
's.data
in PyTorch shouldn't update any previous references (including views) tot
. Thetorch.Tensor
Python objectt
is just a shallow handle and it can point to different storage at different time, so any new assignments tot.data
should be independent fromt
's previous references, regardless of whether they aret.abs()
ort.view(-1)
. Only in-place modification tot.data
such ast.data.add_(1.0)
ort.data += 1.0
(or equivalently, in-place modification ont
such ast.add_(1.0)
ort += 1.0
) should affectt
's previous views.(I guess perhaps the fix would be removing the IR update on views when calling PyTorch XLA tensor's
THPVariable_set_data
and maybe add a few test cases on views such asy0
,y1
,y2
andy3
above?)The text was updated successfully, but these errors were encountered: