Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

High memory usage with bucketing #5035

Closed
tdomhan opened this issue Feb 15, 2017 · 21 comments
Closed

High memory usage with bucketing #5035

tdomhan opened this issue Feb 15, 2017 · 21 comments

Comments

@tdomhan
Copy link
Contributor

tdomhan commented Feb 15, 2017

When using the bucketing module I'd expect the memory usage to be about the same as when using the normal module unrolled to the largest bucket size. However we observe unusually high GPU memory usage in MxNet when using multiple buckets.
This can be reproduced/observed with the lstm_bucketing.py example from the latest MXNet commit as such:
in examples/rnn/lstm_bucketing.py change:

num-layers to 4
num-hidden to 1024
num-embed to 512

When using multiple buckets (see line 49), overall memory usage is 1419MB.
When changing line 49 to only use a single bucket (e.g. 60), overall memory usage is only 1185MB.

It should be noted that the initial memory usage for bucketing is the same (1185MB), but after a couple of batches the memory usage increases. We suspect this is due to the BucketingModule binding another sub module when a new bucket size is given by the data iterator and memory sharing across modules isn't working properly.

While for this model the difference is only 300 MB, we observed much higher differences in practice, making it difficult to train any reasonably sized model with bucketing.

Note: the default bucket key is of course the largest bucket.

@tdomhan tdomhan changed the title Increased memory usage with bucketing High memory usage with bucketing Feb 15, 2017
@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 15, 2017

Diving a little deeper into this issue. This is my current understanding of memory sharing:

  • python land: BucketingModule
    • BucketingModule has a _curr_module member that is passed to module.bind
    • initially this member is the default module, however each call to self.switch_bucket will set it to a different module
  • C++, GraphExecutor class
    • GraphExecutor::Init gets a shared_exec that is the executor of the shared module pass to module.bind
    • GraphExecutor::InitDataEntryMemory gets the data_pool_ member of the shared executor
    • data_pool_ is a vector of ndarrays
    • After GraphExecutor::InitDataEntryMemory is called data_pool_ will contain all ndarrays borrowed from the shared executor and also potentially new ndarrays it created. Important: As no all shapes fit necessarily it might have fewer entries than the original data_pool_ coming from the shared executor.

So with the above one potential problem occurs when we first see a small bucket and then a larger bucket subsequently. This is because in this scenario curr_module in BucketingModule will point to the small bucket, which has a data_pool that is smaller than the data_pool_ of the default bucket module.

Given this I tried to modify BucketingModule to always pass the default module as shared module (as my assumption was that the default module will always occupy more space than any given module). Now this actually works, but only when turning off all memory optimizations (NNVM_EXEC_ENABLE_INPLACE=false NNVM_EXEC_MATCH_RANGE=0). If they these are not turned off we still allocate new memory in InitDataEntryMemory as the shapes across buckets don't match up anymore.

@piiswrong
Copy link
Contributor

piiswrong commented Feb 15, 2017

Thanks for the analysis. Haibin will be working on this.

In the mean time the cudnn bucketing example should work better since there are less memory blobs allocated https://github.com/dmlc/mxnet/pull/5004/files

@eric-haibin-lin

also see #4795

@feiyulv
Copy link

feiyulv commented Feb 16, 2017

@tdomhan I have met the same problem. For the nnvm version, even the default bucket key is the largest, some memory will be still allocated with new smaller keys, (may be caused by the memory plan strategy). So the memory will increase after some batches.
My current solution is:

  1. set the gpu storage to NaiveStorage
  2. bind a new executor for each key, and release the previous one. only keep the largest key executor
    @piiswrong @eric-haibin-lin
    Hope this problem can be solved soon

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 16, 2017

I'm glad I'm not the only one who has this issue. I still need to check whether this is problem does not exist in 0.8. One potentially interesting thing is that in 0.8 the allocator has access to the shared_mem_, whereas in nnvm it doesn't actually when planning the memory.

@feiyulv
Copy link

feiyulv commented Feb 16, 2017

Problem still exits in no nnvm version, but the increasing is not very obvious.

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 20, 2017

Here are the counts of different ndarray sizes in the default bucket's memory pool (for some arbitrary model and not the RNN example, but that shouldn't change the point):

        {0: 4,
         256: 1,
         512: 4,
         1024: 2,
         16.384: 836,
         30.720: 121,
         32.768: 542,
         49.152: 59,
         65.536: 179,
         131.072: 59,
         983040: 2,
         1.966.080: 5,
         30.842.880: 2}

vs the ones from a different bucket of smaller size.

       {0: 3,
         256: 1,
         512: 5,
         10.240: 38,
         16.384: 279,
         32.768: 160,
         49.152: 19,
         65.536: 40,
         131.072: 40,
         327.680: 3,
         655.360: 23,
         10.280.960: 2}

The meaning of the above number is e.g. that for the default bucket we got 1 ndarray of size 256 (bytes). Now one can see that they are not really compatible with each other, as for example the 23 655360 byte arrays can't be fit into the ndarrays of the default bucket. This is probably due to the memory planning being done independently leading to incompatible sets of ndarrays.

@eric-haibin-lin
Copy link
Member

@tdomhan Thanks for the analysis. The problem with bucketing is that for each bucket, currently we first plan its memory unaware of the shared memory pool information. After the memory is planned, we try to reuse what's available in the shared memory pool. Because graph of big bucket and small bucket don't necessarily produce the same memory plan, it doesn't guarantee that no extra memory is allocated. I'm working on this to fix it.

@feiyulv
Copy link

feiyulv commented Feb 21, 2017

set the storage type to naive and don't keep the executor for each bucket key will help. The memory won't increase though floating in small range @tdomhan
@eric-haibin-lin I don't understand the purpose of the PooledStorage, since the shared executor is some kind of pool. Is that useful for the temp space allocated by forward/backward pass?

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 21, 2017

as the naive storage manager will do a cudaMalloc with each ndarray created (unlike the PooledStorage) I'm guessing this will decrease speed quite a bit. I don't have any numbers on this though. What's your experience with that?
Why aren't you using the PooledStorage in combination with creating a new executor? That should also work, but avoid the cudaMallocs. You would still have the overhead of creating a new executor per batch. Ideally we'd be able to fix the sharing of memory between different graphs in order to avoid both the overhead of repeatably creating executors as well as repeatedly calling cudaMalloc.

From my understanding PooledStorage is useful for use cases where you create many short lived executors (e.g. minpy or your use case) in order to reuse memory from previously created ndarrays. For the memory sharing between different buckets the ndarrays haven't been deallocated yet, yet we still want to share them between different graphs.

@eric-haibin-lin
Copy link
Member

@feiyulv Regarding PooledStorage, I believe its purpose it to avoid excessive cudaMalloc for each of the NDArrays created. We should try to reuse the NDArray malloc'ed if any in the pool. Excessive cudaMallocs could lead to much slower run time performance, depending on the workload.

@feiyulv
Copy link

feiyulv commented Feb 22, 2017

@eric-haibin-lin @tdomhan thx
@tdomhan My experience is it affects a little speed with NaiveStorage, since we use a default shared executor for different keys, only a small amount of memory will be allocated for a new executor. PooledStorage with new executor won't work. The reason of memory increasing is the memory allocated by new executor using PooledStorage can't be released, and can't be reused for the future executors.

@eric-haibin-lin
Copy link
Member

eric-haibin-lin commented Feb 24, 2017

Found some inefficiency in the system and made a few changes related to memory allocation in MXNet:

  1. In the bucketing module, curr_module is passed in as the shared_module, but instead the module with default_bucket_key should be passed. Link.
  2. When the memory pool of the default bucket module doesn't hold sufficient memory for other bucket to bind, extra NDArrays are allocated. But these NDArrays are not added back to the pool. Link
  3. In the PlanMemory pass in nnvm, the memory allocator is not aware of the allocated memory pool of the shared module. Link
  4. The NNVM_EXEC_MATCH_RANGE variable has impact on the result of memory planning. Instead of letting the user to choose it, the backend could just try different values and choose the best one to automate the process. Some users are not even aware of this variable. Link

Fixing 1 and 2 reduce the memory quite a lot, while 3 and 4 bring marginal reduction if 1 and 2 are fixed (5% ~ 10%).

Benchmark result on LSTM workload:

Version ( 22673b6 (baseline) 1 1 + 2 1 + 2 + 3 1 + 2 + 3 + 4
Memory (MB) > 12288 (Out of Memory) 4297 1956 1817 1754

Benchmark result on Neural Style workload

Version ( 22673b6 (baseline) 1 + 2 + 3 + 4
Memory (MB) > 12288 (Out of Memory) 2034
  • LSTM configuration: python lstm_bucketing.py --gpus=0 --num-layers=4 --num-hidden=1024 --num-embed=512 --num-epochs=1
  • Neural style uses default configuration

@whaozl @tdomhan @piiswrong @tqchen @mli

eric-haibin-lin added a commit to eric-haibin-lin/mxnet that referenced this issue Feb 24, 2017
Always pass in the bucket with default_bucket_key as the shared_module while binding new buckets

Imbalance version of shared pool during plan memory

Auto search and updated shared mem pool

Cleanup unused code

Fix compilation error
eric-haibin-lin added a commit to eric-haibin-lin/mxnet that referenced this issue Feb 24, 2017
Always pass in the bucket with default_bucket_key as the shared_module while binding new buckets

Imbalance version of shared pool during plan memory

Auto search and updated shared mem pool

Cleanup unused code

Fix compilation error
@piiswrong
Copy link
Contributor

How long does 4 take?

@whaozl
Copy link

whaozl commented Feb 24, 2017

@eric-haibin-lin the bug is fixed ? the new version is ? I check the version which you said and take a look.

@eric-haibin-lin
Copy link
Member

eric-haibin-lin commented Feb 24, 2017

@piiswrong
I profiled the total time spent on switch_bucket with 4 turned on. It introduces less than 1 second overhead per epoch for this LSTM network. The total training time per epoch is around 5 minutes.

@whaozl I'll add one more test case before merging this in. Will keep you posted.

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 26, 2017

The memory planning is only run once per bucket, no? So the overhead shouldn't be per epoch but just a fixed time at the beginning of training. I assume that for a use case like minpy or any other use case where you dynamically construct a graph the overhead should be much higher. So does it take about 1s to do the sweep?

@piiswrong
Copy link
Contributor

If you are building graphs for trees for parsing then you can have a different graph for every batch of data. In that case 1/40s per bucket overhead is non negligible.

I think a better strategy is to run more graph optimization for the first 100-1000 bind then turn off

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 27, 2017

One other problem other than what was listed in 1-4 is the order of allocations in GraphExecutor::InitDataEntryMemory. Arrays were allocated in the order they're encountered in the graph. This could lead to situations where large ndarrays from the shared pool were used for much smaller ndarrays. Allocating the largest ndarrays first will further reduce the memory consumption. Here's a PR with the change:
#5161

@eric-haibin-lin
Copy link
Member

Yeah that's a good point! I actually made the same change in my PR over the weekend. https://github.com/dmlc/mxnet/pull/5133/files#diff-d8b5a5b027d00584737fb6486cba38b9R488

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 27, 2017 via email

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 28, 2017

are there any remaining blockers for merging #5133 ?

piiswrong pushed a commit to piiswrong/mxnet that referenced this issue Mar 30, 2017
* Imbalance version of shared pool during plan memory

* Bug fix for no shared_pool case

* Auto search and updated shared mem pool

* Cleanup unused code

* Cleanup logging code

* Add unit test for shared storage

* Remove shared pool in PlanMemory. Fix lint warnings

* Fix lint warnings

* Use reference instead of ptrs
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants