-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VM][PooledAllocator] try reallocation once when OOM #8285
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @ganler for the contribution!
LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); | ||
LOG(WARNING) << "Trying to release all unused memory and reallocate..."; | ||
ReleaseAll(); | ||
buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What to expect if it still failed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it still fails, an InternalError will be thrown, causing a TVMError regarding OOM in the Python End.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, would that be better if we let ReleaseAll return the size it released and check if it is larger than the requested size? So that we can directly throw a message including both sizes without calling alloc again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. But IMHO this is not robust enough.
Say that we have 8 GB GPU memory, the PooledAllocator
cached 4 GB and we want to allocate 6 GB.
- Applying your idea,
ReleaseAll()
returns "4GB" which is less than "6GB", thus resulting in a failed allocation. - Instead, if we release unused memory and do re-allocation, "6GB" is very likely to be successfully allocated.
The big picture behind your idea is practical if we can have some APIs like "total_system_memory" and "available_system_memory", which may require introducing a series of runtime/driver libraries. e.g., cudaMemGetInfo
by CudaRT (user space) or NVML
(if some system privilege is allowed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. I don't have other comments then.
This change doesn't solve the issue in #8233, because Lines 196 to 197 in 4d9bc9b
That call is not protected by try/catch, so if almost all memory are held by
I think we need to revisit the memory release strategy of |
See TVM Discussion Topic.
This change aims to make TVM behaviour more robust when an OOM occurs and resolve a mysterious exception-uncaught bug.
Potential reviewers: @junrushao1994 @icemelon9 @jroesch