Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add collective_broadcast to the StableHLO specification #1809

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions rfcs/20231017-collective-broadcast.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# [RFC] Add collective_broadcast to the StableHLO specification

Status: Review<br/>
Initial version: 10/17/20223<br/>
Last updated: 11/1/2023<br/>
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1809)

## Motivation

StableHLO currently has [five collective communication primitives](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective-ops): `collective_permute`, `all_gather`, `all_to_all`, `all_reduce`, and `reduce_scatter`. However, one of the major collective communication primitives, `broadcast`, is missing from this list. This primitive allows for a one-to-many replication of a tensor to many devices efficiently. `broadcast` is a primitive in [MPI](https://www.open-mpi.org/doc/v4.1/man3/MPI_Bcast.3.php), [NCCL](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#c.ncclBroadcast), and [PyTorch](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast). From here on out, we will refer to this operation as `collective_broadcast` for reasons discussed later.

While it technically would be possible to replicate a broadcast with a conditional mask and a `psum`, that reduces to an `all_reduce` communication primitive, which is significantly more expensive than a simple `collective_broadcast`. Additionally, when dealing with network-switch environments, the explicit use of `collective_broadcast` allows the switch to greatly optimize it's throughput when replicating to many targets simultaneously. However, XLA currently has no ability to lower directly to a mesh's `collective_broadcast` primitive, so a lot of that optimization is left on the table.

Additionally, a new compiler pass that detects usage of the old `psum` hack and replaces it with a `collective_broadcast` could be implemented only once and forever be supported by all hardware, future and current. This could have positive knock-on effects for users who don't even realize they're using it!

`collective_broadcast` can be used to quickly replicate a tensor across an entire mesh, and would use less communication resources as compared to `all_gather` or `psum`. `collective_broadcast` is also the base primitive used in the [SUMMA](https://www.netlib.org/lapack/lawnspdf/lawn96.pdf) distributed GEMM algorithm. As AI computing grows larger, there likely will grow a need for these 2D distributed GEMM algorithms. Adding support for one of the needed primitives could help advance research in these areas.

## Alternatives considered
GleasonK marked this conversation as resolved.
Show resolved Hide resolved

Instead of adding `collective_broadcast` as a primitive, we considered loosening the restriction of `collective_permute` to allow a one-to-many communication schedule instead of the current restriction of a one-to-one schedule. Downstream compilers would then be responsible for detecting this and calling their own `collective_broadcast` primitive. However, loosening this restriction makes defining the transposition rule for `collective_permute` significantly more complicated. Questions of how to calculate that and do it efficiently given any communication configuration and do so in SPMD became difficult. However, the transposition rule for `collective_broadcast` is just `psum` with a source-device one-hot masking. This simplicity plus the broad usage of `collective_broadcast` in the wider ecosystem made us choose to ultimately add the new primitive instead.

## Why call it collective_broadcast and not just broadcast?
Unfortunately, the op name `broadcast` is already taken by [an op in XLA proper](https://www.tensorflow.org/xla/operation_semantics#broadcast), so we can't have the two names clash. `collective_broadcast` was the preferred alternative.

## Proposed Specification

### collective_broadcast

#### Semantics

Within each process group in the StableHLO process grid, send the value of the
`operand` tensor from the source process to the target processes and produce a
`result` tensor.

The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:

* `cross_replica(replica_groups)` if `channel_id <= 0`.
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved
* `cross_partition(replica_groups)` if `channel_id > 0`.

Afterwards, `result@process` is given by:

* `operand@process_groups[i, 0]` if there exists an `i` such that
the process is in `process_groups[i]`.
* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))` otherwise.


#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------------------|------------------------------------------------------------------|-------------|
| (I1) | `operand` | tensor | (C3) |
| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1), (C2) |
| (I3) | `channel_id` | constant of type `si64` | |
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved

#### Outputs

| Name | Type | Constraints |
|----------|--------|-------------|
| `result` | tensor | (C3) |

#### Constraints

* (C1) `is_unique(replica_groups)`.
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved
* (C2) `0 <= replica_groups < N` where `N` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_partitions` if `cross_partition` is used.
* (C3) `type(result) = type(operand)`.

#### Examples
GleasonK marked this conversation as resolved.
Show resolved Hide resolved

```mlir
// num_replicas: 4
// num_partitions: 1
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
```
Loading