Skip to content

Commit 781603c

Browse files
iamzainhudameta-codesync[bot]
authored andcommitted
Add custom all2all interface (#3454)
Summary: Pull Request resolved: #3454 Add ability for users to pass in their own All2All implementation Reviewed By: aliafzal Differential Revision: D80278019 fbshipit-source-id: 746fbf0f59182efa2c8a07a5bbdd237b3906e364
1 parent a4ca26f commit 781603c

File tree

1 file changed

+185
-23
lines changed

1 file changed

+185
-23
lines changed

torchrec/distributed/comm_ops.py

Lines changed: 185 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
# pyre-strict
99

10+
from abc import ABC, abstractmethod
11+
from collections.abc import Sequence
1012
from contextlib import contextmanager
1113
from dataclasses import dataclass, field
12-
from typing import Any, List, Optional, Tuple, TypeVar
14+
from typing import Any, List, Optional, Tuple, TypeVar, Union
1315

1416
import torch
1517
import torch.distributed as dist
1618
import torch.distributed._functional_collectives
17-
1819
from torch import Tensor
1920
from torch.autograd import Function
2021
from torch.autograd.profiler import record_function
@@ -80,6 +81,45 @@ def torchrec_use_sync_collectives():
8081
"""
8182

8283

84+
class Comm(ABC):
85+
"""
86+
Interface for communication primitives.
87+
A primitive primarily needs to handle 3 tasks:
88+
1. **Memory allocation strategy**:
89+
- Associate each call to a temporary buffer (flexible, simple)
90+
- Reuse a persistent buffer (efficient, complex)
91+
2. **Memory location**:
92+
- NCCL memory pool
93+
- Regular CUDA caching allocator
94+
3. **Communication execution**:
95+
- Actual communication primitive implementation
96+
See `All2AllSingle` for a concrete example.
97+
"""
98+
99+
@abstractmethod
100+
def allocate(
101+
self,
102+
size: Sequence[Union[int, torch.SymInt]],
103+
*,
104+
dtype: torch.dtype,
105+
device: torch.device,
106+
) -> torch.Tensor:
107+
"""
108+
Allocate memory for communication buffers.
109+
Args:
110+
size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer
111+
dtype (torch.dtype): dtype of the tensor buffer
112+
device (torch.device): which device to allocate the tensor onto
113+
Returns:
114+
torch.Tensor: Allocated tensor buffer
115+
Example:
116+
```python
117+
tensor = comm.allocate([1024, 512], dtype=torch.float32, device="cuda:0")
118+
```
119+
"""
120+
...
121+
122+
83123
class Request(Awaitable[W]):
84124
"""
85125
Defines a collective operation request for a process group on a tensor.
@@ -334,6 +374,83 @@ def _get_split_lengths_by_len(
334374
return (my_len, splits)
335375

336376

377+
class All2AllSingle(Comm):
378+
"""
379+
Interface for all-to-all single tensor communication.
380+
This primitive performs an all-to-all scatter/gather operation where each
381+
rank sends and receives data from all other ranks in a single operation.
382+
"""
383+
384+
@abstractmethod
385+
def __call__(
386+
self,
387+
output_tensor: torch.Tensor,
388+
input_tensor: torch.Tensor,
389+
output_split_sizes: Optional[list[int]] = None,
390+
input_split_sizes: Optional[list[int]] = None,
391+
group: Optional[dist.ProcessGroup] = None,
392+
async_op: bool = False,
393+
) -> Optional[dist.Work]:
394+
"""
395+
Execute all-to-all single operation.
396+
Args:
397+
output_tensor: Pre-allocated output tensor
398+
input_tensor: Input tensor to scatter
399+
output_split_sizes: Sizes for splitting the output (if None, split evenly)
400+
input_split_sizes: Sizes for splitting the input (if None, split evenly)
401+
group: Process group (if None, uses default group)
402+
async_op: Whether to perform asynchronously
403+
Returns:
404+
Optional work handle if async_op=True, None otherwise
405+
"""
406+
...
407+
408+
409+
class DefaultAllocMixin:
410+
"""
411+
Mixin providing default tensor allocation using PyTorch's standard allocator.
412+
"""
413+
414+
def allocate(
415+
self,
416+
size: Sequence[Union[int, torch.SymInt]],
417+
*,
418+
dtype: torch.dtype,
419+
device: torch.device,
420+
) -> torch.Tensor:
421+
"""Allocate tensor using torch.empty with standard CUDA caching allocator."""
422+
return torch.empty(size, dtype=dtype, device=device)
423+
424+
425+
class DefaultAll2AllSingle(DefaultAllocMixin, All2AllSingle):
426+
"""
427+
Default implementation of all-to-all single using PyTorch's distributed primitives.
428+
429+
This implementation uses:
430+
- Standard CUDA caching allocator for memory allocation
431+
- PyTorch's dist.all_to_all_single for communication
432+
"""
433+
434+
def __call__(
435+
self,
436+
output_tensor: torch.Tensor,
437+
input_tensor: torch.Tensor,
438+
output_split_sizes: Optional[list[int]] = None,
439+
input_split_sizes: Optional[list[int]] = None,
440+
group: Optional[dist.ProcessGroup] = None,
441+
async_op: bool = False,
442+
) -> Optional[dist.Work]:
443+
"""Execute all-to-all using PyTorch's native implementation."""
444+
return dist.all_to_all_single(
445+
output_tensor,
446+
input_tensor,
447+
output_split_sizes,
448+
input_split_sizes,
449+
group=group,
450+
async_op=async_op,
451+
)
452+
453+
337454
def alltoall_pooled(
338455
a2a_pooled_embs_tensor: Tensor,
339456
batch_size_per_rank: List[int],
@@ -342,12 +459,14 @@ def alltoall_pooled(
342459
cumsum_dim_sum_per_rank_tensor: Optional[Tensor] = None,
343460
group: Optional[dist.ProcessGroup] = None,
344461
codecs: Optional[QuantizedCommCodecs] = None,
462+
all_to_all_single_comm: Optional[All2AllSingle] = None,
345463
) -> Awaitable[Tensor]:
346464
"""
347-
Performs AlltoAll operation for a single pooled embedding tensor. Each process
348-
splits the input pooled embeddings tensor based on the world size, and then scatters
349-
the split list to all processes in the group. Then concatenates the received tensors
350-
from all processes in the group and returns a single output tensor.
465+
Performs AlltoAll operation for a single pooled embedding tensor.
466+
467+
Each process splits the input pooled embeddings tensor based on the world size,
468+
and then scatters the split list to all processes in the group. Then concatenates
469+
the received tensors from all processes in the group and returns a single output tensor.
351470
352471
Args:
353472
a2a_pooled_embs_tensor (Tensor): input pooled embeddings. Must be pooled
@@ -367,7 +486,19 @@ def alltoall_pooled(
367486
codecs (Optional[QuantizedCommCodecs]): quantized communication codecs.
368487
369488
Returns:
370-
Awaitable[Tensor]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting tensor.
489+
Awaitable[Tensor]: Async work handle which can be `wait()`-ed later to
490+
get the resulting tensor.
491+
492+
Example:
493+
```python
494+
# Using default implementation
495+
result = alltoall_pooled(embeddings, batch_sizes, dim_sums)
496+
output = result.wait()
497+
# Using custom implementation
498+
custom_comm = MyCustomAll2All()
499+
result = alltoall_pooled(embeddings, batch_sizes, dim_sums,
500+
all2all_single_comm=custom_comm)
501+
```
371502
372503
.. warning::
373504
`alltoall_pooled` is experimental and subject to change.
@@ -391,7 +522,15 @@ def alltoall_pooled(
391522
return NoWait(all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor))
392523

393524
myreq = Request(group, device=a2a_pooled_embs_tensor.device)
394-
All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor)
525+
if all_to_all_single_comm is None:
526+
all_to_all_single_comm = DefaultAll2AllSingle()
527+
All2All_Pooled_Req.apply(
528+
group,
529+
myreq,
530+
a2ai,
531+
a2a_pooled_embs_tensor,
532+
all_to_all_single_comm,
533+
)
395534
return myreq
396535

397536

@@ -476,6 +615,7 @@ def variable_batch_alltoall_pooled(
476615
emb_dim_per_rank_per_feature: List[List[int]],
477616
group: Optional[dist.ProcessGroup] = None,
478617
codecs: Optional[QuantizedCommCodecs] = None,
618+
all_to_all_single_comm: Optional[All2AllSingle] = None,
479619
) -> Awaitable[Tensor]:
480620

481621
if group is None:
@@ -497,7 +637,11 @@ def variable_batch_alltoall_pooled(
497637
)
498638

499639
myreq = Request(group, device=a2a_pooled_embs_tensor.device)
500-
Variable_Batch_All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor)
640+
if all_to_all_single_comm is None:
641+
all_to_all_single_comm = DefaultAll2AllSingle()
642+
Variable_Batch_All2All_Pooled_Req.apply(
643+
group, myreq, a2ai, a2a_pooled_embs_tensor, all_to_all_single_comm
644+
)
501645
return myreq
502646

503647

@@ -1138,6 +1282,7 @@ def forward(
11381282
myreq: Request[Tensor],
11391283
a2ai: All2AllPooledInfo,
11401284
input_embeddings: Tensor,
1285+
all_to_all_single_comm: All2AllSingle,
11411286
) -> Tensor:
11421287
my_rank = dist.get_rank(pg)
11431288
(B_global, D_local_sum) = input_embeddings.shape
@@ -1191,16 +1336,24 @@ def forward(
11911336
input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank]
11921337
qcomm_ctx = None
11931338

1194-
sharded_output_embeddings = torch.empty(
1195-
sum(output_split_sizes),
1339+
sharded_output_embeddings = all_to_all_single_comm.allocate(
1340+
(sum(output_split_sizes),),
1341+
dtype=sharded_input_embeddings.dtype,
1342+
device=sharded_input_embeddings.device,
1343+
)
1344+
1345+
sharded_input_embeddings_registered = all_to_all_single_comm.allocate(
1346+
sharded_input_embeddings.shape,
11961347
dtype=sharded_input_embeddings.dtype,
11971348
device=sharded_input_embeddings.device,
11981349
)
11991350

1351+
sharded_input_embeddings_registered.copy_(sharded_input_embeddings)
1352+
12001353
with record_function("## All2All_Pooled_fwd ##"):
1201-
req = dist.all_to_all_single(
1202-
output=sharded_output_embeddings,
1203-
input=sharded_input_embeddings,
1354+
req = all_to_all_single_comm(
1355+
output_tensor=sharded_output_embeddings,
1356+
input_tensor=sharded_input_embeddings_registered,
12041357
output_split_sizes=output_split_sizes,
12051358
input_split_sizes=input_split_sizes,
12061359
group=pg,
@@ -1218,7 +1371,7 @@ def forward(
12181371

12191372
@staticmethod
12201373
# pyre-fixme[2]: Parameter must be annotated.
1221-
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
1374+
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]:
12221375
pg = ctx.pg
12231376
my_rank = dist.get_rank(pg)
12241377
myreq = ctx.myreq
@@ -1241,7 +1394,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
12411394
grad_input.div_(dist.get_world_size(ctx.pg))
12421395
myreq.tensor = None
12431396
myreq.dummy_tensor = None
1244-
return (None, None, None, grad_input)
1397+
return (None, None, None, grad_input, None)
12451398

12461399

12471400
class All2All_Pooled_Wait(Function):
@@ -1385,6 +1538,7 @@ def forward(
13851538
myreq: Request[Tensor],
13861539
a2ai: VariableBatchAll2AllPooledInfo,
13871540
input_embeddings: Tensor,
1541+
all_to_all_single_comm: All2AllSingle,
13881542
) -> Tensor:
13891543
my_rank = dist.get_rank(pg)
13901544

@@ -1438,16 +1592,24 @@ def forward(
14381592
for split in input_split_sizes
14391593
]
14401594

1441-
sharded_output_embeddings = torch.empty(
1442-
sum(output_split_sizes),
1595+
sharded_output_embeddings = all_to_all_single_comm.allocate(
1596+
(sum(output_split_sizes),),
1597+
dtype=sharded_input_embeddings.dtype,
1598+
device=sharded_input_embeddings.device,
1599+
)
1600+
1601+
sharded_input_embeddings_registered = all_to_all_single_comm.allocate(
1602+
sharded_input_embeddings.shape,
14431603
dtype=sharded_input_embeddings.dtype,
14441604
device=sharded_input_embeddings.device,
14451605
)
14461606

1607+
sharded_input_embeddings_registered.copy_(sharded_input_embeddings)
1608+
14471609
with record_function("## Variable_Batch_All2All_Pooled_fwd ##"):
1448-
req = dist.all_to_all_single(
1449-
output=sharded_output_embeddings,
1450-
input=sharded_input_embeddings,
1610+
req = all_to_all_single_comm(
1611+
output_tensor=sharded_output_embeddings,
1612+
input_tensor=sharded_input_embeddings_registered,
14511613
output_split_sizes=output_split_sizes,
14521614
input_split_sizes=input_split_sizes,
14531615
group=pg,
@@ -1466,7 +1628,7 @@ def forward(
14661628
@staticmethod
14671629
# pyre-fixme[2]: Parameter must be annotated.
14681630
# pyre-fixme[2]: Parameter must be annotated.
1469-
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
1631+
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]:
14701632
myreq = ctx.myreq
14711633
a2ai = myreq.a2ai
14721634
assert myreq.req is not None
@@ -1486,7 +1648,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
14861648
grad_input.div_(dist.get_world_size(ctx.pg))
14871649
myreq.tensor = None
14881650
myreq.dummy_tensor = None
1489-
return (None, None, None, grad_input)
1651+
return (None, None, None, grad_input, None)
14901652

14911653

14921654
class Variable_Batch_All2All_Pooled_Wait(Function):

0 commit comments

Comments
 (0)