Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

Expand Down
103 changes: 66 additions & 37 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
include "mlir/Dialect/Ptr/IR/PtrEnums.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
Expand All @@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
/*descr=*/[{A shaped type with value semantics and rank.}],
/*cppType=*/"::mlir::ShapedType">;

// A shaped pointer type with value semantics and rank.
class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
// A ptr-like type, either scalar or shaped type with value semantics.
def Ptr_PtrLikeType :
AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;

// An int-like type, either scalar or shaped type with value semantics.
def Ptr_IntLikeType :AnyTypeOf<[
Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
AnySignlessIntegerOrIndex
]>;

// A shaped value type of rank 1 of any element type.
def Ptr_Any1DType :
Expand Down Expand Up @@ -167,41 +175,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
}];
}

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//

def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
]> {
let summary = "Pointer add operation";
let description = [{
The `ptr_add` operation adds an integer offset to a pointer to produce a new
pointer. The input and output pointer types are always the same.

Example:

```mlir
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
```
}];

let arguments = (ins
Ptr_PtrType:$base,
AnySignlessIntegerOrIndex:$offset,
DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
let results = (outs Ptr_PtrType:$result);
let assemblyFormat = [{
($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
}];
let hasFolder = 1;
let extraClassDeclaration = [{
/// `ViewLikeOp::getViewSource` method.
Value getViewSource() { return getBase(); }
}];
}

//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -361,6 +334,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//

def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
Pure, ViewLikeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Pointer add operation";
let description = [{
The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.

The operation supports both scalar and shaped types with value semantics:
- When both base and offset are scalar: produces a single new pointer
- When base is shaped and offset is scalar: adds the same offset to each
pointer in the base
- When base is scalar and offset is shaped: adds the single pointer to each
offset in the shaped value
- When both are shaped: performs element-wise addition (shapes must be
compatible)

Example:

```mlir
// Scalar base and offset
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32

// Shaped base with scalar offset
%ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32

// Scalar base with shaped offset
%x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>

// Both base and offset are shaped
%ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
```
}];
let arguments = (ins
Ptr_PtrLikeType:$base,
Ptr_IntLikeType:$offset,
DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
let results = (outs Ptr_PtrLikeType:$result);
let assemblyFormat = [{
($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
}];
let hasFolder = 1;
let extraClassDeclaration = [{
/// `ViewLikeOp::getViewSource` method.
Value getViewSource() { return getBase(); }

/// Returns the ptr type of the operation.
ptr::PtrType getPtrType();
}];
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ add_mlir_dialect_library(
MLIRIR
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
MLIRInferTypeOpInterface
MLIRViewLikeInterface
MLIRPtrMemorySpaceInterfaces
)
39 changes: 39 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,45 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

LogicalResult PtrAddOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// Get the base pointer and offset types.
Type baseType = operands[0].getType();
Type offsetType = operands[1].getType();

auto offTy = dyn_cast<ShapedType>(offsetType);
if (!offTy) {
// If the offset isn't shaped, the result is always the base type.
inferredReturnTypes.push_back(baseType);
return success();
}
auto baseTy = dyn_cast<ShapedType>(baseType);
if (!baseTy) {
// Base isn't shaped, but offset is, use the ShapedType from offset with the
// base pointer as element type.
inferredReturnTypes.push_back(offTy.clone(baseType));
return success();
}

// Both are shaped, their shape must match.
if (offTy.getShape() != baseTy.getShape()) {
if (location)
mlir::emitError(*location) << "shapes of base and offset must match";
return failure();
}

// Make sure they are the same kind of shaped type.
if (baseType.getTypeID() != offsetType.getTypeID()) {
if (location)
mlir::emitError(*location) << "the shaped containers type must match";
return failure();
}
inferredReturnTypes.push_back(baseType);
return success();
}

//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
// CHECK: }
func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
%0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
%1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
%2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
%3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
%0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
%1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
%2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
%3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
}

Expand Down Expand Up @@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
%0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
%1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
%2 = ptr.type_offset f32 : index
%3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
%3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
%4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
}
Expand Down Expand Up @@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
%0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
%1 = ptr.type_offset f32 : index
%2 = arith.muli %1, %arg1 : index
%3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
%3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
return %3 : !ptr.ptr<#ptr.generic_space>
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Ptr/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
return
}

// -----

func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
// expected-error@+1 {{the shaped containers type must match}}
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}

// -----

func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
// expected-error@+1 {{shapes of base and offset must match}}
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}
63 changes: 63 additions & 0 deletions mlir/test/Dialect/Ptr/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
return %0 : vector<4xf32>
}

/// Test ptr_add with shaped operands (vectors)
func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
%res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
%res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
%res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
%res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
}

/// Test ptr_add with shaped operands (tensors)
func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}

/// Test ptr_add with 2D tensors
func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
%res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
}

/// Test ptr_add with scalar base and shaped offsets (vectors)
func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
%res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
%res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
%res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
%res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
}

/// Test ptr_add with scalar base and shaped offsets (tensors)
func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
%res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
%res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
%res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
%res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}

/// Test ptr_add with shaped base and scalar offset (vectors)
func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
%res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
%res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
%res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
%res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
}

/// Test ptr_add with shaped base and scalar offset (tensors)
func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
%res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
%res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
%res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
%res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}
30 changes: 30 additions & 0 deletions mlir/test/Target/LLVMIR/ptr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
llvm.return
}

// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
// CHECK-NEXT: }
llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
%res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
}

// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
// CHECK-NEXT: }
llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
}

// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
// CHECK-NEXT: }
llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
%res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
}