Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLM-R model OOM (PyTorch XLA limitations vs TF) #1870

Closed
tmabraham opened this issue Apr 3, 2020 · 146 comments
Closed

XLM-R model OOM (PyTorch XLA limitations vs TF) #1870

tmabraham opened this issue Apr 3, 2020 · 146 comments
Labels

Comments

@tmabraham
Copy link

I am trying to train an XLM-R model in Kaggle Kernels with TPU enabled. There was a TF kernel that was able to do this successfully:
https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta

However, attempts to train a similar model with PyTorch XLA have not been successful due to OOM errors. I tried to keep the code as similar as possible and made sure all non-XLA variables (dataset, model, etc.) were defined globally so it wasn't replicated 8 times. I am actually using a smaller model of the model (base vs large) and am using much lower batch sizes. I even tried using multi-threading interface (which is apparently now deprecated) as I read multi-threading uses less memory. In all cases I get OOM errors. In most cases, it will load the model do the forward function, but fail when calculating the loss function. In some cases, it fails at loss.backward().

I have two questions related to this:

  1. Is there a way to get PyTorch XLA to work with this model in TPU Kaggle Kernels as it was possible with TF?
  2. What are the limitations of PyTorch XLA usage on TPU compared to TF usage on TPU? Are there certain models that cannot be used?
@tmabraham
Copy link
Author

@jysohn23
Copy link
Collaborator

jysohn23 commented Apr 3, 2020

There is a fundamental difference in PyTorch/XLA vs TF/TPU paradigms. Whereas PT/TPU builds all the graphs, initializes the weights, runs input pipelines etc and then feeds the TPUs. TF/TPU builds the TF graphs, converts it into XLA graphs and hands it over to the TPU for doing all the heavy lifting.

Also, based on your Kaggle Kernel you posted, I assume that the SIGKILL was issued by due to RAM OOM, though we'd need to check kernel logs to know for sure (not the memory on the TPU core). @ifigotin may be working on bumping that limit but I'll let him chime in on that status.

@tmabraham
Copy link
Author

tmabraham commented Apr 3, 2020

@jysohn23 Thanks for your reply. So does this difference in paradigm lead to some models not working, either due to OOM problems or other problems? If so, why?

I would also note that Kaggle did give me a separate message saying:

Your notebook tried to allocate more memory than is available.

So you are probably right, I realize now it probably is VM RAM OOM. How can I reduce the memory usage in this notebook?

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 3, 2020

Can you try this?

!free -h

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 3, 2020

And this:

!cat /proc/cpuinfo | grep processor | wc -l

@jysohn23
Copy link
Collaborator

jysohn23 commented Apr 3, 2020

@tmabraham Yeah, PT/TPU sometime uses more RAM on GCE VM whereas TF/TPU uses on TPU VM. But as long as you can get more RAM GCE VM you should be fine.

@tmabraham
Copy link
Author

@dlibenzi There is a total of 18 GB of memory:

              total        used        free      shared  buff/cache   available
Mem:            18G        866M         12G        1.0M        5.5G         17G
Swap:            0B          0B          0B

The output of the second command is 4

@tmabraham
Copy link
Author

@jysohn23 Are there any steps I can take to reduce VM RAM usage in my notebook?

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 3, 2020

How big are the x_train and x_valid tensors?

@tmabraham
Copy link
Author

@dlibenzi x_train is (435775,128) and y_train is (435775,1). Note the TF TPU kernel uses (435775,192).

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 3, 2020

That gets encoded AFAICT. Can you print the final x_train shape?

@tmabraham
Copy link
Author

@dlibenzi I didn't understand. This is the tokenized/encoded shape. There are 435775 sentences and when encoded the representation has max_len equal to 128. This is the shape of the data that I use when creating my TensorDataset.

@tmabraham
Copy link
Author

Just wanted to follow up on this... Is there any way this can be fixed? Or is this a limitation of PyTorch XLA?

@tmabraham
Copy link
Author

Also I will note that I get similar OOM problems when using bert-base-cased, which is supposed to be a model of the same size. I haven't investigated this thoroughly, so don't know if it's the same issue or not,

@tmabraham
Copy link
Author

@jysohn23 @dlibenzi Sorry to keep bothering you. Just wanted to see if this is a known issue or if there is something I am supposed to do in my environment to prevent this issue.

@jysohn23
Copy link
Collaborator

jysohn23 commented Apr 7, 2020

Hi @tmabraham sorry for the late reply. As far as I can tell, it's a limitation of the VM that kaggle provides. Can you try running the exact same on a Colab notebook? Make a copy off of this colab notebook sample we provide: https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb and paste in your content.

@tmabraham
Copy link
Author

@jysohn23 Yes I was able to get something that seems to work in Colab.

@tmabraham Yeah, PT/TPU sometime uses more RAM on GCE VM whereas TF/TPU uses on TPU VM. But as long as you can get more RAM GCE VM you should be fine.

Are there steps we can take to reduce PT/TPU RAM usage, or is this an inherent limitation of PyTorch XLA?

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 7, 2020

So I ran your Kaggle notebook, and after tokenization, there are about 3GB left.

image

There are 11GB buffer cache, but the dataset seems pretty big.

@jysohn23
Copy link
Collaborator

jysohn23 commented Apr 7, 2020

To be clear it's not a limitation of pytorch/xla, but rather an imbalance in resources that are given out for free. On Kaggle they're granting a couple CPU cores and few GB ram to feed 8 TPU cores. You'd have the exact same problem if you were given 4 free V100 on Kaggle kernels with only couple CPU cores and few GB RAM. You can try creating model before forking processes as long as they're read only to reduce memory footprint caused by model weights.

@tmabraham
Copy link
Author

tmabraham commented Apr 7, 2020

@jysohn23 I still think this could be a limitation of PyTorch XLA because there is a TF kernel that works in Kaggle Kernels. Maybe there are some optimizations in TF that are not possible in PyTorch XLA?

I am creating the model before the forking processes.

@jysohn23
Copy link
Collaborator

jysohn23 commented Apr 7, 2020

PT/TPU vs TF/TPU is not an apples-to-apples comparison as they have different paradigms: #1870 (comment)

@tmabraham
Copy link
Author

tmabraham commented Apr 7, 2020

@jysohn23 I understand that, but I guess I was hoping there is still something that could be done to prevent the higher usage of RAM with PyTorch XLA vs TF TPU. I guess the answer is no.

I will try to ask Kaggle if they can potentially increase the RAM for the VMs. If not, I will train in Colab. Thanks for the clarification!

@jysohn23
Copy link
Collaborator

jysohn23 commented Apr 8, 2020

I ran your Kaggle kernel trying to add stuff like: del df_train, df_valid which freed up like 500 MBs, but it still looks like it runs out of memory right as we do model = mx.to(xm.xla_device()) in the xmp context, since there we create 8 copies of the model after sending to device 😞.

@dlibenzi
Copy link
Collaborator

We have made a change (which should be on nightly) that lowers the host memory utilization.
@tmabraham mind giving it a try on nightly?

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 10, 2020

I tried myself with nightly and it trains:

https://www.kaggle.com/davidelibenzi/simple-xlmr-tpu-pytorch

The trick is adding --version nightly to the env-setup script.

image

@Jerry2001Qu
Copy link

In that kernel: https://www.kaggle.com/davidelibenzi/simple-xlmr-tpu-pytorch

nprocs is set to 1 instead of 8.

I'm running into essentially the same issue in my own Kaggle kernel. Bert Base loads fine without OOM issues, however XLM-RoBERTa-Base does not. I don't even dare try XLM-RoBERTa-Large, which is able to load on TensorFlow as can be seen in this kernel: https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 10, 2020

Have tried with 8 as well. With nightly, it trains.
Rerunning now ...

@dlibenzi
Copy link
Collaborator

image

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 10, 2020

So, we normally recommend this kind of structure:

def _mp_fn(...):
  model = Net().to(xla_device)
  ...

xmp.spawn(_mp_fn, ..., start_method='fork')

But if the VM is RAM starved and the model are hefty parameter sizes, the recommended setup is more like:

# Create once at global scope to share pages with child processes.
model = Net()

def _mp_fn(...):
  model.to(xla_device)
  ...

xmp.spawn(_mp_fn, ..., start_method='fork')

That, together with fork(2) makes sure there will be only one copy of the models parameters on PyTorch CPU host memory.

@Jerry2001Qu
Copy link

Oh that's awesome. @dlibenzi can you commit the updated Kaggle Kernel? I don't see --version nightly on the notebook when I click your link.

@dlibenzi
Copy link
Collaborator

Since input data is pretty tiny here, and given that is Colab this is running tight on memory, I suggest not to use the data loaders.
I honestly suggest you take my notebook as baseline.
There are two major things there.

  1. Loading model onto device in a serialized fashion
  2. Creating a file indexed dataset to minimize initial footprint

@AdityaSoni19031997
Copy link

AdityaSoni19031997 commented Apr 21, 2020

So being inspired from @dlibenzi 's streaming code idea, i am pretty sure it will seem complicated to many people who haven't worked directly with binary streaming (at least for me it was very True); So i have added inline comments where ever it was necessary, and put it up as a inline_commented_gist.

I hope it helps everyone in future who comes across this issue!

My question is this @dlibenzi ,
We can stream data directly using numpy's memmap as well, or using lmdb. So is there any benefit that the above methods will fail to have when it's compared to the BytesIO streaming?

  • One obvious limitation is that when using memmap, you cannot cache tensor's of different shapes; Anything else in your experience?

If we don't do bucketing while caching, then we can use memmap directly wrapped in a torch.utils.data.Dataset as shown below, (and it should be as optimal as the BytesIO version is, right?)

class MyStreamingDataset(torch.utils.data.Dataset):
        # np.array(x_train).shape  -> (435712, 3)   
    
    def __init__(self, data_memmap, target_memmap, shape=()):
        self.data = np.memmap(data_memmap, shape=shape, mode="r", dtype="int32")
        self.target = np.memmap(target_memmap, shape=(shape[1],), mode="r", dtype="int32")
        self.shape = shape
    
    def __len__(self):
        return self.shape[1]
    
    def __getitem__(self, idx):
        # mem-map contains input_ids, masks, targets in that index order;
        return np.array(self.data[0][idx]), np.array(self.data[1][idx]), np.array(self.target[idx])

Thanks a lot!

@dlibenzi
Copy link
Collaborator

There is no need to buffer reads, for (at least) two reasons.
First, there are at least two buffering layers underneath already. The Python one and the OS one.
Second, with shuffling (which is default-true), read offsets are all over the places, so there is very little locality that caching can exploit.

@dlibenzi
Copy link
Collaborator

dlibenzi commented May 4, 2020

We have added a new API which makes global model sharing and serialized to() less hacky:

WRAPPED_MODEL = MpModelWrapper(MyNetwork())

@psinger
Copy link

psinger commented May 25, 2020

@dlibenzi Thanks - does this solve some of the memory issues?

@dlibenzi
Copy link
Collaborator

Yes, it does.
It avoids all processes doing the send-to-device at the same time.
We have also update the Colab with the example usage:

https://colab.sandbox.google.com/github/pytorch/xla/blob/master/contrib/colab/mnist-training.ipynb

@taylanbil
Copy link
Collaborator

taylanbil commented May 25, 2020 via email

@psinger
Copy link

psinger commented May 27, 2020

Thanks @dlibenzi & @taylanbil - this is really helpful.

Unfortunately, I am still struggling with saving the model weights in case of large models. The model is pulled to memory when doing so and the kernel (on Kaggle) fails.

I think I asked that before, but is there any way to optimize / circumvent that?

@dlibenzi
Copy link
Collaborator

You are using xm.save() right?

@psinger
Copy link

psinger commented May 28, 2020

Yes exactly.

@taylanbil
Copy link
Collaborator

@dlibenzi, although super sub-optimal, could we do this?

  • xm.save from master (ordinal 0) while others are sitting in rendezvous-1
  • ordinal 0 reaches rv-1
  • ordinal 1 xm.saves while ordinals 0, 2-7 sits in rv-2
  • ordinal 1 reaches rv-2
    -ordinal 2 xm.saves while others sits in rv-3
    ...

etc

@dlibenzi
Copy link
Collaborator

No.
We already have only ordinal 0 actually fetch device data to CPU tensors, and save.
The issue is that at that point the memory is already low, and even if only one process fetches the tensors to CPU, it OOMs.

@taylanbil
Copy link
Collaborator

I thought we send data to cpu from all devices, and only save from master after we send.

https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_model.py#L631

Is that not true?

@dlibenzi
Copy link
Collaborator

No.
We sync from all devices. Sync leaves data on device.
Then we fetch CPU data from master in order to call torch.save().

@taylanbil
Copy link
Collaborator

taylanbil commented May 28, 2020 via email

@dlibenzi
Copy link
Collaborator

So the memory issue from the pytorch serialization comes from the fact that not only all the CPU tensors must be loaded in host memory at the same time, but pytorch uses a memory buffer to store them.
This effectively doubles the memory required:

https://github.com/pytorch/pytorch/blob/e029d678b63d8970d92e7d9713af74eb0ea69ad8/torch/serialization.py#L462

I have created #2140 which streams tensors to CPU (and then to file) one at a time.
But this requires using the matching load() API.

@psinger
Copy link

psinger commented Jun 5, 2020

@dlibenzi Thanks again for this!

Just tried it and I get the following error:
FileExistsError: [Errno 17] File exists: 'model.bin.tensors'

I am trying to save the model within the _run method, same spot as where I would call xm.save.

@dlibenzi
Copy link
Collaborator

dlibenzi commented Jun 5, 2020

I was envisioning the API to explicitly fail on existing checkpoints, but I realized this is not the normal torch.save() behavior.
Let me fix that ...

@dlibenzi
Copy link
Collaborator

dlibenzi commented Jun 5, 2020

#2173

@zcain117 zcain117 added the kaggle label Jul 1, 2020
@garyongguanjie
Copy link

@dlibenzi Thanks again for this!

Just tried it and I get the following error:
FileExistsError: [Errno 17] File exists: 'model.bin.tensors'

I am trying to save the model within the _run method, same spot as where I would call xm.save.

May i ask how did you fix this?

@dlibenzi
Copy link
Collaborator

Made the serialization module override an existing checkpoint. Like torch.save() would.

@kbrajwani
Copy link

hey @dlibenzi
i was trying to do gpt large model in kaggle its running fine but i am getting memory issues. When i try to find its shows me that running this script
! curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
! python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
takes 3gb ram.
so my question is how can i clear that ram.
Screenshot (2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests