Skip to content

Conversation

@22quinn
Copy link
Collaborator

@22quinn 22quinn commented Aug 24, 2025

Purpose

This is the Python-only fix version of #22724. See context and discussions there.
Thanks @EricMarcus-ai for the awesome work!

Related:
vLLM: #20431
vLLM: #16993
OpenRLHF: OpenRLHF/OpenRLHF#1052
ModelScope: modelscope/ms-swift#4353

Test Plan

Below is a minor modification to the initial test added in #22724. Because this is a stress test, will not add to CI for cost/perf reason.

# Helper churn to pressure CPython's small-object allocator so that
# temporary bound-method objects get collected and their memory is reused.
class _Churn:

    def cb(self):
        return 1


def _churn_bound_methods(n: int = 500_000):
    for _ in range(n):
        _ = _Churn().cb  # create & drop a temporary bound-method object


def _churn_small_objs(n: int = 2_000_000):
    for _ in range(n):
        _ = object()


@create_new_process_for_each_test()
def test_cumem_callback_lifetime_regression():
    """
    Test for the borrowed callback bug
    Context
    -------
    The C++ extension originally borrowed references to the Python
    bound methods passed into `init_module`. After leaving the Python
    context that created those bound methods, CPython can free them.
    Later, when tensors that were allocated by the pluggable allocator
    are freed, `my_free` calls back into C++; Python using those
    now-dangling pointers, causing UB / segfaults.
    What this test does
    -------------------
    1) Create the singleton `CuMemAllocator` and build a pluggable
       allocator by calling `get_pluggable_allocator(A.python_malloc_callback,
       A.python_free_callback)`. This mirrors the setup in cumem.py
    2) Aggressively churn Python objects to encourage CPython to reuse
       the memory that held those temporaries.
    3) Perform allocate inside the pluggable mem-pool to
       drive calls into `my_malloc`/`my_free`.
    Expected outcome
    ----------------
    • With the original extension that borrows refs:
      this test aborts (segfault) during or after
      the first/second allocate.
    • With the updated extension (explicit referencing in `init_module`) 
      the test should complete without crashing.
    Pass criterion
    --------------
    The test passes if the process completes the allocate/free cycles
    without aborting. There are no numeric assertions; a failure will
    manifest as a process crash.
    """

    if torch.cuda.device_count() == 0:
        pytest.skip("No CUDA device available")

    A = CuMemAllocator.get_instance()

    # calling this creates temporary bound-method objects.
    alloc = get_pluggable_allocator(A.python_malloc_callback,
                                    A.python_free_callback)
    mpool = torch.cuda.memory.MemPool(alloc._allocator)

    _churn_bound_methods(400_000)
    _churn_small_objs(500_000)
    gc.collect()

    with torch.cuda.memory.use_mem_pool(mpool):
        bufs = [
            torch.empty(1 << 20, dtype=torch.uint8, device="cuda")
            for _ in range(128)
        ]  # ~ 128 MiB
        del bufs
        gc.collect()

    # churn harder
    _churn_bound_methods(1_000_000)
    gc.collect()

    with torch.cuda.memory.use_mem_pool(mpool):
        buf = torch.empty(1 << 20, dtype=torch.uint8, device="cuda")
        del buf
        gc.collect()

    # Clean shutdown
    torch.cuda.synchronize()
    del mpool, alloc
    gc.collect()

Test Result

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical memory corruption bug caused by Python's garbage collector freeing CUDA callback methods that were still in use by a C++ extension. The fix introduces strong references to these callbacks to ensure they are not prematurely collected.

The approach taken is correct, but it could be made more robust. I've added a comment suggesting a refactoring that would prevent accidental reintroduction of the bug in the future by making the API safer to use. This change would make the fix more resilient to future code modifications.

Comment on lines 155 to 156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this correctly creates strong references to the callbacks, it introduces new public attributes (..._ref) and leaves the original methods (python_malloc_callback, python_free_callback) accessible. If a developer accidentally uses the original methods, they will get a new temporary bound method object, reintroducing the memory corruption bug. This makes the fix fragile.

A more robust approach would be to rename the callback methods with a leading underscore (e.g., _python_malloc_callback) to indicate they are for internal use, and then create public attributes in __init__ that hold the bound methods. This way, any access to self.python_malloc_callback will return the same strongly-referenced object, making the API safer and preventing accidental misuse.

Here's how you could refactor it:

  1. Rename the methods in vllm/device_allocator/cumem.py:
    def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
        ...
    
    def _python_free_callback(self, ptr: int) -> HandleType:
        ...
  2. Update __init__ to create the public, persistent bound methods:
    # In __init__
    self.python_malloc_callback = self._python_malloc_callback
    self.python_free_callback = self._python_free_callback
  3. Update the call site in use_memory_pool to use the public attributes (which no longer need the _ref suffix):
    # In use_memory_pool
    with use_memory_pool_with_allocator(
            self.python_malloc_callback,
            self.python_free_callback) as data:
        ...

This refactoring would make the fix more resilient to future changes.

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
@22quinn 22quinn requested a review from youkaichao August 24, 2025 01:52
@22quinn 22quinn added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 24, 2025
self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {}
self.python_malloc_callback = self._python_malloc_callback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to explain why we are doing this? self._python_malloc_callback is a temporary object everytime we access it.

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
@youkaichao youkaichao merged commit 9dc30b7 into vllm-project:main Aug 24, 2025
36 checks passed
@22quinn 22quinn added the rl Related to RL workflows label Aug 24, 2025
@22quinn 22quinn deleted the cumem-gc-fix branch August 24, 2025 05:49
johnnynunez pushed a commit to johnnynunez/vllm that referenced this pull request Aug 24, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
ekagra-ranjan pushed a commit to ekagra-ranjan/vllm that referenced this pull request Sep 4, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…llm-project#23477)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Eric Marcus <eric.marcus@kaiko.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
@acodercat
Copy link

I still have this error:[rank1]:[W1110 21:40:15.983774416 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W1110 21:40:15.175479179 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank3]:[W1110 21:40:15.185425871 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank0]:[W1110 21:40:16.665186314 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:149 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x80 (0x15555437eeb0 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x69 (0x15555431bb5f in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x27f (0x1554fa0d4d9f in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: + 0x1e4c0 (0x15555471c4c0 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #4: + 0x34351 (0x155554732351 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #5: + 0x38e33 (0x155554736e33 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1bc (0x15555471d50c in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #7: + 0xbc9d62 (0x155547707d62 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
frame #8: + 0x374fcd (0x155546eb2fcd in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
frame #9: + 0x37560e (0x155546eb360e in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
frame #10: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x162f5ba]
frame #11: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x1601812]
frame #12: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x162ee92]
frame #13: _PyObject_ClearManagedDict + 0x10d (0x168c0cd in /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python)
frame #14: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x168bf69]
frame #15: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x168b698]
frame #16: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x178a194]
frame #17: Py_FinalizeEx + 0xe1 (0x1788abb in /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python)
frame #18: Py_RunMain + 0x23f (0x1738b15 in /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python)
frame #19: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x175c9fa]
frame #20: /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python() [0x175c7ed]
frame #21: + 0x29d90 (0x155555208d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #22: __libc_start_main + 0x80 (0x155555208e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #23: _start + 0x29 (0x17cb6e9 in /home/maohao/workspace/agent-training/msswift-training/.venv/bin/python)

terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:149 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x80 (0x15555437eeb0 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x69 (0x15555431bb5f in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x27f (0x1554fa0d4d9f in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: + 0x1e4c0 (0x15555471c4c0 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #4: + 0x34351 (0x155554732351 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #5: + 0x38e33 (0x155554736e33 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1bc (0x15555471d50c in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #7: + 0xbc9d62 (0x155547707d62 in /home/maohao/workspace/agent-training/msswift-training/.venv/lib/python3.12/site-packages/torch/lib/libtorch_python.so)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rl Related to RL workflows

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants