diff --git a/rfcs/20231017-collective-broadcast.md b/rfcs/20231017-collective-broadcast.md new file mode 100644 index 00000000000..dd7b64e9743 --- /dev/null +++ b/rfcs/20231017-collective-broadcast.md @@ -0,0 +1,87 @@ +# [RFC] Add collective_broadcast to the StableHLO specification + +Status: Review
+Initial version: 10/17/20223
+Last updated: 11/1/2023
+Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1809) + +## Motivation + +StableHLO currently has [five collective communication primitives](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective-ops): `collective_permute`, `all_gather`, `all_to_all`, `all_reduce`, and `reduce_scatter`. However, one of the major collective communication primitives, `broadcast`, is missing from this list. This primitive allows for a one-to-many replication of a tensor to many devices efficiently. `broadcast` is a primitive in [MPI](https://www.open-mpi.org/doc/v4.1/man3/MPI_Bcast.3.php), [NCCL](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#c.ncclBroadcast), and [PyTorch](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast). From here on out, we will refer to this operation as `collective_broadcast` for reasons discussed later. + +While it technically would be possible to replicate a broadcast with a conditional mask and a `psum`, that reduces to an `all_reduce` communication primitive, which is significantly more expensive than a simple `collective_broadcast`. Additionally, when dealing with network-switch environments, the explicit use of `collective_broadcast` allows the switch to greatly optimize it's throughput when replicating to many targets simultaneously. However, XLA currently has no ability to lower directly to a mesh's `collective_broadcast` primitive, so a lot of that optimization is left on the table. + +Additionally, a new compiler pass that detects usage of the old `psum` hack and replaces it with a `collective_broadcast` could be implemented only once and forever be supported by all hardware, future and current. This could have positive knock-on effects for users who don't even realize they're using it! + +`collective_broadcast` can be used to quickly replicate a tensor across an entire mesh, and would use less communication resources as compared to `all_gather` or `psum`. `collective_broadcast` is also the base primitive used in the [SUMMA](https://www.netlib.org/lapack/lawnspdf/lawn96.pdf) distributed GEMM algorithm. As AI computing grows larger, there likely will grow a need for these 2D distributed GEMM algorithms. Adding support for one of the needed primitives could help advance research in these areas. + +## Alternatives considered + +Instead of adding `collective_broadcast` as a primitive, we considered loosening the restriction of `collective_permute` to allow a one-to-many communication schedule instead of the current restriction of a one-to-one schedule. Downstream compilers would then be responsible for detecting this and calling their own `collective_broadcast` primitive. However, loosening this restriction makes defining the transposition rule for `collective_permute` significantly more complicated. Questions of how to calculate that and do it efficiently given any communication configuration and do so in SPMD became difficult. However, the transposition rule for `collective_broadcast` is just `psum` with a source-device one-hot masking. This simplicity plus the broad usage of `collective_broadcast` in the wider ecosystem made us choose to ultimately add the new primitive instead. + +## Why call it collective_broadcast and not just broadcast? +Unfortunately, the op name `broadcast` is already taken by [an op in XLA proper](https://www.tensorflow.org/xla/operation_semantics#broadcast), so we can't have the two names clash. `collective_broadcast` was the preferred alternative. + +## Proposed Specification + +### collective_broadcast + +#### Semantics + +Within each process group in the StableHLO process grid, send the value of the +`operand` tensor from the source process to the target processes and produce a +`result` tensor. + +The operation splits the StableHLO process grid into `process_groups` which is +defined as follows: + +* `cross_replica(replica_groups)` if `channel_id <= 0`. +* `cross_partition(replica_groups)` if `channel_id > 0`. + +Afterwards, `result@process` is given by: + +* `operand@process_groups[i, 0]` if there exists an `i` such that + the process is in `process_groups[i]`. +* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))` otherwise. + + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-------------------------|------------------------------------------------------------------|-------------| +| (I1) | `operand` | tensor | (C3) | +| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1), (C2) | +| (I3) | `channel_id` | constant of type `si64` | | + +#### Outputs + +| Name | Type | Constraints | +|----------|--------|-------------| +| `result` | tensor | (C3) | + +#### Constraints + +* (C1) `is_unique(replica_groups)`. +* (C2) `0 <= replica_groups < N` where `N` is defined as: + * `num_replicas` if `cross_replica` is used. + * `num_partitions` if `cross_partition` is used. +* (C3) `type(result) = type(operand)`. + +#### Examples + +```mlir +// num_replicas: 4 +// num_partitions: 1 +// %operand@(0, 0): [[1, 2]] +// %operand@(1, 0): [[3, 4]] +// %operand@(2, 0): [[5, 6]] +// %operand@(3, 0): [[7, 8]] +%result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle +} : (tensor1x2xi64>) -> tensor<1x2xi64> +// %result@(0, 0): [[0, 0]] +// %result@(1, 0): [[5, 6]] +// %result@(2, 0): [[5, 6]] +// %result@(3, 0): [[0, 0]] +```