Skip to content

Conversation

@wenscarl
Copy link
Collaborator

@wenscarl wenscarl commented Nov 6, 2025

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added optional communication-backend parameter for multi-node memory and buffer allocation to allow using a provided communicator for handle transfer.
  • Bug Fixes / Reliability

    • Multi-node synchronization now uses the provided communicator's barrier when available, preserving previous behavior otherwise.
  • Tests

    • Added end-to-end tests covering custom communication backends and multi-node all-reduce synchronization.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 6, 2025

Walkthrough

The changes thread an optional comm_backend_for_handle_transfer communication backend through multi-node multicast memory initialization and all-reduce workspace construction. New barrier() was added to CommBackend/MPIBackend. Call sites and tests were updated to use the provided backend when present, falling back to MPI when None, and a new test validates a custom communicator end-to-end.

Changes

Cohort / File(s) Summary
Comm backend and multicast memory
flashinfer/comm/mnnvl.py
Added CommBackend.barrier() and MPIBackend.barrier() methods. Added comm_backend_for_handle_transfer: Optional[CommBackend] = None parameter to McastDeviceMemory.__init__ and McastGPUBuffer.__init__. Updated _alloc_mn_mcast_mem(self, buf_size, comm_backend_for_handle_transfer: Any = None) to use the provided backend or lazily create an MpiComm. Propagated comm_backend_for_handle_transfer through multi-node allocation flow.
All-reduce workspace setup
flashinfer/comm/trtllm_mnnvl_ar.py
Imported CommBackend. Extended get_allreduce_mnnvl_workspace() signature to accept comm_backend_for_handle_transfer: Optional[CommBackend] = None. Passed the backend into McastGPUBuffer construction and replaced unconditional MPI barrier with conditional logic: use comm_backend_for_handle_transfer.barrier() if provided, otherwise call mpi_barrier().
Tests: custom communicator
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py
New test module adding CustomCommunicator (implements CommBackend with Get_rank, Get_size, allgather, bcast, barrier, Split), utilities for dynamic port allocation and multi-process orchestration, and parametrized tests (world sizes 2, 4) validating all-reduce using the custom backend.
Tests: existing allreduce updates
tests/comm/test_trtllm_mnnvl_allreduce.py
Added Optional typing and imported CommBackend, MpiComm. Extended test helper and test signatures to accept comm_backend_for_handle_transfer: Optional[CommBackend] = None. Replaced unconditional MPI barrier with conditional barrier selection using the provided backend or MpiComm.

Sequence Diagram

sequenceDiagram
    actor Test
    participant Workspace as get_allreduce_mnnvl_workspace
    participant Buffer as McastGPUBuffer
    participant DeviceMem as McastDeviceMemory
    participant Comm

    Test->>Workspace: call(mapping, dtype, comm_backend_for_handle_transfer?, buffer_size)
    Workspace->>Buffer: new McastGPUBuffer(..., comm_backend_for_handle_transfer)
    Buffer->>DeviceMem: initialize(buf_size, ..., comm_backend_for_handle_transfer)
    DeviceMem->>DeviceMem: _alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer)
    alt comm_backend_for_handle_transfer provided
        DeviceMem->>Comm: use provided backend for handle transfer
    else no backend provided
        DeviceMem->>DeviceMem: lazily create MpiComm and use it
    end
    Workspace->>Workspace: synchronize before use
    alt comm_backend_for_handle_transfer provided
        Workspace->>Comm: comm_backend_for_handle_transfer.barrier()
    else
        Workspace->>Workspace: mpi_barrier()
    end
    Workspace-->>Test: return workspace (McastGPUBuffer, tensor, size)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Check that comm_backend_for_handle_transfer is consistently passed and not dropped across call chains.
  • Verify lazy MpiComm creation avoids double-init/resource leaks.
  • Confirm conditional barrier logic preserves synchronization semantics.
  • Review CustomCommunicator test for correct torch.distributed usage and process orchestration/cleanup.

Poem

🐰 I tunneled through code, with a nibble and cheer,

A barrier added so ranks can draw near.
Handles now travel on a backend well-known,
Multicast and tests dance, their petals full-blown.
Hooray for the hop β€” buffers ready to own!

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description consists entirely of the template with all sections unfilled - the actual Description, Related Issues, and Reviewer Notes sections are empty, and all checklist items are unchecked. Please fill in the Description section explaining what this PR does and why the custom communicator is needed. Also complete the checklist items and any relevant related issues.
Docstring Coverage ⚠️ Warning Docstring coverage is 24.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
βœ… Passed checks (1 passed)
Check name Status Explanation
Title check βœ… Passed The title accurately captures the main objective of the PR - adding a custom communicator for trtllm_mnnvl_ar functionality, which aligns with the new CustomCommunicator class and related infrastructure changes.
✨ Finishing touches
  • πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@wenscarl wenscarl marked this pull request as ready for review November 14, 2025 21:01
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (2)

29-46: Simplify redundant initialization in allgather bytes case.

On line 42, gathered = [data] * self.Get_size() is immediately overwritten by dist.all_gather_object on line 43, making the initialization redundant.

Apply this diff:

         elif isinstance(data, bytes):
             local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device)
             world_size = self.Get_size()
-            gathered = [data] * self.Get_size()
+            gathered = [None] * self.Get_size()
             dist.all_gather_object(gathered, data, group=self._group)
             return gathered

66-67: Document or implement Split method properly.

The Split method currently just returns self without performing any actual split operation. If this is intentional for test simplicity (since the test doesn't require sub-communicators), consider documenting this limitation. Otherwise, implement proper group splitting.

If this stub is intentional, add a docstring:

     def Split(self, color: int, key: int) -> "CustomCommunicator":
+        """
+        Stub implementation for testing - does not actually split the communicator.
+        Returns self since test scenarios don't require sub-communicators.
+        """
         return self
πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 28b4dd4 and 602adfe.

πŸ“’ Files selected for processing (4)
  • flashinfer/comm/mnnvl.py (8 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (4 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (3 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/comm/mnnvl.py
🧰 Additional context used
🧬 Code graph analysis (3)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (1)
  • Mapping (21-475)
flashinfer/comm/mnnvl.py (4)
  • CommBackend (146-162)
  • MpiComm (179-199)
  • barrier (159-159)
  • barrier (215-216)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
  • barrier (60-64)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (4)
flashinfer/comm/mapping.py (4)
  • Mapping (21-475)
  • local_rank (391-392)
  • node_rank (387-388)
  • tp_rank (325-326)
flashinfer/comm/mnnvl.py (11)
  • CommBackend (146-162)
  • Get_rank (150-150)
  • Get_rank (206-207)
  • Get_size (153-153)
  • Get_size (209-210)
  • allgather (156-156)
  • allgather (212-213)
  • barrier (159-159)
  • barrier (215-216)
  • Split (162-162)
  • Split (218-220)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • get_allreduce_mnnvl_workspace (124-199)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
  • row_linear_residual_norm_fusion_forward (17-152)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
flashinfer/comm/mnnvl.py (6)
  • McastGPUBuffer (966-1030)
  • CommBackend (146-162)
  • lamport_initialize (946-963)
  • lamport_initialize (1005-1006)
  • barrier (159-159)
  • barrier (215-216)
πŸͺ› Ruff (0.14.4)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py

46-46: Avoid specifying long messages outside the exception class

(TRY003)


66-66: Unused method argument: color

(ARG002)


66-66: Unused method argument: key

(ARG002)


89-89: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


210-210: Do not catch blind exception: Exception

(BLE001)


253-255: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (9)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)

2-2: LGTM! Clean import and parameter additions.

The imports are properly used: Optional for type hints, CommBackend for the parameter type, and MpiComm for the fallback communicator. The parameter addition maintains backward compatibility with a sensible default.

Also applies to: 10-10, 32-32


41-45: LGTM! Correct conditional barrier pattern.

The logic correctly selects between the provided custom backend and the default MPI backend, then uses the unified interface to call barrier().

flashinfer/comm/trtllm_mnnvl_ar.py (3)

18-18: LGTM! Necessary imports added.

Both McastGPUBuffer (used on line 168) and CommBackend (used on line 127) are properly utilized in this file.


168-185: LGTM! Consistent parameter threading and barrier logic.

The comm_backend_for_handle_transfer is correctly passed to the McastGPUBuffer constructor, and the barrier logic follows the same conditional pattern as in the test file: fall back to mpi_barrier() when no custom backend is provided.


124-129: No issues foundβ€”all call sites have been updated correctly.

The verification shows both test call sites are properly aligned with the new parameter order. The test in test_trtllm_mnnvl_allreduce_custom_comm.py passes 4 positional arguments that map correctly to the new signature (comm β†’ comm_backend_for_handle_transfer, explicit_workspace_bytes β†’ buffer_size_in_bytes). The other test uses keyword arguments, which is resilient to signature changes. No breaking changes detected.

tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (4)

70-78: LGTM! Robust port allocation with IPv6 fallback.

The function correctly attempts IPv4 first and falls back to IPv6 if needed, using proper context managers for socket cleanup.


81-98: LGTM! Proper multi-process test orchestration.

The function correctly spawns processes, waits for completion, and validates exit codes. The force=True in set_start_method is appropriate for pytest environments where the method may be set multiple times.


217-217: Verify allgather handles boolean values correctly.

The code passes rank_failed (a bool) to comm.allgather(), but the allgather method implementation (lines 29-46) only explicitly handles int and bytes types. While Python booleans are a subtype of int and this will likely work, it's not explicitly supported in the type signature.

Consider either:

  1. Explicitly casting to int: comm.allgather(int(rank_failed))
  2. Updating the allgather method to accept int | bool | bytes and handle bool explicitly

Apply this diff for the explicit cast:

-        all_failures = comm.allgather(rank_failed)
+        all_failures = comm.allgather(int(rank_failed))

241-263: LGTM! Well-structured test with proper GPU validation.

The test correctly validates available GPU resources before attempting to spawn processes, preventing cryptic failures. The parameterization over world sizes provides good coverage.

Args:
mapping: Tensor parallel mapping configuration containing rank info
dtype: Data type of the tensors being reduced
comm: Optional communication backend for multi-node synchronization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Fix docstring parameter name mismatch.

The docstring references comm: but the actual parameter is named comm_backend_for_handle_transfer. This inconsistency may confuse users and tools that parse docstrings.

Apply this diff:

-        comm: Optional communication backend for multi-node synchronization
+        comm_backend_for_handle_transfer: Optional communication backend for multi-node synchronization
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
comm: Optional communication backend for multi-node synchronization
comm_backend_for_handle_transfer: Optional communication backend for multi-node synchronization
πŸ€– Prompt for AI Agents
In flashinfer/comm/trtllm_mnnvl_ar.py around line 144, the docstring refers to a
parameter named "comm:" while the actual function parameter is
"comm_backend_for_handle_transfer"; update the docstring to use the exact
parameter name "comm_backend_for_handle_transfer" (and adjust its short
description if needed) so the parameter list matches the function signature and
docstring parsers/tools can correctly map the description to the parameter.

reference_output = (allreduce_result,)

# Run the test with the same workspace
from .test_trtllm_mnnvl_allreduce import row_linear_residual_norm_fusion_forward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ› οΈ Refactor suggestion | 🟠 Major

Extract shared test helper to common module.

As noted in a previous review, row_linear_residual_norm_fusion_forward is now used across multiple test files. Importing from another test file creates test-to-test dependencies which can be fragile.

Based on learnings

Consider extracting this function to a shared test utilities module, e.g., tests/comm/test_helpers.py or tests/comm/common.py, and importing from there in both test files.

πŸ€– Prompt for AI Agents
In tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py around line 185, the
test imports row_linear_residual_norm_fusion_forward from another test file
creating fragile test-to-test coupling; extract that helper into a new shared
test module (e.g., tests/comm/test_helpers.py or tests/comm/common.py), move the
function implementation there, update both test files to import the helper from
the new shared module, and ensure the new module has any required
imports/fixtures so tests remain self-contained.

@wenscarl wenscarl requested a review from yzh119 November 15, 2025 05:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants