-
Notifications
You must be signed in to change notification settings - Fork 572
Add custom communicator for trtllm_mnnvl_ar #2056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe changes thread an optional Changes
Sequence DiagramsequenceDiagram
actor Test
participant Workspace as get_allreduce_mnnvl_workspace
participant Buffer as McastGPUBuffer
participant DeviceMem as McastDeviceMemory
participant Comm
Test->>Workspace: call(mapping, dtype, comm_backend_for_handle_transfer?, buffer_size)
Workspace->>Buffer: new McastGPUBuffer(..., comm_backend_for_handle_transfer)
Buffer->>DeviceMem: initialize(buf_size, ..., comm_backend_for_handle_transfer)
DeviceMem->>DeviceMem: _alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer)
alt comm_backend_for_handle_transfer provided
DeviceMem->>Comm: use provided backend for handle transfer
else no backend provided
DeviceMem->>DeviceMem: lazily create MpiComm and use it
end
Workspace->>Workspace: synchronize before use
alt comm_backend_for_handle_transfer provided
Workspace->>Comm: comm_backend_for_handle_transfer.barrier()
else
Workspace->>Workspace: mpi_barrier()
end
Workspace-->>Test: return workspace (McastGPUBuffer, tensor, size)
Estimated code review effortπ― 3 (Moderate) | β±οΈ ~20 minutes
Poem
Pre-merge checks and finishing touchesβ Failed checks (2 warnings)
β Passed checks (1 passed)
β¨ Finishing touches
π§ͺ Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
7edb1b6 to
28b4dd4
Compare
28b4dd4 to
602adfe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
π§Ή Nitpick comments (2)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (2)
29-46: Simplify redundant initialization inallgatherbytes case.On line 42,
gathered = [data] * self.Get_size()is immediately overwritten bydist.all_gather_objecton line 43, making the initialization redundant.Apply this diff:
elif isinstance(data, bytes): local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device) world_size = self.Get_size() - gathered = [data] * self.Get_size() + gathered = [None] * self.Get_size() dist.all_gather_object(gathered, data, group=self._group) return gathered
66-67: Document or implementSplitmethod properly.The
Splitmethod currently just returnsselfwithout performing any actual split operation. If this is intentional for test simplicity (since the test doesn't require sub-communicators), consider documenting this limitation. Otherwise, implement proper group splitting.If this stub is intentional, add a docstring:
def Split(self, color: int, key: int) -> "CustomCommunicator": + """ + Stub implementation for testing - does not actually split the communicator. + Returns self since test scenarios don't require sub-communicators. + """ return self
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (4)
flashinfer/comm/mnnvl.py(8 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(4 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(3 hunks)tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py(1 hunks)
π§ Files skipped from review as they are similar to previous changes (1)
- flashinfer/comm/mnnvl.py
π§° Additional context used
𧬠Code graph analysis (3)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (1)
Mapping(21-475)flashinfer/comm/mnnvl.py (4)
CommBackend(146-162)MpiComm(179-199)barrier(159-159)barrier(215-216)tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
barrier(60-64)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (4)
flashinfer/comm/mapping.py (4)
Mapping(21-475)local_rank(391-392)node_rank(387-388)tp_rank(325-326)flashinfer/comm/mnnvl.py (11)
CommBackend(146-162)Get_rank(150-150)Get_rank(206-207)Get_size(153-153)Get_size(209-210)allgather(156-156)allgather(212-213)barrier(159-159)barrier(215-216)Split(162-162)Split(218-220)flashinfer/comm/trtllm_mnnvl_ar.py (1)
get_allreduce_mnnvl_workspace(124-199)tests/comm/test_trtllm_mnnvl_allreduce.py (1)
row_linear_residual_norm_fusion_forward(17-152)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
flashinfer/comm/mnnvl.py (6)
McastGPUBuffer(966-1030)CommBackend(146-162)lamport_initialize(946-963)lamport_initialize(1005-1006)barrier(159-159)barrier(215-216)
πͺ Ruff (0.14.4)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py
46-46: Avoid specifying long messages outside the exception class
(TRY003)
66-66: Unused method argument: color
(ARG002)
66-66: Unused method argument: key
(ARG002)
89-89: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
210-210: Do not catch blind exception: Exception
(BLE001)
253-255: Avoid specifying long messages outside the exception class
(TRY003)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (9)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
2-2: LGTM! Clean import and parameter additions.The imports are properly used:
Optionalfor type hints,CommBackendfor the parameter type, andMpiCommfor the fallback communicator. The parameter addition maintains backward compatibility with a sensible default.Also applies to: 10-10, 32-32
41-45: LGTM! Correct conditional barrier pattern.The logic correctly selects between the provided custom backend and the default MPI backend, then uses the unified interface to call
barrier().flashinfer/comm/trtllm_mnnvl_ar.py (3)
18-18: LGTM! Necessary imports added.Both
McastGPUBuffer(used on line 168) andCommBackend(used on line 127) are properly utilized in this file.
168-185: LGTM! Consistent parameter threading and barrier logic.The
comm_backend_for_handle_transferis correctly passed to theMcastGPUBufferconstructor, and the barrier logic follows the same conditional pattern as in the test file: fall back tompi_barrier()when no custom backend is provided.
124-129: No issues foundβall call sites have been updated correctly.The verification shows both test call sites are properly aligned with the new parameter order. The test in
test_trtllm_mnnvl_allreduce_custom_comm.pypasses 4 positional arguments that map correctly to the new signature (commβcomm_backend_for_handle_transfer,explicit_workspace_bytesβbuffer_size_in_bytes). The other test uses keyword arguments, which is resilient to signature changes. No breaking changes detected.tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (4)
70-78: LGTM! Robust port allocation with IPv6 fallback.The function correctly attempts IPv4 first and falls back to IPv6 if needed, using proper context managers for socket cleanup.
81-98: LGTM! Proper multi-process test orchestration.The function correctly spawns processes, waits for completion, and validates exit codes. The
force=Trueinset_start_methodis appropriate for pytest environments where the method may be set multiple times.
217-217: Verifyallgatherhandles boolean values correctly.The code passes
rank_failed(a bool) tocomm.allgather(), but theallgathermethod implementation (lines 29-46) only explicitly handlesintandbytestypes. While Python booleans are a subtype of int and this will likely work, it's not explicitly supported in the type signature.Consider either:
- Explicitly casting to int:
comm.allgather(int(rank_failed))- Updating the
allgathermethod to acceptint | bool | bytesand handle bool explicitlyApply this diff for the explicit cast:
- all_failures = comm.allgather(rank_failed) + all_failures = comm.allgather(int(rank_failed))
241-263: LGTM! Well-structured test with proper GPU validation.The test correctly validates available GPU resources before attempting to spawn processes, preventing cryptic failures. The parameterization over world sizes provides good coverage.
| Args: | ||
| mapping: Tensor parallel mapping configuration containing rank info | ||
| dtype: Data type of the tensors being reduced | ||
| comm: Optional communication backend for multi-node synchronization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix docstring parameter name mismatch.
The docstring references comm: but the actual parameter is named comm_backend_for_handle_transfer. This inconsistency may confuse users and tools that parse docstrings.
Apply this diff:
- comm: Optional communication backend for multi-node synchronization
+ comm_backend_for_handle_transfer: Optional communication backend for multi-node synchronizationπ Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| comm: Optional communication backend for multi-node synchronization | |
| comm_backend_for_handle_transfer: Optional communication backend for multi-node synchronization |
π€ Prompt for AI Agents
In flashinfer/comm/trtllm_mnnvl_ar.py around line 144, the docstring refers to a
parameter named "comm:" while the actual function parameter is
"comm_backend_for_handle_transfer"; update the docstring to use the exact
parameter name "comm_backend_for_handle_transfer" (and adjust its short
description if needed) so the parameter list matches the function signature and
docstring parsers/tools can correctly map the description to the parameter.
| reference_output = (allreduce_result,) | ||
|
|
||
| # Run the test with the same workspace | ||
| from .test_trtllm_mnnvl_allreduce import row_linear_residual_norm_fusion_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π οΈ Refactor suggestion | π Major
Extract shared test helper to common module.
As noted in a previous review, row_linear_residual_norm_fusion_forward is now used across multiple test files. Importing from another test file creates test-to-test dependencies which can be fragile.
Based on learnings
Consider extracting this function to a shared test utilities module, e.g., tests/comm/test_helpers.py or tests/comm/common.py, and importing from there in both test files.
π€ Prompt for AI Agents
In tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py around line 185, the
test imports row_linear_residual_norm_fusion_forward from another test file
creating fragile test-to-test coupling; extract that helper into a new shared
test module (e.g., tests/comm/test_helpers.py or tests/comm/common.py), move the
function implementation there, update both test files to import the helper from
the new shared module, and ensure the new module has any required
imports/fixtures so tests remain self-contained.
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / Reliability
Tests