77
88# pyre-strict
99
10+ from abc import ABC , abstractmethod
11+ from collections .abc import Sequence
1012from contextlib import contextmanager
1113from dataclasses import dataclass , field
12- from typing import Any , List , Optional , Tuple , TypeVar
14+ from typing import Any , List , Optional , Tuple , TypeVar , Union
1315
1416import torch
1517import torch .distributed as dist
1618import torch .distributed ._functional_collectives
17-
1819from torch import Tensor
1920from torch .autograd import Function
2021from torch .autograd .profiler import record_function
@@ -80,6 +81,45 @@ def torchrec_use_sync_collectives():
8081"""
8182
8283
84+ class Comm (ABC ):
85+ """
86+ Interface for communication primitives.
87+ A primitive primarily needs to handle 3 tasks:
88+ 1. **Memory allocation strategy**:
89+ - Associate each call to a temporary buffer (flexible, simple)
90+ - Reuse a persistent buffer (efficient, complex)
91+ 2. **Memory location**:
92+ - NCCL memory pool
93+ - Regular CUDA caching allocator
94+ 3. **Communication execution**:
95+ - Actual communication primitive implementation
96+ See `All2AllSingle` for a concrete example.
97+ """
98+
99+ @abstractmethod
100+ def allocate (
101+ self ,
102+ size : Sequence [Union [int , torch .SymInt ]],
103+ * ,
104+ dtype : torch .dtype ,
105+ device : torch .device ,
106+ ) -> torch .Tensor :
107+ """
108+ Allocate memory for communication buffers.
109+ Args:
110+ size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer
111+ dtype (torch.dtype): dtype of the tensor buffer
112+ device (torch.device): which device to allocate the tensor onto
113+ Returns:
114+ torch.Tensor: Allocated tensor buffer
115+ Example:
116+ ```python
117+ tensor = comm.allocate([1024, 512], dtype=torch.float32, device="cuda:0")
118+ ```
119+ """
120+ ...
121+
122+
83123class Request (Awaitable [W ]):
84124 """
85125 Defines a collective operation request for a process group on a tensor.
@@ -334,6 +374,83 @@ def _get_split_lengths_by_len(
334374 return (my_len , splits )
335375
336376
377+ class All2AllSingle (Comm ):
378+ """
379+ Interface for all-to-all single tensor communication.
380+ This primitive performs an all-to-all scatter/gather operation where each
381+ rank sends and receives data from all other ranks in a single operation.
382+ """
383+
384+ @abstractmethod
385+ def __call__ (
386+ self ,
387+ output_tensor : torch .Tensor ,
388+ input_tensor : torch .Tensor ,
389+ output_split_sizes : Optional [list [int ]] = None ,
390+ input_split_sizes : Optional [list [int ]] = None ,
391+ group : Optional [dist .ProcessGroup ] = None ,
392+ async_op : bool = False ,
393+ ) -> Optional [dist .Work ]:
394+ """
395+ Execute all-to-all single operation.
396+ Args:
397+ output_tensor: Pre-allocated output tensor
398+ input_tensor: Input tensor to scatter
399+ output_split_sizes: Sizes for splitting the output (if None, split evenly)
400+ input_split_sizes: Sizes for splitting the input (if None, split evenly)
401+ group: Process group (if None, uses default group)
402+ async_op: Whether to perform asynchronously
403+ Returns:
404+ Optional work handle if async_op=True, None otherwise
405+ """
406+ ...
407+
408+
409+ class DefaultAllocMixin :
410+ """
411+ Mixin providing default tensor allocation using PyTorch's standard allocator.
412+ """
413+
414+ def allocate (
415+ self ,
416+ size : Sequence [Union [int , torch .SymInt ]],
417+ * ,
418+ dtype : torch .dtype ,
419+ device : torch .device ,
420+ ) -> torch .Tensor :
421+ """Allocate tensor using torch.empty with standard CUDA caching allocator."""
422+ return torch .empty (size , dtype = dtype , device = device )
423+
424+
425+ class DefaultAll2AllSingle (DefaultAllocMixin , All2AllSingle ):
426+ """
427+ Default implementation of all-to-all single using PyTorch's distributed primitives.
428+
429+ This implementation uses:
430+ - Standard CUDA caching allocator for memory allocation
431+ - PyTorch's dist.all_to_all_single for communication
432+ """
433+
434+ def __call__ (
435+ self ,
436+ output_tensor : torch .Tensor ,
437+ input_tensor : torch .Tensor ,
438+ output_split_sizes : Optional [list [int ]] = None ,
439+ input_split_sizes : Optional [list [int ]] = None ,
440+ group : Optional [dist .ProcessGroup ] = None ,
441+ async_op : bool = False ,
442+ ) -> Optional [dist .Work ]:
443+ """Execute all-to-all using PyTorch's native implementation."""
444+ return dist .all_to_all_single (
445+ output_tensor ,
446+ input_tensor ,
447+ output_split_sizes ,
448+ input_split_sizes ,
449+ group = group ,
450+ async_op = async_op ,
451+ )
452+
453+
337454def alltoall_pooled (
338455 a2a_pooled_embs_tensor : Tensor ,
339456 batch_size_per_rank : List [int ],
@@ -342,12 +459,14 @@ def alltoall_pooled(
342459 cumsum_dim_sum_per_rank_tensor : Optional [Tensor ] = None ,
343460 group : Optional [dist .ProcessGroup ] = None ,
344461 codecs : Optional [QuantizedCommCodecs ] = None ,
462+ all_to_all_single_comm : Optional [All2AllSingle ] = None ,
345463) -> Awaitable [Tensor ]:
346464 """
347- Performs AlltoAll operation for a single pooled embedding tensor. Each process
348- splits the input pooled embeddings tensor based on the world size, and then scatters
349- the split list to all processes in the group. Then concatenates the received tensors
350- from all processes in the group and returns a single output tensor.
465+ Performs AlltoAll operation for a single pooled embedding tensor.
466+
467+ Each process splits the input pooled embeddings tensor based on the world size,
468+ and then scatters the split list to all processes in the group. Then concatenates
469+ the received tensors from all processes in the group and returns a single output tensor.
351470
352471 Args:
353472 a2a_pooled_embs_tensor (Tensor): input pooled embeddings. Must be pooled
@@ -367,7 +486,19 @@ def alltoall_pooled(
367486 codecs (Optional[QuantizedCommCodecs]): quantized communication codecs.
368487
369488 Returns:
370- Awaitable[Tensor]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting tensor.
489+ Awaitable[Tensor]: Async work handle which can be `wait()`-ed later to
490+ get the resulting tensor.
491+
492+ Example:
493+ ```python
494+ # Using default implementation
495+ result = alltoall_pooled(embeddings, batch_sizes, dim_sums)
496+ output = result.wait()
497+ # Using custom implementation
498+ custom_comm = MyCustomAll2All()
499+ result = alltoall_pooled(embeddings, batch_sizes, dim_sums,
500+ all2all_single_comm=custom_comm)
501+ ```
371502
372503 .. warning::
373504 `alltoall_pooled` is experimental and subject to change.
@@ -391,7 +522,15 @@ def alltoall_pooled(
391522 return NoWait (all2all_pooled_sync (group , a2ai , a2a_pooled_embs_tensor ))
392523
393524 myreq = Request (group , device = a2a_pooled_embs_tensor .device )
394- All2All_Pooled_Req .apply (group , myreq , a2ai , a2a_pooled_embs_tensor )
525+ if all_to_all_single_comm is None :
526+ all_to_all_single_comm = DefaultAll2AllSingle ()
527+ All2All_Pooled_Req .apply (
528+ group ,
529+ myreq ,
530+ a2ai ,
531+ a2a_pooled_embs_tensor ,
532+ all_to_all_single_comm ,
533+ )
395534 return myreq
396535
397536
@@ -476,6 +615,7 @@ def variable_batch_alltoall_pooled(
476615 emb_dim_per_rank_per_feature : List [List [int ]],
477616 group : Optional [dist .ProcessGroup ] = None ,
478617 codecs : Optional [QuantizedCommCodecs ] = None ,
618+ all_to_all_single_comm : Optional [All2AllSingle ] = None ,
479619) -> Awaitable [Tensor ]:
480620
481621 if group is None :
@@ -497,7 +637,11 @@ def variable_batch_alltoall_pooled(
497637 )
498638
499639 myreq = Request (group , device = a2a_pooled_embs_tensor .device )
500- Variable_Batch_All2All_Pooled_Req .apply (group , myreq , a2ai , a2a_pooled_embs_tensor )
640+ if all_to_all_single_comm is None :
641+ all_to_all_single_comm = DefaultAll2AllSingle ()
642+ Variable_Batch_All2All_Pooled_Req .apply (
643+ group , myreq , a2ai , a2a_pooled_embs_tensor , all_to_all_single_comm
644+ )
501645 return myreq
502646
503647
@@ -1138,6 +1282,7 @@ def forward(
11381282 myreq : Request [Tensor ],
11391283 a2ai : All2AllPooledInfo ,
11401284 input_embeddings : Tensor ,
1285+ all_to_all_single_comm : All2AllSingle ,
11411286 ) -> Tensor :
11421287 my_rank = dist .get_rank (pg )
11431288 (B_global , D_local_sum ) = input_embeddings .shape
@@ -1191,16 +1336,24 @@ def forward(
11911336 input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank ]
11921337 qcomm_ctx = None
11931338
1194- sharded_output_embeddings = torch .empty (
1195- sum (output_split_sizes ),
1339+ sharded_output_embeddings = all_to_all_single_comm .allocate (
1340+ (sum (output_split_sizes ),),
1341+ dtype = sharded_input_embeddings .dtype ,
1342+ device = sharded_input_embeddings .device ,
1343+ )
1344+
1345+ sharded_input_embeddings_registered = all_to_all_single_comm .allocate (
1346+ sharded_input_embeddings .shape ,
11961347 dtype = sharded_input_embeddings .dtype ,
11971348 device = sharded_input_embeddings .device ,
11981349 )
11991350
1351+ sharded_input_embeddings_registered .copy_ (sharded_input_embeddings )
1352+
12001353 with record_function ("## All2All_Pooled_fwd ##" ):
1201- req = dist . all_to_all_single (
1202- output = sharded_output_embeddings ,
1203- input = sharded_input_embeddings ,
1354+ req = all_to_all_single_comm (
1355+ output_tensor = sharded_output_embeddings ,
1356+ input_tensor = sharded_input_embeddings_registered ,
12041357 output_split_sizes = output_split_sizes ,
12051358 input_split_sizes = input_split_sizes ,
12061359 group = pg ,
@@ -1218,7 +1371,7 @@ def forward(
12181371
12191372 @staticmethod
12201373 # pyre-fixme[2]: Parameter must be annotated.
1221- def backward (ctx , * unused ) -> Tuple [None , None , None , Tensor ]:
1374+ def backward (ctx , * unused ) -> Tuple [None , None , None , Tensor , None ]:
12221375 pg = ctx .pg
12231376 my_rank = dist .get_rank (pg )
12241377 myreq = ctx .myreq
@@ -1241,7 +1394,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
12411394 grad_input .div_ (dist .get_world_size (ctx .pg ))
12421395 myreq .tensor = None
12431396 myreq .dummy_tensor = None
1244- return (None , None , None , grad_input )
1397+ return (None , None , None , grad_input , None )
12451398
12461399
12471400class All2All_Pooled_Wait (Function ):
@@ -1385,6 +1538,7 @@ def forward(
13851538 myreq : Request [Tensor ],
13861539 a2ai : VariableBatchAll2AllPooledInfo ,
13871540 input_embeddings : Tensor ,
1541+ all_to_all_single_comm : All2AllSingle ,
13881542 ) -> Tensor :
13891543 my_rank = dist .get_rank (pg )
13901544
@@ -1438,16 +1592,24 @@ def forward(
14381592 for split in input_split_sizes
14391593 ]
14401594
1441- sharded_output_embeddings = torch .empty (
1442- sum (output_split_sizes ),
1595+ sharded_output_embeddings = all_to_all_single_comm .allocate (
1596+ (sum (output_split_sizes ),),
1597+ dtype = sharded_input_embeddings .dtype ,
1598+ device = sharded_input_embeddings .device ,
1599+ )
1600+
1601+ sharded_input_embeddings_registered = all_to_all_single_comm .allocate (
1602+ sharded_input_embeddings .shape ,
14431603 dtype = sharded_input_embeddings .dtype ,
14441604 device = sharded_input_embeddings .device ,
14451605 )
14461606
1607+ sharded_input_embeddings_registered .copy_ (sharded_input_embeddings )
1608+
14471609 with record_function ("## Variable_Batch_All2All_Pooled_fwd ##" ):
1448- req = dist . all_to_all_single (
1449- output = sharded_output_embeddings ,
1450- input = sharded_input_embeddings ,
1610+ req = all_to_all_single_comm (
1611+ output_tensor = sharded_output_embeddings ,
1612+ input_tensor = sharded_input_embeddings_registered ,
14511613 output_split_sizes = output_split_sizes ,
14521614 input_split_sizes = input_split_sizes ,
14531615 group = pg ,
@@ -1466,7 +1628,7 @@ def forward(
14661628 @staticmethod
14671629 # pyre-fixme[2]: Parameter must be annotated.
14681630 # pyre-fixme[2]: Parameter must be annotated.
1469- def backward (ctx , * unused ) -> Tuple [None , None , None , Tensor ]:
1631+ def backward (ctx , * unused ) -> Tuple [None , None , None , Tensor , None ]:
14701632 myreq = ctx .myreq
14711633 a2ai = myreq .a2ai
14721634 assert myreq .req is not None
@@ -1486,7 +1648,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
14861648 grad_input .div_ (dist .get_world_size (ctx .pg ))
14871649 myreq .tensor = None
14881650 myreq .dummy_tensor = None
1489- return (None , None , None , grad_input )
1651+ return (None , None , None , grad_input , None )
14901652
14911653
14921654class Variable_Batch_All2All_Pooled_Wait (Function ):
0 commit comments