@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
339339 let hasCanonicalizer = 1;
340340}
341341
342+ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
343+ AllShapesMatch<["input", "result"]>,
344+ AllElementTypesMatch<["input", "result"]>
345+ ]> {
346+ let summary = "Broadcast over a device mesh.";
347+ let description = [{
348+ Broadcast the tensor on `root` to all devices in each respective group.
349+ The operation broadcasts along mesh axes `mesh_axes`.
350+ The `root` device specifies the in-group multi-index that is broadcast to
351+ all other devices in the group.
352+
353+ Example:
354+ ```
355+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
356+
357+ %1 = mesh.broadcast %0 on @mesh0
358+ mesh_axes = [0]
359+ root = [0]
360+ : (tensor<2xi8>) -> tensor<2xi8>
361+ ```
362+
363+ Input:
364+ ```
365+ +-------+-------+ | broadcast
366+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
367+ +-------+-------+ ↓
368+ device (1, 0) -> | | | <- device (1, 1)
369+ +-------+-------+
370+ ```
371+
372+ Output:
373+ ```
374+ +-------+-------+
375+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
376+ +-------+-------+
377+ device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
378+ +-------+-------+
379+ ```
380+ }];
381+ let arguments = !con(commonArgs, (ins
382+ AnyRankedTensor:$input,
383+ DenseI64ArrayAttr:$root,
384+ Variadic<Index>:$root_dynamic
385+ ));
386+ let results = (outs
387+ AnyRankedTensor:$result
388+ );
389+ let assemblyFormat = [{
390+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
391+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
392+ attr-dict `:` functional-type(operands, results)
393+ }];
394+ }
395+
396+ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
397+ AllRanksMatch<["input", "result"]>,
398+ AllElementTypesMatch<["input", "result"]>
399+ ]> {
400+ let summary = "Gather over a device mesh.";
401+ let description = [{
402+ Gathers on device `root` along the `gather_axis` tensor axis.
403+ `root` specifies the coordinates of a device along `mesh_axes`.
404+ It uniquely identifies the root device for each device group.
405+ The result tensor on non-root devices is undefined.
406+ Using it will result in undefined behavior.
407+
408+ Example:
409+ ```mlir
410+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
411+ ...
412+ %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
413+ gather_axis = 1 root = [1]
414+ : (tensor<2x2xi8>) -> tensor<2x4xi8>
415+ ```
416+ Input:
417+ ```
418+ gather tensor
419+ axis 1
420+ ------------>
421+ +-------+-------+
422+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
423+ | 3 4 | 7 8 |
424+ +-------+-------+
425+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
426+ | 11 12 | 15 16 |
427+ +-------+-------+
428+ ```
429+ Result:
430+ ```
431+ +-------------+
432+ | 1 2 5 6 | <- devices (0, 1)
433+ | 3 4 7 8 |
434+ +-------------+
435+ | 9 10 13 14 | <- devices (1, 1)
436+ | 11 12 15 16 |
437+ +-------------+
438+ ```
439+ Devices `(0, 0)` and `(1, 0)` have undefined result.
440+ }];
441+ let arguments = !con(commonArgs, (ins
442+ AnyNon0RankedTensor:$input,
443+ IndexAttr:$gather_axis,
444+ DenseI64ArrayAttr:$root,
445+ Variadic<Index>:$root_dynamic
446+ ));
447+ let results = (outs
448+ AnyNon0RankedTensor:$result
449+ );
450+ let assemblyFormat = [{
451+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
452+ `gather_axis` `=` $gather_axis
453+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
454+ attr-dict `:` functional-type(operands, results)
455+ }];
456+ }
457+
458+ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
459+ AllShapesMatch<["input", "result"]>,
460+ AllElementTypesMatch<["input", "result"]>
461+ ]> {
462+ let summary = "Send over a device mesh.";
463+ let description = [{
464+ Receive from a device within a device group.
465+ }];
466+ let arguments = !con(commonArgs, (ins
467+ AnyNon0RankedTensor:$input,
468+ OptionalAttr<DenseI64ArrayAttr>:$source,
469+ Variadic<Index>:$source_dynamic
470+ ));
471+ let results = (outs
472+ AnyRankedTensor:$result
473+ );
474+ let assemblyFormat = [{
475+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
476+ (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
477+ attr-dict `:` functional-type(operands, results)
478+ }];
479+ }
480+
481+ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
482+ AllShapesMatch<["input", "result"]>
483+ ]> {
484+ let summary = "Reduce over a device mesh.";
485+ let description = [{
486+ Reduces on device `root` within each device group.
487+ `root` specifies the coordinates of a device along `mesh_axes`.
488+ It uniquely identifies the root device within its device group.
489+ The accumulation element type is specified by the result type and
490+ it does not need to match the input element type.
491+ The input element is converted to the result element type before
492+ performing the reduction.
493+
494+ Attributes:
495+ `reduction`: Indicates the reduction method.
496+
497+ Example:
498+ ```
499+ %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
500+ reduction = <max> root = [2, 3]
501+ : (tensor<3x4xf32>) -> tensor<3x4xf64>
502+ ```
503+ }];
504+ let arguments = !con(commonArgs, (ins
505+ AnyRankedTensor:$input,
506+ DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
507+ DenseI64ArrayAttr:$root,
508+ Variadic<Index>:$root_dynamic
509+ ));
510+ let results = (outs
511+ AnyRankedTensor:$result
512+ );
513+ let assemblyFormat = [{
514+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
515+ (`reduction` `=` $reduction^)?
516+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
517+ attr-dict `:` functional-type(operands, results)
518+ }];
519+ }
520+
342521def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
343522 SameOperandsAndResultRank]> {
344523 let summary = "Reduce-scatter over a device mesh.";
@@ -400,4 +579,154 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
400579 let hasCanonicalizer = 1;
401580}
402581
582+ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
583+ AllRanksMatch<["input", "result"]>,
584+ AllElementTypesMatch<["input", "result"]>
585+ ]> {
586+ let summary = "Scatter over a device mesh.";
587+ let description = [{
588+ For each device group split the input tensor on the `root` device along
589+ axis `scatter_axis` and scatter the parts across the group devices.
590+
591+ Example:
592+ ```
593+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
594+ %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
595+ scatter_axis = 0
596+ root = [1]
597+ : (tensor<2x2xi8>) -> tensor<1x2xi8>
598+ ```
599+
600+ Input:
601+ ```
602+ device
603+ (0, 1)
604+ ↓
605+ +-------+-------+ | scatter tensor
606+ device (0, 0) -> | | | | axis 0
607+ | | | ↓
608+ +-------+-------+
609+ device (1, 0) -> | 1 2 | 5 6 |
610+ | 3 4 | 7 8 |
611+ +-------+-------+
612+ ↑
613+ device
614+ (1, 1)
615+ ```
616+
617+ Result:
618+ ```
619+ device
620+ (0, 1)
621+ ↓
622+ +-------+-------+
623+ device (0, 0) -> | 1 2 | 5 6 |
624+ +-------+-------+
625+ device (1, 0) -> | 3 4 | 7 8 |
626+ +-------+-------+
627+ ↑
628+ device
629+ (1, 1)
630+ ```
631+ }];
632+ let arguments = !con(commonArgs, (ins
633+ AnyNon0RankedTensor:$input,
634+ IndexAttr:$scatter_axis,
635+ DenseI64ArrayAttr:$root,
636+ Variadic<Index>:$root_dynamic
637+ ));
638+ let results = (outs
639+ AnyRankedTensor:$result
640+ );
641+ let assemblyFormat = [{
642+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
643+ `scatter_axis` `=` $scatter_axis
644+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
645+ attr-dict `:` functional-type(operands, results)
646+ }];
647+ }
648+
649+ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
650+ AllShapesMatch<["input", "result"]>,
651+ AllElementTypesMatch<["input", "result"]>
652+ ]> {
653+ let summary = "Send over a device mesh.";
654+ let description = [{
655+ Send from one device to another within a device group.
656+ }];
657+ let arguments = !con(commonArgs, (ins
658+ AnyNon0RankedTensor:$input,
659+ DenseI64ArrayAttr:$destination,
660+ Variadic<Index>:$destination_dynamic
661+ ));
662+ let results = (outs
663+ AnyRankedTensor:$result
664+ );
665+ let assemblyFormat = [{
666+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
667+ `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
668+ attr-dict `:` functional-type(operands, results)
669+ }];
670+ }
671+
672+ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
673+ SameOperandsAndResultElementType,
674+ SameOperandsAndResultShape
675+ ]> {
676+ let summary = "Sift over a device mesh.";
677+ let description = [{
678+ Within each device group shift along mesh axis `shift_axis` by an offset
679+ `offset`.
680+ The result on devices that do not have a corresponding source is undefined.
681+ `shift_axis` must be one of `mesh_axes`.
682+ If the `rotate` attribute is present,
683+ instead of a shift a rotation is done.
684+
685+ Example:
686+ ```
687+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
688+ %1 = mesh.shift on @mesh0 mesh_axes = [1]
689+ shift_axis = 1 offset = 2 rotate
690+ : tensor<2xi8> -> tensor<2xi8>
691+ ```
692+
693+ Input:
694+ ```
695+ mesh axis 1
696+ ----------->
697+
698+ +----+----+----+----+
699+ | 1 | 2 | 3 | 4 |
700+ +----+----+----+----+
701+ | 5 | 6 | 7 | 8 |
702+ +----+----+----+----+
703+ ```
704+
705+ Result:
706+ ```
707+ +----+----+----+----+
708+ | 3 | 4 | 1 | 2 |
709+ +----+----+----+----+
710+ | 7 | 8 | 5 | 6 |
711+ +----+----+----+----+
712+ ```
713+ }];
714+ let arguments = !con(commonArgs, (ins
715+ AnyNon0RankedTensor:$input,
716+ IndexAttr:$shift_axis,
717+ I64Attr:$offset,
718+ UnitAttr:$rotate
719+ ));
720+ let results = (outs
721+ AnyRankedTensor:$result
722+ );
723+ let assemblyFormat = [{
724+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
725+ `shift_axis` `=` $shift_axis
726+ `offset` `=` $offset
727+ (`rotate` $rotate^)?
728+ attr-dict `:` type($input) `->` type($result)
729+ }];
730+ }
731+
403732#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
0 commit comments