-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLM-R model OOM (PyTorch XLA limitations vs TF) #1870
Comments
There is a fundamental difference in PyTorch/XLA vs TF/TPU paradigms. Whereas PT/TPU builds all the graphs, initializes the weights, runs input pipelines etc and then feeds the TPUs. TF/TPU builds the TF graphs, converts it into XLA graphs and hands it over to the TPU for doing all the heavy lifting. Also, based on your Kaggle Kernel you posted, I assume that the SIGKILL was issued by due to RAM OOM, though we'd need to check kernel logs to know for sure (not the memory on the TPU core). @ifigotin may be working on bumping that limit but I'll let him chime in on that status. |
@jysohn23 Thanks for your reply. So does this difference in paradigm lead to some models not working, either due to OOM problems or other problems? If so, why? I would also note that Kaggle did give me a separate message saying:
So you are probably right, I realize now it probably is VM RAM OOM. How can I reduce the memory usage in this notebook? |
Can you try this?
|
And this:
|
@tmabraham Yeah, PT/TPU sometime uses more RAM on GCE VM whereas TF/TPU uses on TPU VM. But as long as you can get more RAM GCE VM you should be fine. |
@dlibenzi There is a total of 18 GB of memory:
The output of the second command is |
@jysohn23 Are there any steps I can take to reduce VM RAM usage in my notebook? |
How big are the |
@dlibenzi x_train is (435775,128) and y_train is (435775,1). Note the TF TPU kernel uses (435775,192). |
That gets encoded AFAICT. Can you print the final |
@dlibenzi I didn't understand. This is the tokenized/encoded shape. There are 435775 sentences and when encoded the representation has |
Just wanted to follow up on this... Is there any way this can be fixed? Or is this a limitation of PyTorch XLA? |
Also I will note that I get similar OOM problems when using bert-base-cased, which is supposed to be a model of the same size. I haven't investigated this thoroughly, so don't know if it's the same issue or not, |
Hi @tmabraham sorry for the late reply. As far as I can tell, it's a limitation of the VM that kaggle provides. Can you try running the exact same on a Colab notebook? Make a copy off of this colab notebook sample we provide: https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb and paste in your content. |
@jysohn23 Yes I was able to get something that seems to work in Colab.
Are there steps we can take to reduce PT/TPU RAM usage, or is this an inherent limitation of PyTorch XLA? |
To be clear it's not a limitation of pytorch/xla, but rather an imbalance in resources that are given out for free. On Kaggle they're granting a couple CPU cores and few GB ram to feed 8 TPU cores. You'd have the exact same problem if you were given 4 free V100 on Kaggle kernels with only couple CPU cores and few GB RAM. You can try creating model before forking processes as long as they're read only to reduce memory footprint caused by model weights. |
PT/TPU vs TF/TPU is not an apples-to-apples comparison as they have different paradigms: #1870 (comment) |
@jysohn23 I understand that, but I guess I was hoping there is still something that could be done to prevent the higher usage of RAM with PyTorch XLA vs TF TPU. I guess the answer is no. I will try to ask Kaggle if they can potentially increase the RAM for the VMs. If not, I will train in Colab. Thanks for the clarification! |
I ran your Kaggle kernel trying to add stuff like: |
We have made a change (which should be on nightly) that lowers the host memory utilization. |
I tried myself with nightly and it trains: https://www.kaggle.com/davidelibenzi/simple-xlmr-tpu-pytorch The trick is adding |
In that kernel: https://www.kaggle.com/davidelibenzi/simple-xlmr-tpu-pytorch nprocs is set to 1 instead of 8. I'm running into essentially the same issue in my own Kaggle kernel. Bert Base loads fine without OOM issues, however XLM-RoBERTa-Base does not. I don't even dare try XLM-RoBERTa-Large, which is able to load on TensorFlow as can be seen in this kernel: https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta |
Have tried with 8 as well. With nightly, it trains. |
So, we normally recommend this kind of structure: def _mp_fn(...):
model = Net().to(xla_device)
...
xmp.spawn(_mp_fn, ..., start_method='fork') But if the VM is RAM starved and the model are hefty parameter sizes, the recommended setup is more like: # Create once at global scope to share pages with child processes.
model = Net()
def _mp_fn(...):
model.to(xla_device)
...
xmp.spawn(_mp_fn, ..., start_method='fork') That, together with |
Oh that's awesome. @dlibenzi can you commit the updated Kaggle Kernel? I don't see --version nightly on the notebook when I click your link. |
Since input data is pretty tiny here, and given that is Colab this is running tight on memory, I suggest not to use the data loaders.
|
So being inspired from @dlibenzi 's streaming code idea, i am pretty sure it will seem complicated to many people who haven't worked directly with binary streaming (at least for me it was very True); So i have added inline comments where ever it was necessary, and put it up as a inline_commented_gist. I hope it helps everyone in future who comes across this issue! My question is this @dlibenzi ,
If we don't do bucketing while caching, then we can use
Thanks a lot! |
There is no need to buffer reads, for (at least) two reasons. |
We have added a new API which makes global model sharing and serialized to() less hacky:
|
@dlibenzi Thanks - does this solve some of the memory issues? |
Yes, it does. https://colab.sandbox.google.com/github/pytorch/xla/blob/master/contrib/colab/mnist-training.ipynb |
See [pr](#2107
<#2107>) and [colab](
https://colab.sandbox.google.com/github/pytorch/xla/blob/master/contrib/colab/resnet18-training.ipynb
).
…On Mon, May 25, 2020, 01:12 Philipp Singer ***@***.***> wrote:
@dlibenzi <https://github.com/dlibenzi> Thanks - is there any example on
how to use it?
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#1870 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDAQYJO3GFQQO6DAJNM3CLRTISALANCNFSM4L4N2UJQ>
.
|
Thanks @dlibenzi & @taylanbil - this is really helpful. Unfortunately, I am still struggling with saving the model weights in case of large models. The model is pulled to memory when doing so and the kernel (on Kaggle) fails. I think I asked that before, but is there any way to optimize / circumvent that? |
You are using |
Yes exactly. |
@dlibenzi, although super sub-optimal, could we do this?
etc |
No. |
I thought we send data to cpu from all devices, and only save from master after we send. https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_model.py#L631 Is that not true? |
No. |
I see, thx for the info.
…On Thu, May 28, 2020, 12:26 Davide Libenzi ***@***.***> wrote:
No.
We sync from all devices. Sync leaves data on device.
Then we fetch CPU data from master in order to call torch.save().
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1870 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDAQYPFOKOKV3B7QVIOFWDRT23FVANCNFSM4L4N2UJQ>
.
|
So the memory issue from the pytorch serialization comes from the fact that not only all the CPU tensors must be loaded in host memory at the same time, but pytorch uses a memory buffer to store them. I have created #2140 which streams tensors to CPU (and then to file) one at a time. |
@dlibenzi Thanks again for this! Just tried it and I get the following error: I am trying to save the model within the |
I was envisioning the API to explicitly fail on existing checkpoints, but I realized this is not the normal |
May i ask how did you fix this? |
Made the serialization module override an existing checkpoint. Like torch.save() would. |
hey @dlibenzi |
I am trying to train an XLM-R model in Kaggle Kernels with TPU enabled. There was a TF kernel that was able to do this successfully:
https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta
However, attempts to train a similar model with PyTorch XLA have not been successful due to OOM errors. I tried to keep the code as similar as possible and made sure all non-XLA variables (dataset, model, etc.) were defined globally so it wasn't replicated 8 times. I am actually using a smaller model of the model (base vs large) and am using much lower batch sizes. I even tried using multi-threading interface (which is apparently now deprecated) as I read multi-threading uses less memory. In all cases I get OOM errors. In most cases, it will load the model do the forward function, but fail when calculating the loss function. In some cases, it fails at loss.backward().
I have two questions related to this:
The text was updated successfully, but these errors were encountered: