Skip to content

Conversation

silee2
Copy link
Contributor

@silee2 silee2 commented Aug 21, 2025

This PR tightens some loose ends in some XeGPU op definitions.
Changes are backward compatible except for

  • Enforcing previous implicit assumption of load/store/prefetch offsets is required if source/dest is not a scatter tensor descriptor.
  • Likewise, enforce offsets is not allowed if source/dest is a scatter tensor descriptor.
  • Additionally, allow i64, i32 and ui32 as source/dest for load/store/prefetch. This matches behavior of tensor descriptor which allows i64, i32 and ui32 base address in addition to ui64
  • Explicitly state that create_tdesc and update_offset ops are not valid in SIMT mode. create_tdesc and update_offset ops are still available for subgroup level non SIMT mode.

New test cases are added for the new enforced checks.

Other minor implementation change:
XeGPU scatter tensor descriptor only allows 1D base memref. This was check in op verify() method. Now moved to tablegen - ODS - definition.

@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

This PR tightens some loose ends in some XeGPU op definitions.
Changes are backward compatible except for

  • Enforcing previous implicit assumption of load/store/prefetch offsets is required if source/dest is not a scatter tensor descriptor.
  • Likewise, enforce offsets is not allowed if source/dest is a scatter tensor descriptor.
  • Additionally, allow i64, i32 and ui32 as source/dest for load/store/prefetch. This matches behavior of tensor descriptor which allows i64, i32 and ui32 base address in addition to ui64
  • Explicitly state that create_tdesc and update_offset ops are not valid in SIMT mode. create_tdesc and update_offset ops are still available for subgroup level non SIMT mode.

New test cases are added for the new enforced checks.

Other minor implementation change:
XeGPU scatter tensor descriptor only allows 1D base memref. This was check in op verify() method. Now moved to tablegen - ODS - definition.


Full diff: https://github.com/llvm/llvm-project/pull/154653.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+65-25)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+4-2)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+18-23)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+58-3)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..8fd04a5d4cdcf 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -500,7 +500,8 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     (scattered) subviews, allowing each work-item in a subgroup specifying their own offset.
     It accepts the following parameters:
 
-    * source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
+    * source: a 1D memref or pointer (i64, i32, ui64, ui32) represents the flattened
+      memory object.
     * offsets: a vector containing offsets of each access point. Its size
       is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
       implying each element in the vector corresponds to a work-item (SIMT lane)
@@ -510,6 +511,8 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     match the dimension of offsets. It may also has a second dimension corresponding to
     the chunk_size if the chunk size is larger than 1.
 
+    This op is not available in SIMT mode.
+
     Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
     ```mlir
     %a = memref.alloc() : memref<1024xf32>
@@ -536,7 +539,7 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     ```
   }];
 
-  let arguments = (ins XeGPU_BaseAddrType: $source,
+  let arguments = (ins XeGPU_GatherScatterBaseAddrType: $source,
                        XeGPU_OffsetType: $offsets);
   let results = (outs XeGPU_TensorDesc:$TensorDesc);
 
@@ -617,6 +620,15 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
         : memref<1024xf32>, vector<4xindex>
     ```
 
+    Example 3 (SIMT mode):
+    SIMT mode only accepts the offsets variant.
+    ```mlir
+      xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
+                             l2_hint = #xegpu.cache_hint<cached>,
+                             l3_hint = #xegpu.cache_hint<cached>}
+        : memref<256xf32>, vector<1xindex>
+    ```
+
   }];
 
   let arguments = (ins XeGPU_GatherScatterSourceType: $source,
@@ -670,8 +682,19 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
-    In SIMT mode, the result vector represents the data to be loaded by each work-item.
-    Each work-item recieves a `chunk_size` number of elements.
+    In SIMT mode, the result is a 1D vector that represents the data to be loaded by
+    each work-item. If size is not 1, size should be equal to the chunk size,
+
+    `source` represents the memory region to be loaded from, which can be either a
+        tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+        In case of tensor_desc, offsets come from the producer create_tdesc op.
+        tensor_desc cannot be used in SIMT mode.
+    `offsets` represents offsets from source. required if `source` in not a TensorDescType.
+        offsets is a vector of `index` type and vector length is either the subgroup size
+        or 1 in SIMT mode.
+    `mask` is a vector of `i1` type, which is used to mask out the memory access.
+        mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
+    `chunk_size` (optional) represents contiguous number of elements to load from per work item.
 
   Example 1:
   ```mlir
@@ -691,16 +714,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
             vector<16xi1> -> vector<16x8xf32>
   ```
 
-  Example 3 (SIMT mode):
-  ```mlir
-    %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
-                             l2_hint = #xegpu.cache_hint<uncached>,
-                             l3_hint = #xegpu.cache_hint<uncached>}>
-          : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
-            vector<16xi1> -> vector<8xf32>
-  ```
-
-  Example 4:
+  Example 3:
   A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
   It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
   The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
@@ -715,6 +729,16 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
   ```
 
+  Example 4 (SIMT mode):
+  SIMT mode only accepts the offsets variant. chunk_size can be inferred from result
+  type. In this example, chunk_size is 8.
+  ```mlir
+    %2 = xegpu.load %1[%2], %0 <{l1_hint = #xegpu.cache_hint<cached>,
+                             l2_hint = #xegpu.cache_hint<uncached>,
+                             l3_hint = #xegpu.cache_hint<uncached>}>
+          : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
+  ```
+
   }];
 
   let arguments = (ins XeGPU_GatherScatterSourceType: $source,
@@ -784,8 +808,20 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
   has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
   introduced on purpose, making sure users are aware of this implicit transformation.
 
-  In SIMT mode, the input vector represents the data to be stored by each work-item.
-  Each work-item stores a `chunk_size` number of elements.
+  In SIMT mode, the result is a 1D vector that represents the data to be stored by
+  each work-item. If size is not 1, size should be equal to the chunk size.
+
+    `value` represents the data to be stored.
+    `dest` represents the memory region to be stored to, which can be either a
+        tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+        In case of tensor_desc, offsets come from the producer create_tdesc op.
+        tensor_desc cannot be used in SIMT mode.
+    `offsets` represents offsets from dest. required if `source` in not a TensorDescType.
+        offsets is a vector of `index` type and vector length is either the subgroup size
+        or 1 in SIMT mode.
+    `mask` is a vector of `i1` type, which is used to mask out the memory access.
+        mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
+    `chunk_size` (optional) represents contiguous number of elements to store to per work item.
 
   Example 1:
   ```mlir
@@ -803,15 +839,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
           : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
 
-  Example 3 (SIMT mode):
-  ```mlir
-    xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
-                             l2_hint = #xegpu.cache_hint<write_back>,
-                             l3_hint = #xegpu.cache_hint<write_through>}>
-          : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
-  ```
-
-  Example 4:
+  Example 3:
   A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
   It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
   The dest operand could be a raw pointer (uint64_t).
@@ -827,6 +855,16 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
   ```
 
+  Example 4 (SIMT mode):
+  SIMT mode only accepts the offsets variant. chunk_size can be inferred from value
+  type. In this example, chunk_size is 8.
+  ```mlir
+    xegpu.store %0, %1[%2], %3 <{l1_hint = #xegpu.cache_hint<uncached>,
+                             l2_hint = #xegpu.cache_hint<write_back>,
+                             l3_hint = #xegpu.cache_hint<write_through>}>
+          : vector<8xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>
+  ```
+
   }];
 
   let arguments = (ins
@@ -895,6 +933,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
     update the offset per work-item, so its offsets contains values representing
     shifts for each work-item.
 
+    This op is not available in SIMT mode.
+
     Example:
     ```mlir
       %off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index f8b371db498e8..53ecedab5406d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -16,13 +16,15 @@ include "mlir/IR/BuiltinTypes.td"
 def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
 def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
-def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
+def XeGPU_PointerType: AnyTypeOf<[UI64, UI32, I64, I32]>;
+def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
 def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
 def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
 def XeGPU_VectorType: VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType]>;
+def XeGPU_GatherScatterBaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1]>, XeGPU_PointerType]>;
 
 // common base class for types in XeGPU dialect
 class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -189,7 +191,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   let genVerifyDecl = 1;
 }
 
-def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
+def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,XeGPU_GatherScatterBaseAddrType]>;
 
 def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
   let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..cf5da7a416846 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -58,13 +58,6 @@ static SmallVector<int64_t> getShapeOf(Type type) {
   return shape;
 }
 
-static int64_t getRankOf(Value val) {
-  auto type = val.getType();
-  if (auto ty = llvm::dyn_cast<ShapedType>(type))
-    return ty.getRank();
-  return 0;
-}
-
 static bool isReadHintOrNone(const CachePolicyAttr &attr) {
   if (!attr)
     return true;
@@ -685,10 +678,6 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state,
 LogicalResult CreateDescOp::verify() {
   auto tdescTy = getTensorDescType();
 
-  if (getRankOf(getSource()) > 1)
-    return emitOpError(
-        "Expecting the source is a 1D memref or pointer (uint64_t).");
-
   if (!tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.\n");
 
@@ -723,13 +712,15 @@ LogicalResult CreateDescOp::verify() {
 LogicalResult PrefetchOp::verify() {
   auto tdescTy = getTensorDescType();
 
+  if (!tdescTy && !getOffsets())
+    return emitOpError("Expects offsets.");
+
+  if (tdescTy && getOffsets())
+    return emitOpError("offsets not allowed.");
+
   if (tdescTy && !tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.");
 
-  if (!tdescTy && getRankOf(getSource()) > 1)
-    return emitOpError(
-        "Expecting the source is a 1D memref or pointer (uint64_t).");
-
   if (!isReadHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
@@ -757,13 +748,15 @@ LogicalResult LoadGatherOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
+  if (!tdescTy && !getOffsets())
+    return emitOpError("Expects offsets.");
+
+  if (tdescTy && getOffsets())
+    return emitOpError("offsets not allowed.");
+
   if (tdescTy && !tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.");
 
-  if (!tdescTy && getRankOf(getSource()) > 1)
-    return emitOpError(
-        "Expecting the source is a 1D memref or pointer (uint64_t).");
-
   if (!isReadHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
@@ -804,13 +797,15 @@ LogicalResult StoreScatterOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
+  if (!tdescTy && !getOffsets())
+    return emitOpError("Expects offsets.");
+
+  if (tdescTy && getOffsets())
+    return emitOpError("offsets not allowed.");
+
   if (tdescTy && !tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.");
 
-  if (!tdescTy && getRankOf(getDest()) > 1)
-    return emitOpError(
-        "Expecting the dest is a 1D memref or pointer (uint64_t).");
-
   if (!isWriteHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 93a5a055b08c6..c076ac78b9edd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -387,11 +387,28 @@ func.func @load_gather_vc_3(%src: ui64) {
 // -----
 func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
   %offsets = arith.constant dense<[0]> : vector<1xindex>
-  // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+  // expected-error@+1 {{op operand #0 must be TensorDesc describing regions of interested data}}
   xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
   return
 }
 
+// -----
+func.func @prefetch_offset_wi_2(%src: memref<16xf32>) {
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  %1 = xegpu.create_tdesc %src, %offsets : memref<16xf32>, vector<1xindex>
+          -> !xegpu.tensor_desc<1x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
+  // expected-error@+1 {{offsets not allowed}}
+  xegpu.prefetch %1[%offsets]: !xegpu.tensor_desc<1x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>>, vector<1xindex>
+  return
+}
+
+// -----
+func.func @prefetch_offset_wi_3(%src: memref<16xf32>) {
+  // expected-error@+1 {{Expects offsets}}
+  xegpu.prefetch %src: memref<16xf32>
+  return
+}
+
 // -----
 func.func @load_gather_offset_sg(%src: memref<?xf16>) {
   %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -428,12 +445,50 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
   %val = arith.constant dense<2.9>: vector<4xf16>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
-  // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
+  // expected-error@+1 {{op operand #1 must be TensorDesc describing regions of interested data}}
   xegpu.store %val, %src[%offsets], %mask
         : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
   return
 }
 
+// -----
+func.func @store_scatter_offset_wi_3(%src: memref<16xf16>) {
+  %val = arith.constant dense<2.9>: vector<1xf16>
+  %mask = arith.constant dense<1>: vector<1xi1>
+  // expected-error@+1 {{Expects offsets}}
+  xegpu.store %val, %src, %mask
+        : vector<1xf16>, memref<16xf16>, vector<1xi1>
+  return
+}
+
+// -----
+func.func @store_scatter_offset_wi_4(%src: !xegpu.tensor_desc<1x1xf32, #xegpu.scatter_tdesc_attr<>>) {
+  %val = arith.constant dense<2.9>: vector<1xf16>
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  %mask = arith.constant dense<1>: vector<1xi1>
+  // expected-error@+1 {{offsets not allowed}}
+  xegpu.store %val, %src[%offsets], %mask
+        : vector<1xf16>, !xegpu.tensor_desc<1x1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1>
+  return
+}
+
+// -----
+func.func @load_gather_offset_wi_4(%src: !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<>>) {
+  %mask = arith.constant dense<1>: vector<1xi1>
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  // expected-error@+1 {{offsets not allowed}}
+  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1> -> vector<2xf16>
+  return
+}
+
+// -----
+func.func @load_gather_offset_wi_3(%src: ui64) {
+  %mask = arith.constant dense<1>: vector<1xi1>
+  // expected-error@+1 {{Expects offsets}}
+  %2 = xegpu.load %src, %mask <{chunk_size = 2}> : ui64, vector<1xi1> -> vector<2xf16>
+  return
+}
+
 // -----
 func.func @load_gather_offset_wi_2(%src: ui64) {
   %mask = arith.constant dense<1>: vector<1xi1>
@@ -447,7 +502,7 @@ func.func @load_gather_offset_wi_2(%src: ui64) {
 func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
   %mask = arith.constant dense<1>: vector<1xi1>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
-  // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+  // expected-error@+1 {{op operand #0 must be TensorDesc describing regions of interested data}}
   %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>,  vector<1xindex>, vector<1xi1> -> vector<2xf32>
   return
 }

@silee2
Copy link
Contributor Author

silee2 commented Aug 21, 2025

@akroviakov

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great cleanup 👍
One design question

Copy link
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, makes #154556 less confusing

```

Example 4:
Example 3:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
The source operand could be a raw pointer (ui64, ui32, i64 or i32). Please refer to create_tdesc

```

Example 4:
Example 3:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The dest operand could be a raw pointer (uint64_t).
The dest operand could be a raw pointer (ui64, ui32, i64 or i32).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank! Missed that one.

@akroviakov
Copy link
Contributor

Btw, what is the reason to demand vector offsets in the op definition? For SIMT distribution, given a vector of SIMD offsets, we would need to extract an element from 1D vector at lane idx, to my understanding. vector.extract says:

the result degenerates to a scalar element.

Why do we need to make extra steps for wrapping vector.extract result into a vector and then get the scalar back via materializations, if we could allow scalar elements in the first place?
Do we prohibit users from supplying a single-element offset/mask vector at WG level code somehow? If not, then how is it different from passing a scalar?

@silee2
Copy link
Contributor Author

silee2 commented Aug 21, 2025

Btw, what is the reason to demand vector offsets in the op definition? For SIMT distribution, given a vector of SIMD offsets, we would need to extract an element from 1D vector at lane idx, to my understanding. vector.extract says:

the result degenerates to a scalar element.

Why do we need to make extra steps for wrapping vector.extract result into a vector and then get the scalar back via materializations, if we could allow scalar elements in the first place? Do we prohibit users from supplying a single-element offset/mask vector at WG level code somehow? If not, then how is it different from passing a scalar?

I tried not to change op definition as much as possible but agree with you on this one. vector.extract and vector.insert may get optimized away later, but better to not introduce the problem to begin with. And the change will not break any exsiting cases. Updated PR to allow scalar offset for load/store/prefetch.

@akroviakov
Copy link
Contributor

akroviakov commented Aug 21, 2025

Another possible inconsistency for SIMT code in the current upstream. Loadgather can have a memref source, in this case, the chunk size needs to be part of the op. In verification, when we check the mask shape against the chunk from the op (not from tdesc), we have:

  llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
  if (chunkSize > 1)
    expectedMaskShape.pop_back();
  if (expectedMaskShape != maskShape)
    return emitError() << "Mask should match value except the chunk size dim.";

For SIMT (i.e., 1D valueShape) code with non-1 chunk size, this means that we always hit (expectedMaskShape != maskShape), because we pop from the 1D mask shape, but mask is still 1D.

Example IR:

    gpu.func @scatter_ops(%arg0: memref<256xf16>, %arg1: vector<16xindex>) {
      %cst = arith.constant dense<true> : vector<1xi1>
      %0 = gpu.lane_id
      %1 = vector.extract %arg1[%0] : index from vector<16xindex>
      %2 = vector.broadcast %1 : index to vector<1xindex>
      %3 = xegpu.load %arg0[%2], %cst <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
      xegpu.store %3, %arg0[%2], %cst <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
      gpu.return
    }

Could you please add this example as a test (maybe a simplified version of it) to see whether this PR addresses the SIMT verification inconsistency?

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

But I am not sure about the need for scalar offsets. Maybe you can clarify that a bit.

OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let arguments = (ins XeGPU_GatherScatterSourceType:$source,
Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not move Index to XeGPU_OffsetType?

Copy link
Contributor Author

@silee2 silee2 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalar offset is provided for simplifying layout distribution to SIMT.
But single element vector of offset is also valid.
Scalar is an exception.
I think keeping XeGPU_OffsetType as vector type and explicitly spelling out scalar Index type fits that idea better and helps code readability.
Also, maybe XeGPU may introduce other ops in the future that only need vector offsets and not scalar offsets.
Such usage case would require scalar and vector to be separated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will supporting both vector<1xindex> and index incur more if-else check in patterns or it can be handled in one place, so patterns don't need to care about this?

tensor_desc cannot be used in SIMT mode.
- `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
offsets is a vector of `index` type and vector length is either the subgroup size
or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the scalar offset case needed? I would imagine the offsets will always be a vector if coming from SG/WG level. So I don't think this is required (but nice to have if it does not complicate our logic too much).

I think a vector<1xindex> would also do the exact same thing (DCE, canonicalize kick in) with much less maintainence effort on our side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping single element vector and relying on clean up passes to remove vector.extract and vector.insert would also work.
But keeping and maintaining scalar offsets and mask is not too much of a burden.
Only extra burden is op validation.
XeGPU to XeVM conversion will materialize single element vector to scalar. In case of scalar, materialization just isn't needed. Conversion logic is identical.

OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let arguments = (ins XeGPU_GatherScatterSourceType:$source,
Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above, consider moving I1 inside MarkTy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping it separate is better for now. Same reason as stated above for offsets.

@silee2
Copy link
Contributor Author

silee2 commented Aug 22, 2025

Another possible inconsistency for SIMT code in the current upstream. Loadgather can have a memref source, in this case, the chunk size needs to be part of the op. In verification, when we check the mask shape against the chunk from the op (not from tdesc), we have:

  llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
  if (chunkSize > 1)
    expectedMaskShape.pop_back();
  if (expectedMaskShape != maskShape)
    return emitError() << "Mask should match value except the chunk size dim.";

For SIMT (i.e., 1D valueShape) code with non-1 chunk size, this means that we always hit (expectedMaskShape != maskShape), because we pop from the 1D mask shape, but mask is still 1D.

Example IR:

    gpu.func @scatter_ops(%arg0: memref<256xf16>, %arg1: vector<16xindex>) {
      %cst = arith.constant dense<true> : vector<1xi1>
      %0 = gpu.lane_id
      %1 = vector.extract %arg1[%0] : index from vector<16xindex>
      %2 = vector.broadcast %1 : index to vector<1xindex>
      %3 = xegpu.load %arg0[%2], %cst <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
      xegpu.store %3, %arg0[%2], %cst <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
      gpu.return
    }

Could you please add this example as a test (maybe a simplified version of it) to see whether this PR addresses the SIMT verification inconsistency?

Updated validation logic for SIMT mode and added valid SIMT mode op tests.

@silee2
Copy link
Contributor Author

silee2 commented Aug 22, 2025

LGTM.

But I am not sure about the need for scalar offsets. Maybe you can clarify that a bit.

Scalar offsets and masks are provided for ease of use.
Layout distribution to SIMT will insert vector.extract based on lane id to get per lane offset and mask.
Extracted offset and mask will be scalar values.
Allowing scalar offsets and mask helps distribution pass avoid creating vector.insert to create single element vector for offsets and mask.

@@ -595,6 +604,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.
Copy link
Contributor

@chencha3 chencha3 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could you help to update the doc here? It seems the definition has switched to use memref/pointer instead of TensorDesc. I didn't see XeGPU_GatherScatterSourceType contains TensorDesc, or maybe I miss it. But anyway, we are retiring the support for scattered TensorDesc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad. just realized that XeGPU_GatherScatterSourceType is different from XeGPU_GatherScatterBaseSourceType

Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help to get rid of getTensorDesc() and getTensorDescType() methods if XeGPU_GatherScatterSourceType doesn't support TensorDesc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, XeGPU_GatherScatterSourceType includes XeGPU_TensorType

OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let arguments = (ins XeGPU_GatherScatterSourceType:$source,
Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will supporting both vector<1xindex> and index incur more if-else check in patterns or it can be handled in one place, so patterns don't need to care about this?

tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
In case of tensor_desc, offsets come from the producer create_tdesc op.
tensor_desc cannot be used in SIMT mode.
- `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
- `offsets`: represents offsets from source. required if `source` in not a TensorDesc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector<1xindex> can auto convert to index by materialization cast.
So lowering pattern will always see scalar index

@@ -89,13 +82,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
if (!tdescTy.isScattered())
return emitError() << "Expects a scattered TensorDesc.";

if (!valueTy)
return emitError() << "Expecting a vector type result.";
auto chunkSize = tdescTy.getChunkSizeAsInt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since support for scattered TensorDesc is removed, is this method still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support is not removed.

if (!valueTy) {
if (chunkSize > 1)
return emitError() << "Expecting chunk size == 1 for scalar result";
if (maskVecTy || offsetsVecTy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe using llvm::all_of to check all of mask, value, and offsets are VectorType or ScalarType help to increase the code readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of mask and offset are Vector or Scalar type
But value depends on chunk size.
I'm not too familiar with using llvm::all_of and available predicate lambda to use with it.
PR is getting a bit big so will optimize in a follow up PR.

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +166 to +167
if (maskSize == 1)
return success();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might be less error-prone to perform remaining checks when maskSize is non-unit instead of an early exist.
As in:

if (maskSize != 1) {
  // perform checks
}

Copy link
Contributor Author

@silee2 silee2 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, code below that line is for maskSize != 1
Instead of If-then-else,
I'm doing If with early exit and fallthrough is the else clause.

@silee2 silee2 merged commit 1bf31c3 into llvm:main Aug 22, 2025
9 checks passed
charithaintc pushed a commit that referenced this pull request Sep 3, 2025
This PR adds distribution patterns for scattered load and store ops,
chunk size included.

XeGPU moves toward offsets being part of the load/store ops, so the pass
only supports this case. Manipulating a vector of offsets indirectly
through create_tdesc is complex and soon to become obsolete anyway.
This PR assumes the SIMT-adapted scatter ops verification introduced in
#154653. The distribution
itself can be reviewed in the meantime.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 3, 2025
This PR adds distribution patterns for scattered load and store ops,
chunk size included.

XeGPU moves toward offsets being part of the load/store ops, so the pass
only supports this case. Manipulating a vector of offsets indirectly
through create_tdesc is complex and soon to become obsolete anyway.
This PR assumes the SIMT-adapted scatter ops verification introduced in
llvm/llvm-project#154653. The distribution
itself can be reviewed in the meantime.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants