Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch XLA erroneously update t's views when it shouldn't (because t.data isn't in-place modified) #3392

Closed
ronghanghu opened this issue Feb 23, 2022 · 18 comments · Fixed by #3411
Assignees

Comments

@ronghanghu
Copy link
Collaborator

ronghanghu commented Feb 23, 2022

🐛 Bug

There seems to be a bug in the IR implementation of .view in PyTorch XLA that causes different behaviors between CPU/GPU and PyTorch XLA. This discrepancy gives different computation results when running the same model code on GPUs vs on TPUs (which might be a surprise to users). I found this bug when tracing a model that have different results between GPU and TPU.

In short, the bug is that PyTorch XLA seems to update a tensor's view when it shouldn't (because the tensor isn't in-place modified). It is illustrated in the following example:

To Reproduce

On TPU (error behavior)

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()

x0 = torch.zeros(10, device=device)
y0 = x0.view(2, 5)
# this should update y0 to all 1's (since x0.data is in-place modified)
x0.data.add_(1.0)

x1 = torch.zeros(10, device=device)
y1 = x1.view(2, 5)
# this should update y1 to all 1's (since __iadd__ maps to `THPVariable_add_` and is in-place modification)
x1.data += 1.0

x2 = torch.zeros(10, device=device)
y2 = x2.view(2, 5)
# this should NOT update y2 because x2.data now points to a new memory buffer (not in-place modification)
# while y2 is a view on top of the old memory buffer
x2.data = x2.data + 1.0

xm.mark_step()
print(f"y0: {y0}")
print(f"y1: {y1}")
print(f"y2: {y2}")

prints

y0: tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], device='xla:1')
y1: tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], device='xla:1')
y2: tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], device='xla:1')

As can be seen, TPU behavior from PyTorch XLA diverges from CPU (see below) since here y2 is erroneously updated to 1's even if x2.data is not in-place modified. Here, the IR of y2 is incorrectly built in PyTorch XLA, which results in different computation outputs between CPU/GPU and TPU and may break model training or inference.

I think to fix this discrepancy and get the same computational outputs between CPU and TPU, it's necessary to distinguish whether t1's data is being in-place modified (such as t.data.add_(1.0)) or not (such as t.data = <another tensor>). Only the former (in-place) case should trigger an update to all of the views associated with it, while the latter shouldn't cause views on previous t to be updated (just like that it won't update the views when we call something like t = t + 1).

Expected behavior

On CPU (expected behavior)

import torch

device = torch.device("cpu")

x0 = torch.zeros(10, device=device)
y0 = x0.view(2, 5)
# this should update y0 to all 1's (since x0.data is in-place modified)
x0.data.add_(1.0)

x1 = torch.zeros(10, device=device)
y1 = x1.view(2, 5)
# this should update y1 to all 1's (since __iadd__ maps to `THPVariable_add_` and is in-place modification)
x1.data += 1.0

x2 = torch.zeros(10, device=device)
y2 = x2.view(2, 5)
# this should NOT update y2 because x2.data now points to a new memory buffer (not in-place modification)
# while y2 is a view on top of the old memory buffer
x2.data = x2.data + 1.0

print(f"y0: {y0}")
print(f"y1: {y1}")
print(f"y2: {y2}")

prints

y0: tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
y1: tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
y2: tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

Here in PyTorch CPU/GPU, only y0 and y1 are updated to 1's because x0.data and x1.data are in-place modified. (The op x1.data += 1.0 has __iadd__ mapped to THPVariable_add_ to in-place modify x1.data and it doesn't change id(x1.data) after the op). On the other hand, y2 is NOT (and should not be) updated because x2.data is assigned to a new memory buffer (which is not an in-place modification) while y2 is a view on the old memory buffer.

Note that accessing x2 is equivalent (except for PyTorch's auto-grad tracing) to accessing the exact version of x2.data at that moment when x2 being accessed. So in the example above, calling y2 = x2.view(2, 5) should give the same view as calling y2 = x2.data.view(2, 5), except that the former allows us to back-propagate to x2 and put the gradient in x2.grad (when we call something like y2.sum().backward()) while the latter breaks auto-grad tracing on x2 and has y2.requires_grad being False. Several PyTorch-based packages rely on this behavior in PyTorch.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM with tpu-vm-pt-1.10
  • torch_xla version: 1.10

Additional context

I suspect by fixing this problem, the original shape mismatch error in #3330 (comment) should also go away.

On CPU (expected behavior):

import torch

device = torch.device("cpu")

x3 = torch.zeros(10, device=device)
y3 = x3.view(-1)
# This should NOT update y3 because `x3.data` is not in-place modified
x3.data = y3[:5] + 1

print(f"y3: {y3}")

prints

y3: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

Here y3 did not change (as expected) because x3.data = y3[:5] + 1 is not an in-place update.

On TPU (error behavior):

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()

x3 = torch.zeros(10, device=device)
y3 = x3.view(-1)
# This should NOT update y3 because `x3.data` is not in-place modified
x3.data = y3[:5] + 1  # crashes on XLA

# print(f"y3: {y3}")

gives

RuntimeError: torch_xla/csrc/tensor.cpp:781 : Check failed: xla::util::Multiply<xla::int64_t>(ir_value.shape().dimensions()) == xla::util::Multiply<xla::int64_t>(view->shape().dimensions()) (5 vs. 10)
*** Begin stack trace ***
        tensorflow::CurrentStackTrace()
        torch_xla::XLATensor::UpdateView(std::shared_ptr<torch_xla::View>, torch_xla::ir::Value) const
        torch_xla::XLATensor::SetIrValue(torch_xla::ir::Value)
        torch_xla::XLATensor::ShallowCopyTo(torch_xla::XLATensor*) const
        torch_xla::XLATensorImpl::shallow_copy_from(c10::intrusive_ptr<c10::TensorImpl, c10::detail::intrusive_target_default_null_type<c10::TensorImpl> > const&)

        THPVariable_set_data(THPVariable*, _object*, void*)

        PyObject_SetAttr
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        PyEval_EvalCode



        PyRun_InteractiveLoopFlags
        PyRun_AnyFileExFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

It crashes because PyTorch XLA erroneously tries to update y3 (the view associated with the former version of x3) although it shouldn't update y3 because x3.data is not being in-place modified (but is assigned to something else).

On TPU (correct with an extra xm.mark_step()):

Note that one can also get the correct results on PyTorch XLA by inserting an extra xm.mark_step() after the view op to break the IR graph. This prevents the crashing above and gives consistent results to CPU ones:

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()

x3 = torch.zeros(10, device=device)
y3 = x3.view(-1)
xm.mark_step()  # this extra `xm.mark_step()` prevents the crash we saw earlier

# This should NOT update y3 because `x3.data` is not in-place modified
x3.data = y3[:5] + 1  # doesn't crash in this case
print(f"y3: {y3}")

prints

y3: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='xla:1')

which is what one would expect from PyTorch and consistent with the behavior on CPU/GPU. However, in actual PyTorch XLA training it is often not practical to have an extra xm.mark_step() call in the middle of a model's forward pass, so it would be great if this example could also work without the extra xm.mark_step().


So I think in summary, assignments to a tensor t's .data in PyTorch shouldn't update any previous references (including views) to t. The torch.Tensor Python object t is just a shallow handle and it can point to different storage at different time, so any new assignments to t.data should be independent from t's previous references, regardless of whether they are t.abs() or t.view(-1). Only in-place modification to t.data such as t.data.add_(1.0) or t.data += 1.0 (or equivalently, in-place modification on t such as t.add_(1.0) or t += 1.0) should affect t's previous views.

(I guess perhaps the fix would be removing the IR update on views when calling PyTorch XLA tensor's THPVariable_set_data and maybe add a few test cases on views such as y0, y1, y2 and y3 above?)

@JackCaoG
Copy link
Collaborator

@bdhirsh Could you take this one?

@ronghanghu
Copy link
Collaborator Author

Hi @bdhirsh and @JackCaoG, do we have a plan and an ETA for this issue? (It is currently blocking our use case on scaling.)

I'm happy to provide additional info or examples if needed. Thanks!

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 2, 2022

Sorry for the delay, I will try to take a look today or tmr.

@ronghanghu
Copy link
Collaborator Author

@JackCaoG Thank you!

@bdhirsh
Copy link
Collaborator

bdhirsh commented Mar 2, 2022

@ronghanghu One thing I'm wondering - are you able to update your code to not use .data?

I tried running the following in colab, and it seems to run fine (the same results on cpu vs xla):

dev = xm.xla_device()

a = torch.zeros(4, device=dev)
b = a.view(2, 2)
b= b + 1
print(a)

> tensor([0., 0., 0., 0.], device='xla:1')  # a wasn't updated, since we set b to be a totally new tensor

a.data is a pretty weird tensor property (AFAIK it shouldn't be necessary to do anything you need) - it effectively returns a new tensor, but:

  • it points to the same storage as a
  • it removes autograd metadata from the new tensor
  • it removes inplace/view autograd metadata (VersionCounter) from the new tensor

some context - tensor.data field is a relic of the tensor-variable merge. It made more sense when we used to have Variable as a separate concept from python (you would wrap your Variable(Tensor) in order to make your tensor "trainable"), but makes less sense now that requires_grad is just a field on the tensor.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Mar 3, 2022

@ronghanghu One thing I'm wondering - are you able to update your code to not use .data?

Hi @bdhirsh, unfortunately, our use case has to rely on the .data. This allows us to keep the same variable a (e.g. so that it can be added to an optimizer, a module's parameter, etc) but point it to different storage at different times by changing a.data, and also to manipulate a variable's underlying storage without being traced by auto-grad.

This .data assignment or editing is a frequent use case in gradient checkpointing and scaling scenarios (for example, there are many such usages in the fairscale library like https://github.com/facebookresearch/fairscale/blob/2ca4f0eefa0f6b83f4448e1c10c3624ae5d24dee/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1863).

I tried running the following in colab, and it seems to run fine (the same results on cpu vs xla):

Yes, the case you showed above works well in both CPU and PyTorch XLA. However, the following case doesn't work (giving different results between CPU/TPU)

import torch
import torch_xla.core.xla_model as xm

dev = xm.xla_device()
# dev = torch.device("cpu")

a = torch.zeros(4, device=dev)
b = a.view(2, 2)
a.data = a.data + 1
print(b)  # prints 1 on XLA, 0 on CPU

some context - tensor.data field is a relic of the tensor-variable merge. It made more sense when we used to have Variable as a separate concept from python (you would wrap your Variable(Tensor) in order to make your tensor "trainable"), but makes less sense now that requires_grad is just a field on the tensor.

Yeah, exactly! In our use case, we would like it to behave more like the previous Variable class (a shallow handle with its .data holding the real storage), and allow its .data to point to different actual storage at different times. Also, after the tensor-variable merge, it is necessary to access and edit t.data to manage/manipulate PyTorch's auto-grad behavior (e.g. using a tensor t as a variable for auto-grad tracing, while editing t.data to do something that we hope not to be traced by auto-grad). In general, we often still need to separate Tensor (data or storage properties) from Variable (auto-grad properties) and need .data for this purpose.

So our case would require that the .data assignment behavior in PyTorch XLA is the same as on CPUs. (Currently they are the same except when using views.)

Thanks a lot for your help!

@bdhirsh
Copy link
Collaborator

bdhirsh commented Mar 3, 2022

Would detach() get you the behavior that you want? (which I think should get you similar behavior to .data - it returns you a new tensor that views the same "data" as the old tensor, but detached from the autograd graph).

c = torch.zeros(4, device=dev)
d = c.view(2, 2)
d = d.detach() + 1
print(c) # on colab this printed all zeros

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Mar 3, 2022

Hi @bdhirsh, unfortunately detach() cannot get what we need. We need to edit a Variable's underlying .data to point it to a new storage (while keeping the same Variable python object for auto-grad, module parameter, and optimizer), so detach cannot let us accomplish this purpose (it only allows us to access a Variable's .data but not edit this handle).

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Mar 3, 2022

A toy use case is as follows -- for example, if we want to change a Variable's .data to point to different underlying tensors between forward and backward pass (this examples shows using a positive x.data for forward pass and a negative x.data for backward pass; it reflect real uses cases that do more sophisticated editing):

import torch

# Forward pass has x.data > 0
x = torch.ones(4, requires_grad=True)
loss = torch.abs(x).sum()  # note that `abs` will use `x.data` in its backward pass

# a dummy computation to simulate something involving views (OK on CPU but crashes on XLA)
y = x.view(2, 2)
loss += 0 * y.detach().sum()

# Change Variable x to have `x.data < 0` (of a broadcast-able shape) in backward pass
# This is where we need *edit* access to `x.data`
x.data = -100 * torch.ones(1)
loss.backward()

print(loss)
# prints `tensor(4., grad_fn=<SumBackward0>)` because `x.data` are +1's in `abs` forward pass

print(x.grad)
# prints `tensor([-1., -1., -1., -1.])` because we changed `x.data` to -100's in `abs` backward pass

Despite being a toy example, it reflects many real use cases where one wants to control the auto-grad behavior in PyTorch such as changing a variable's .data between forward and backward passes. For example, in Microsoft's ZeRO optimizer, a layer parameter variable's .data storage is first destroyed after a layer's forward pass (to save memory for the next layer's computation) and then reconstructed before a layer's backward pass. There are many other use cases involving .data in auto-grad control.

All these applications require us to edit a Variable's .data. Currently this is okay in PyTorch XLA for most scenarios, except when views are involved.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Mar 3, 2022

There are many profound use cases that involve editing a tensor/variable's .data. Below are just a few examples

The expected behavior of tensor/variable in PyTorch (as in the case of CPU/GPU) is that accessing x should be equivalent (except for PyTorch's auto-grad tracing) to accessing the exact version of x.data at that moment when x being accessed. So in the example above, calling y = x.view(2, 2) should give the same view as calling y = x.data.view(2, 2) (except that the former allows us to back-propagate to x and put the gradient in x.grad).

Although everything is okay when we don't use views, the IR generation of views in PyTorch XLA deviates from this behavior and therefore breaks many cases (including all the 3 examples above from Meta or Microsoft) involving .data access (when applying to a neural network with .view ops).

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 4, 2022

OK, I think I root caused the issue

using example

import torch
import torch_xla.core.xla_model as xm

dev = xm.xla_device()
# dev = torch.device("cpu")

a = torch.zeros(4, device=dev)
b = a.view(2, 2)
a.data = a.data + 1
print(b)  # prints 1 on XLA, 0 on CPU
a.data = a.data + 1

calls

void XLATensor::ShallowCopyTo(XLATensor* dest) const {

and then

void XLATensor::SetIrValue(ir::Value ir_value) {

In SetIrValue we have a incorrect assumption that

  if (data()->view != nullptr) {
    // If we have an active view, and a SetIrValue() happens, it means we are
    // within an in-place execution context, and we need to update the view's
    // alias as well.
    data()->view = UpdateView(data()->view, std::move(ir_value));
    data()->generation += 1;
  } 

In this case SetIrValue is call but we are not doing an inplace operation. I think if we can distinguish between these two case we can solve this issue. I will try to fix it tmr. Not sure if this change(if worked) can make it to the 1.11 since release will happen in next Friday..

@ronghanghu
Copy link
Collaborator Author

Thanks, @JackCaoG! Even if it won't be part of the 1.11 release, it's still very helpful to have it in a nightly version.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 5, 2022

I think I fixed the data issue mentioned here locally, trying to see if that fix will also fix the shape issue we discussed in previous pr. The change is pretty small but very deep under our stack, I think I will just fix it in the nightly and provide you with a working wheel after I verified everything works.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 7, 2022

@ronghanghu All of the shape related issue also seems to be fixed. Let me clean up the pr and submit it. Hopefully we can get you the nightly wheel with the fix this Wednesday.

@ronghanghu
Copy link
Collaborator Author

This is very nice! Thanks again @JackCaoG

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 14, 2022

@JackCaoG it seems that the .data assignment example above (w/ different shape) has been broken again and is now breaking XLA FSDP implementation in (#3431)

The previous example

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()

x3 = torch.zeros(10, device=device)
y3 = x3.view(-1)
# This should NOT update y3 because `x3.data` is not in-place modified
x3.data = y3[:5] + 1

print(f"y3: {y3}")

was working on nightly 20220408 but is broken on 20220413.

After installing 20220413 builds with

sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl

now the example above gives

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 655, in __format__
    return object.__format__(self, format_spec)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 338, in __repr__
    return torch._tensor_str._str(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 439, in _str
    return _str_intern(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 325, in _str_intern
    self = self.to('cpu')
RuntimeError: INVALID_ARGUMENT: From /job:localservice/replica:0/task:0:
2 root error(s) found.
  (0) INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

         [[{{node XRTExecute}}]]
         [[XRTExecute_G12]]
  (1) INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

         [[{{node XRTExecute}}]]
0 successful operations.
0 derived errors ignored.
Recent warning and error logs:
  OP_REQUIRES failed at tpu_execute_op.cc:266 : INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

@JackCaoG
Copy link
Collaborator

@ronghanghu Can you open a new issue and I can follow up.

@ronghanghu
Copy link
Collaborator Author

Thanks, @JackCaoG! I just submitted a new issue to #3502.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants