From 785d21fc3c6b805ed8e700de55bb3ed82b21018e Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Mon, 1 Sep 2025 21:05:55 +0000 Subject: [PATCH 1/4] extend ptr_add op --- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h | 1 + mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 103 +++++++++++------- mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 40 +++++++ .../Conversion/PtrToLLVM/ptr-to-llvm.mlir | 12 +- mlir/test/Dialect/Ptr/invalid.mlir | 16 +++ mlir/test/Dialect/Ptr/ops.mlir | 65 +++++++++++ mlir/test/Target/LLVMIR/ptr.mlir | 30 +++++ 8 files changed, 225 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h index 8686cc7d316d4..eaf1e6243a74d 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h @@ -18,6 +18,7 @@ #include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 5939c3646884c..3ac12978b947c 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td" include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td" include "mlir/Dialect/Ptr/IR/PtrEnums.td" include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/OpAsmInterface.td" @@ -34,8 +35,15 @@ class Ptr_ShapedValueType allowedTypes, list preds = []> : /*descr=*/[{A shaped type with value semantics and rank.}], /*cppType=*/"::mlir::ShapedType">; -// A shaped pointer type with value semantics and rank. -class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>; +// A ptr-like type, either scalar or shaped type with value semantics. +def Ptr_PtrLikeType : + AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>; + +// An int-like type, either scalar or shaped type with value semantics. +def Ptr_IntLikeType :AnyTypeOf<[ + Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>, + AnySignlessIntegerOrIndex +]>; // A shaped value type of rank 1 of any element type. def Ptr_Any1DType : @@ -167,41 +175,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [ }]; } -//===----------------------------------------------------------------------===// -// PtrAddOp -//===----------------------------------------------------------------------===// - -def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ - Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface - ]> { - let summary = "Pointer add operation"; - let description = [{ - The `ptr_add` operation adds an integer offset to a pointer to produce a new - pointer. The input and output pointer types are always the same. - - Example: - - ```mlir - %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32 - %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32 - ``` - }]; - - let arguments = (ins - Ptr_PtrType:$base, - AnySignlessIntegerOrIndex:$offset, - DefaultValuedProp, "PtrAddFlags::none">:$flags); - let results = (outs Ptr_PtrType:$result); - let assemblyFormat = [{ - ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset) - }]; - let hasFolder = 1; - let extraClassDeclaration = [{ - /// `ViewLikeOp::getViewSource` method. - Value getViewSource() { return getBase(); } - }]; -} - //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// @@ -361,6 +334,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// PtrAddOp +//===----------------------------------------------------------------------===// + +def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ + Pure, ViewLikeOpInterface, + DeclareOpInterfaceMethods + ]> { + let summary = "Pointer add operation"; + let description = [{ + The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers. + + The operation supports both scalar and shaped types with value semantics: + - When both base and offset are scalar: produces a single new pointer + - When base is shaped and offset is scalar: adds the same offset to each + pointer in the base + - When base is scalar and offset is shaped: adds the single pointer to each + offset in the shaped value + - When both are shaped: performs element-wise addition (shapes must be + compatible) + + Example: + + ```mlir + // Scalar base and offset + %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32 + %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32 + + // Shaped base with scalar offset + %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32 + + // Scalar base with shaped offset + %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32> + + // Both base and offset are shaped + %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32> + ``` + }]; + let arguments = (ins + Ptr_PtrLikeType:$base, + Ptr_IntLikeType:$offset, + DefaultValuedProp, "PtrAddFlags::none">:$flags); + let results = (outs Ptr_PtrLikeType:$result); + let assemblyFormat = [{ + ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset) + }]; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// `ViewLikeOp::getViewSource` method. + Value getViewSource() { return getBase(); } + + /// Returns the ptr type of the operation. + ptr::PtrType getPtrType(); + }]; +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index bd1e655fc6b5e..a6b0d416a4165 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -33,6 +33,7 @@ add_mlir_dialect_library( MLIRIR MLIRDataLayoutInterfaces MLIRMemorySlotInterfaces + MLIRInferTypeOpInterface MLIRViewLikeInterface MLIRPtrMemorySpaceInterfaces ) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 92ce9be97dd2c..6697f5382db6b 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -346,6 +346,46 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) { return nullptr; } +LogicalResult PtrAddOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // Get the base pointer and offset types. + Type baseType = operands[0].getType(); + Type offsetType = operands[1].getType(); + + // If neither are shaped types, result is same as base type. + if (!isa(baseType) && !isa(offsetType)) { + inferredReturnTypes.push_back(baseType); + return success(); + } + + // Handle cases with shaped types. + if (auto baseTy = dyn_cast(baseType)) { + // If both shaped, they must have the same shape. + if (auto offTy = dyn_cast(offsetType)) { + if (offTy.getShape() != baseTy.getShape()) { + if (location) + mlir::emitError(*location) << "shapes of base and offset must match"; + return failure(); + } + // Make sure they are the same kind of shaped type. + if (baseType.getTypeID() != offsetType.getTypeID()) { + if (location) + mlir::emitError(*location) << "the shaped containers type must match"; + return failure(); + } + } + inferredReturnTypes.push_back(baseType); + return success(); + } + + // Base is scalar, offset is shaped. + auto offsetShapedType = cast(offsetType); + inferredReturnTypes.push_back(offsetShapedType.clone(baseType)); + return success(); +} + //===----------------------------------------------------------------------===// // ToPtrOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir index dc645fe0480fa..5128fd8ccb265 100644 --- a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir +++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir @@ -16,10 +16,10 @@ // CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)> // CHECK: } func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) { - %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index - %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index - %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index - %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index + %0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index + %1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index + %2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index + %3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space> } @@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref, #ptr.generic_space> -> <#ptr.generic_space> %1 = ptr.get_metadata %arg0 : memref, #ptr.generic_space> %2 = ptr.type_offset f32 : index - %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index + %3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index %4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref, #ptr.generic_space> return %4 : memref, #ptr.generic_space> } @@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space> %1 = ptr.type_offset f32 : index %2 = arith.muli %1, %arg1 : index - %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index + %3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index return %3 : !ptr.ptr<#ptr.generic_space> } diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir index 0c34ae43bf6be..cc1eeb3cb5744 100644 --- a/mlir/test/Dialect/Ptr/invalid.mlir +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref, !ptr.ptr<#llvm.address_space<1>> return } + +// ----- + +func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> { + // expected-error@+1 {{the shaped containers type must match}} + %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64> + return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> +} + +// ----- + +func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> { + // expected-error@+1 {{shapes of base and offset must match}} + %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64> + return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> +} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index bde2fb22b6aa0..c008b858af0d7 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -11,6 +11,8 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<# return %res : !ptr.ptr<#ptr.generic_space> } + + /// Check cast ops assembly. func.func @cast_ops(%mr: memref) -> memref { %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> @@ -126,3 +128,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>> return %0 : vector<4xf32> } + +/// Test ptr_add with shaped operands (vectors) +func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> { + %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex> + %res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex> + %res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex> + %res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex> + %res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex> + return %res : vector<4x!ptr.ptr<#ptr.generic_space>> +} + +/// Test ptr_add with shaped operands (tensors) +func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> { + %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64> + return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> +} + +/// Test ptr_add with 2D tensors +func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> { + %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex> + %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex> + return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>> +} + +/// Test ptr_add with scalar base and shaped offsets (vectors) +func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> { + %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex> + %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex> + %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex> + %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex> + %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex> + return %res : vector<4x!ptr.ptr<#ptr.generic_space>> +} + +/// Test ptr_add with scalar base and shaped offsets (tensors) +func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> { + %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64> + %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64> + %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64> + %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64> + %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64> + return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> +} + +/// Test ptr_add with shaped base and scalar offset (vectors) +func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> { + %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index + %res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index + %res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index + %res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index + %res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index + return %res : vector<4x!ptr.ptr<#ptr.generic_space>> +} + +/// Test ptr_add with shaped base and scalar offset (tensors) +func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> { + %res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64 + %res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64 + %res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64 + %res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64 + %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64 + return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> +} diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir index 545bec5979b2d..4b29be813fa81 100644 --- a/mlir/test/Target/LLVMIR/ptr.mlir +++ b/mlir/test/Target/LLVMIR/ptr.mlir @@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3> ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>> llvm.return } + +// CHECK-LABEL: define <4 x ptr> @ptr_add_vector +// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) { +// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]] +// CHECK-NEXT: ret <4 x ptr> %[[RES]] +// CHECK-NEXT: } +llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> { + %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32> + llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>> +} + +// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets +// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) { +// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]] +// CHECK-NEXT: ret <4 x ptr> %[[RES]] +// CHECK-NEXT: } +llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> { + %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32> + llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>> +} + +// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset +// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) { +// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]] +// CHECK-NEXT: ret <4 x ptr> %[[RES]] +// CHECK-NEXT: } +llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> { + %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32 + llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>> +} From efd30cab07840764814de7ea0a55e29320b1d21d Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 3 Sep 2025 14:53:57 +0000 Subject: [PATCH 2/4] address comments --- mlir/test/Dialect/Ptr/ops.mlir | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index c008b858af0d7..51e5ac3ae691d 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -11,8 +11,6 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<# return %res : !ptr.ptr<#ptr.generic_space> } - - /// Check cast ops assembly. func.func @cast_ops(%mr: memref) -> memref { %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> From 60bdbf9c8669a9539d1ea63659fde672c9bb60e3 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 3 Sep 2025 11:32:50 -0400 Subject: [PATCH 3/4] Apply suggestion from @joker-eph Co-authored-by: Mehdi Amini --- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 44 ++++++++++++-------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 6697f5382db6b..33366ec628659 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -355,34 +355,32 @@ LogicalResult PtrAddOp::inferReturnTypes( Type offsetType = operands[1].getType(); // If neither are shaped types, result is same as base type. - if (!isa(baseType) && !isa(offsetType)) { + auto offTy = dyn_cast(offsetType); + if (!offTy) { + // If the offset isn't shaped, the result is always the base type. inferredReturnTypes.push_back(baseType); return success(); } - - // Handle cases with shaped types. - if (auto baseTy = dyn_cast(baseType)) { - // If both shaped, they must have the same shape. - if (auto offTy = dyn_cast(offsetType)) { - if (offTy.getShape() != baseTy.getShape()) { - if (location) - mlir::emitError(*location) << "shapes of base and offset must match"; - return failure(); - } - // Make sure they are the same kind of shaped type. - if (baseType.getTypeID() != offsetType.getTypeID()) { - if (location) - mlir::emitError(*location) << "the shaped containers type must match"; - return failure(); - } - } - inferredReturnTypes.push_back(baseType); - return success(); + auto baseTy = dyn_cast(baseType); + if (!baseTy) { + // Base isn't shaped, but offset is, use the ShapedType from offset with the base pointer as element type. + inferredReturnTypes.push_back(offsetShapedType.clone(baseType)); + return success(); } - // Base is scalar, offset is shaped. - auto offsetShapedType = cast(offsetType); - inferredReturnTypes.push_back(offsetShapedType.clone(baseType)); + // Both are shaped, their shape must match. + if (offTy.getShape() != baseTy.getShape()) { + if (location) + mlir::emitError(*location) << "shapes of base and offset must match"; + return failure(); + } + // Make sure they are the same kind of shaped type. + if (baseType.getTypeID() != offsetType.getTypeID()) { + if (location) + mlir::emitError(*location) << "the shaped containers type must match"; + return failure(); + } + inferredReturnTypes.push_back(baseType); return success(); } From 08b5079e4133ba7453b366b9fa8c7339864d272f Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 3 Sep 2025 15:41:15 +0000 Subject: [PATCH 4/4] fix formatting --- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 33366ec628659..284c998690170 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -354,7 +354,6 @@ LogicalResult PtrAddOp::inferReturnTypes( Type baseType = operands[0].getType(); Type offsetType = operands[1].getType(); - // If neither are shaped types, result is same as base type. auto offTy = dyn_cast(offsetType); if (!offTy) { // If the offset isn't shaped, the result is always the base type. @@ -363,9 +362,10 @@ LogicalResult PtrAddOp::inferReturnTypes( } auto baseTy = dyn_cast(baseType); if (!baseTy) { - // Base isn't shaped, but offset is, use the ShapedType from offset with the base pointer as element type. - inferredReturnTypes.push_back(offsetShapedType.clone(baseType)); - return success(); + // Base isn't shaped, but offset is, use the ShapedType from offset with the + // base pointer as element type. + inferredReturnTypes.push_back(offTy.clone(baseType)); + return success(); } // Both are shaped, their shape must match. @@ -374,6 +374,7 @@ LogicalResult PtrAddOp::inferReturnTypes( mlir::emitError(*location) << "shapes of base and offset must match"; return failure(); } + // Make sure they are the same kind of shaped type. if (baseType.getTypeID() != offsetType.getTypeID()) { if (location)