Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of collective_broadcast into spec #1856

Merged
merged 8 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,68 @@ for this operation ([#560](https://github.com/openxla/stablehlo/issues/560)).

 [More Examples](../stablehlo/tests/interpret_clamp.mlir)

### collective_broadcast
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved

#### Semantics

Within each process group in the StableHLO process grid, send the value of the
`operand` tensor from the source process to the target processes and produce a
`result` tensor.

The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:

* `cross_replica(replica_groups)` if `channel_id <= 0`.
* `cross_partition(replica_groups)` if `channel_id > 0`.

Afterwards, `result@process` is given by:

* `operand@process_groups[i, 0]` if there exists an `i` such that the process is
in `process_groups[i]`.
* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))`
otherwise.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------------------|------------------------------------------------------------------|-------------|
| (I1) | `operand` | tensor | (C3) |
| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1), (C2) |
| (I3) | `channel_id` | constant of type `si64` | |

#### Outputs

| Name | Type | Constraints |
|----------|--------|-------------|
| `result` | tensor | (C3) |

#### Constraints

* (C1) `is_unique(replica_groups)`.
* (C2) `0 <= replica_groups < N` where `N` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_partitions` if `cross_partition` is used.
* (C3) `type(result) = type(operand)`.

#### Examples

```mlir
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
```

### collective_permute

#### Semantics
Expand Down Expand Up @@ -5949,11 +6011,11 @@ order and what kind of synchronization is introduced by it, is TBD

### Collective ops

There are five collective ops in StableHLO: `all_gather`, `all_reduce`,
`all_to_all`, `collective_permute` and `reduce_scatter`. All these ops split
the processes in the StableHLO process grid into **StableHLO process groups**
and execute a joint computation within each process group, independently from
other process groups.
There are six collective ops in StableHLO: `all_gather`, `all_reduce`,
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved
`all_to_all`, `collective_broadcast`, `collective_permute`, and
`reduce_scatter`. All these ops split the processes in the StableHLO process
grid into **StableHLO process groups** and execute a joint computation within
each process group, independently from other process groups.

Within each process group, collective ops may introduce a synchronization
barrier. Further formalization, e.g. elaborating on when exactly this
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def StableHLO_OutputOperandAlias : AttrDef<StableHLO_Dialect, "OutputOperandAlia
}

// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
// optionally for collective instructions (AllToAll, AllReduce,
chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved
// CollectiveBroadcast, and CollectivePermute). Non-positive channel_id
// handle is equivalent to no channel id.
def StableHLO_ChannelHandle : AttrDef<StableHLO_Dialect, "ChannelHandle"> {
let cppNamespace = "::mlir::stablehlo";
let mnemonic = "channel_handle";
Expand Down
34 changes: 23 additions & 11 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,6 @@ LogicalResult TypeExtensionsAttr::verifyEncoding(
getBounds(), RankedTensorType::get(shape, elementType), emitError);
}

//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//

void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
DenseIntElementsAttr sourceTargetPairs) {
CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand,
sourceTargetPairs, /*channel_handle=*/nullptr);
}

//===----------------------------------------------------------------------===//
// ReduceScatterOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -171,6 +160,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectiveBroadcastOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp)
Expand Down Expand Up @@ -807,10 +797,32 @@ LogicalResult AbsOp::inferReturnTypes(
return hlo::inferAbsOp(location, adaptor.getOperand(), inferredReturnTypes);
}

//===----------------------------------------------------------------------===//
// CollectiveBroadcastOp
//===----------------------------------------------------------------------===//

ghpvnist marked this conversation as resolved.
Show resolved Hide resolved
void CollectiveBroadcastOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
DenseIntElementsAttr replica_groups) {
CollectiveBroadcastOp::build(odsBuilder, odsState, resultType, operand,
replica_groups, /*channel_handle=*/nullptr);
}

LogicalResult CollectiveBroadcastOp::verify() {
return hlo::verifyCollectiveBroadcastOp(getLoc(), getReplicaGroups());
}

//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//

void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
DenseIntElementsAttr sourceTargetPairs) {
CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand,
sourceTargetPairs, /*channel_handle=*/nullptr);
}

LogicalResult CollectivePermuteOp::verify() {
return hlo::verifyCollectivePermuteOp(getLoc(), getSourceTargetPairs());
}
Expand Down
5 changes: 5 additions & 0 deletions stablehlo/dialect/StablehloOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
using Base::Base;
};


// Verifies the source target pairs attached to collective permute.
LogicalResult verifyBroadcastSourceTargetPairs(
Operation *op, DenseIntElementsAttr attr);

// Verifies the source target pairs attached to collective permute.
LogicalResult verifyCollectivePermuteSourceTargetPairs(
Operation *op, DenseIntElementsAttr attr);
Expand Down
36 changes: 36 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,42 @@ def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate",
}];
}


def StableHLO_CollectiveBroadcastOp: StableHLO_Op<"collective_broadcast",
[HLO_CompatibleOperandsAndResultType /*collective_broadcast_c3*/]> {
let summary = "CollectiveBroadcast operation";
let description = [{
Within each process group in the process grid, send the value of the
`operand` tensor from the source process to the target processes and produce a
`result` tensor.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast

Example:
```mlir
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
```
}];

let arguments = (ins
HLO_Tensor:$operand, /*collective_broadcast_i1*/
I64ElementsAttr:$replica_groups, /*collective_broadcast_i2*/
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle /*collective_broadcast_i3*/
);
let results = (outs HLO_Tensor);
let hasVerifier = 1;
// channel_handle is only used for the SPMD partitioner, so we add a
// simplified builder method for convenience.
let builders = [
OpBuilder<(ins
"::mlir::Type":$result_type, "::mlir::Value":$operand,
"::mlir::DenseIntElementsAttr":$replica_groups)>];
}

def StableHLO_CollectivePermuteOp: StableHLO_Op<"collective_permute",
[HLO_CompatibleOperandsAndResultType /*collective_permute_c5*/]> {
let summary = "CollectivePermute operation";
Expand Down
32 changes: 32 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3289,6 +3289,38 @@ LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
return success();
}


LogicalResult verifyCollectiveBroadcastOp(std::optional<Location> location,
DenseIntElementsAttr replicaGroups) {
// collective_permute_i2
auto replicaGroupType = replicaGroups.getType().cast<RankedTensorType>();
if (replicaGroupType.getRank() != 2)
return emitOptionalError(location,
"replica groups should be a rank 2 tensor,",
"but instead it is of rank ",
replicaGroupType.getRank());

auto replicaIds = replicaGroups.getValues<int64_t>();
llvm::SmallSet<int64_t, 8> replicaIdsSeen;
for (int64_t replicaId : replicaIds) {
// collective_broadcast_c2
// We only check that is is not negative, as it is impossible
// to statically know `num_replicas` or `num_partitions`
if (replicaId < 0)
return emitOptionalError(
location, "replica_groups values must be positive, but was given ",
replicaId);

// collective_broadcast_c1
if (!replicaIdsSeen.insert(replicaId).second)
return emitOptionalError(location, "replica id #", replicaId,
" seen more than once");
}

return success();
}


LogicalResult verifyCollectivePermuteOp(
std::optional<Location> location, DenseIntElementsAttr sourceTargetPairs) {
auto type = sourceTargetPairs.getType().dyn_cast<RankedTensorType>();
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
DenseIntElementsAttr broadcastDimensions,
Value result);

LogicalResult verifyCollectiveBroadcastOp(std::optional<Location> location,
DenseIntElementsAttr replicaGroups);

LogicalResult verifyCollectivePermuteOp(std::optional<Location> location,
DenseIntElementsAttr sourceTargetPairs);

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 15, 5); }
static Version getCurrentVersion() { return Version(0, 16, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def VHLO_Dialect : Dialect {
0.12.0: MLIR bytecode version 1 => 3.
0.14.0: MLIR bytecode version 3 => 5 (revised to 4 in #1827).
0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO.
0.16.0: Introduce `collective_broadcast` operation.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/dialect/VhloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,15 @@ def VHLO_ClzOpV1 : VHLO_Op<"count_leading_zeros_v1", "0.9.0", "current"> {
let results = (outs VHLO_AnyType:$result);
}

def VHLO_CollectiveBroadcastOpV1 : VHLO_Op<"collective_broadcast_v1", "0.16.0", "current"> {
let arguments = (ins
VHLO_AnyType:$operand,
VHLO_AnyAttr:$replica_groups,
VHLO_AnyAttr:$channel_id
);
let results = (outs VHLO_AnyType:$result);
}

def VHLO_CollectivePermuteOpV1 : VHLO_Op<"collective_permute_v1", "0.9.0", "current"> {
let arguments = (ins
VHLO_AnyType:$operand,
Expand Down
Loading
Loading