-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
78c753a
commit af9df10
Showing
1 changed file
with
133 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# [RFC] Add async support to the StableHLO specification | ||
|
||
Status: Under Review<br/> | ||
Initial version: 09/17/2024<br/> | ||
Last updated: 09/17/2024<br/> | ||
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/...) | ||
|
||
|
||
## Motivation | ||
|
||
Today, stableHLO ops are designed to be executed sequentially, and any async dispatch or scheduling is left to the compiler to define. | ||
|
||
However, getting XLA to generate optimized schedules has proven to be very challenging. | ||
Users have found that this leaves a lot of performance on the table, and have vocalized a desire to have more control over the scheduling. | ||
|
||
There is an [excellent write up](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487) | ||
from Yifu Wang that goes into detail the performance benefits of async tensor parallelism. | ||
|
||
There is already existing async infrastructure in XLA that we use to create collective matmuls, so the main goal | ||
is to expose this in stableHLO and have it accessible in JAX's `shard_map`. | ||
|
||
|
||
## Proposed Specification changes | ||
|
||
### Types | ||
|
||
|
||
```ebnf | ||
AsyncType ::= 'async' '<' ValueType '>' | ||
``` | ||
*Async Types* represents tensor values that must be awaited on before using the underlying values. Async operations | ||
allow multiple operations to be running at once as described in the Async Execution section. | ||
|
||
Add `AsyncType` to `NonValueType` | ||
|
||
```ebnf | ||
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType | AsyncType | ||
``` | ||
|
||
### Ops | ||
|
||
### async_start | ||
|
||
#### Semantics | ||
|
||
Produces the output from executing the `body` function, but runs all operations on a stream separate from the main compute stream. | ||
|
||
The output of an `async_start` computation must first be processed by an `async_done` operation. | ||
|
||
#### Inputs | ||
|
||
| Label | Name | Type | Constraints | | ||
|-------|-----------|---------------------------------------------------------|-------------| | ||
| (I1) | `operand` | variadic number of tensors, quantized tensors or tokens | (C1) | | ||
| (I2) | `body` | function | (C1) | | ||
|
||
#### Outputs | ||
|
||
| Name | Type | Constraints | | ||
|-----------|-------------------------|-------------| | ||
| `results` | async value | (C1) | | ||
|
||
#### Constraints | ||
|
||
* (C1) `body` has type `(T0, ..., TN-1) -> (R0, ..., RM-1)`, where | ||
`Ti = type(operand[i])` and `results` has type `async<(R0, ..., RM-1)>` | ||
|
||
#### Examples | ||
|
||
```mlir | ||
// %init_i: 2 | ||
// %init_sum: 3 | ||
%future = "stablehlo.async_start"( | ||
%init_i as %arg0: tensor<i64>, | ||
%init_sum as %arg1: tensor<i64>) | ||
{ | ||
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64> | ||
stablehlo.return %new_sum : tensor<i64> | ||
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>> | ||
%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64> | ||
// %result: 5 | ||
``` | ||
|
||
### async_done | ||
|
||
#### Semantics | ||
|
||
Waits for the values created by an `async_start` operation to be finalized. All tensors given to `async_done` must has type `async<T>`. | ||
|
||
#### Inputs | ||
|
||
| Label | Name | Type | Constraints | | ||
|-------|-----------|--------------------|-------------| | ||
| (I1) | `operand` | async value | (C1) | | ||
|
||
#### Outputs | ||
|
||
| Name | Type | Constraints | | ||
|-----------|---------------------------------------------------------|-------------| | ||
| `results` | variadic number of tensors, quantized tensors or tokens | (C1) | | ||
|
||
#### Constraints | ||
|
||
* (C1) `results` has type `T` and `operand` has type `async<T>`. | ||
|
||
#### Examples | ||
|
||
```mlir | ||
// %init_i: 2 | ||
// %init_sum: 3 | ||
%future = "stablehlo.async_start"( | ||
%init_i as %arg0: tensor<i64>, | ||
%init_sum as %arg1: tensor<i64>) | ||
{ | ||
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64> | ||
stablehlo.return %new_sum : tensor<i64> | ||
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>> | ||
%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64> | ||
// %result: 5 | ||
``` | ||
|
||
|
||
## Execution | ||
|
||
### Async Execution | ||
|
||
Stable HLO programs are usually defined as simple sequential operations performed one after another, | ||
and the ordering is implementation defined. However, the ops `async_start` and `async_done` allow you | ||
to define operations that you want to run in parallel with other operations. These can allow you | ||
to better utilize your hardware or to define your own communication scheduling. | ||
Async operations are an advanced tool that should only be used when you know what you are doing. |