Skip to content

Commit 8d4e9f5

Browse files
andrey-golubevermilindwalekar
authored andcommitted
[mlir][bufferization] Refine tensor-buffer compatibility checks (llvm#167705)
Generally, to_tensor and to_buffer already perform sufficient verification. However, there are some unnecessarily strict constraints: * builtin tensor requires its buffer counterpart to always be memref * to_buffer on ranked tensor requires to always return memref These checks are assertions (i.e. preconditions), however, they actually prevent an apparently useful bufferization where builtin tensors could become custom buffers. Lift these assertions, maintaining the verification procedure unchanged, to allow builtin -> custom bufferizations at operation boundary level.
1 parent 6d51edf commit 8d4e9f5

File tree

5 files changed

+124
-34
lines changed

5 files changed

+124
-34
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -683,16 +683,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
683683
return false;
684684
}
685685

686-
// bufferization.to_buffer is not allowed to change the rank.
687-
static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
688-
#ifndef NDEBUG
689-
auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
690-
assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
691-
rankedTensorType.getRank()) &&
692-
"to_buffer would be invalid: mismatching ranks");
693-
#endif
694-
}
695-
696686
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
697687
const BufferizationOptions &options,
698688
const BufferizationState &state) {
@@ -711,7 +701,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
711701
FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
712702
if (failed(bufferType))
713703
return failure();
714-
ensureToBufferOpIsValid(value, *bufferType);
704+
715705
return rewriter
716706
.create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
717707
.getResult();

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ struct BuiltinTensorExternalModel
7373
mlir::LogicalResult verifyCompatibleBufferType(
7474
mlir::Type tensor, BufferLikeType bufferType,
7575
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
76-
assert(isa<TensorType>(tensor) && "expected tensor type");
77-
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
78-
7976
auto tensorType = cast<ShapedType>(tensor);
8077
auto memrefType = cast<ShapedType>(bufferType);
8178

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() {
127127
// expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}}
128128
arith.constant {bufferization.manual_deallocation} 0 : index
129129
}
130+
131+
// -----
132+
133+
func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) {
134+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
135+
// expected-error @below{{shapes do not match}}
136+
%b = bufferization.to_buffer %t
137+
: tensor<1x2x3x4xf32> to memref<1x2x3xf32>
138+
return
139+
}
140+
141+
// -----
142+
143+
func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) {
144+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
145+
// expected-error @below{{shapes do not match}}
146+
%t = bufferization.to_tensor %b
147+
: memref<1x2x3xf32> to tensor<1x2x3x4xf32>
148+
return
149+
}
150+
151+
// -----
152+
153+
func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) {
154+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
155+
// expected-error @below{{shapes do not match}}
156+
%b = bufferization.to_buffer %t
157+
: tensor<1x2x3x4xf32> to memref<1x2x4x3xf32>
158+
return
159+
}
160+
161+
// -----
162+
163+
func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) {
164+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
165+
// expected-error @below{{shapes do not match}}
166+
%t = bufferization.to_tensor %b
167+
: memref<1x2x4x3xf32> to tensor<1x2x3x4xf32>
168+
return
169+
}
170+
171+
// -----
172+
173+
func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) {
174+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
175+
// expected-error @below{{element types do not match}}
176+
%b = bufferization.to_buffer %t
177+
: tensor<1x2x3x4xf32> to memref<1x2x3x4xf16>
178+
return
179+
}
180+
181+
// -----
182+
183+
func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) {
184+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
185+
// expected-error @below{{element types do not match}}
186+
%t2 = bufferization.to_tensor %b
187+
: memref<1x2x3x4xf16> to tensor<1x2x3x4xf32>
188+
return
189+
}

mlir/test/Dialect/Bufferization/ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
8383
bufferization.dealloc
8484
return %0#0, %0#1 : i1, i1
8585
}
86+
87+
// CHECK: func.func @test_builtin_custom_builtin_type_conversion
88+
// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32>
89+
func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>)
90+
-> tensor<42xf32> {
91+
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
92+
// CHECK-SAME: to !test.test_memref<[42], f32>
93+
%buffer = bufferization.to_buffer %t
94+
: tensor<42xf32> to !test.test_memref<[42], f32>
95+
96+
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
97+
// CHECK-SAME: to tensor<42xf32>
98+
%tensor = bufferization.to_tensor %buffer
99+
: !test.test_memref<[42], f32> to tensor<42xf32>
100+
101+
// CHECK: return %[[tensor]]
102+
return %tensor : tensor<42xf32>
103+
}
104+
105+
// CHECK: func.func @test_custom_builtin_custom_type_conversion
106+
// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>)
107+
// CHECK-SAME: -> !test.test_tensor<[42], f32>
108+
func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>)
109+
-> !test.test_tensor<[42], f32> {
110+
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
111+
// CHECK-SAME: to memref<42xf32>
112+
%buffer = bufferization.to_buffer %t
113+
: !test.test_tensor<[42], f32> to memref<42xf32>
114+
115+
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
116+
// CHECK-SAME: to !test.test_tensor<[42], f32>
117+
%tensor = bufferization.to_tensor %buffer
118+
: memref<42xf32> to !test.test_tensor<[42], f32>
119+
120+
// CHECK: return %[[tensor]]
121+
return %tensor : !test.test_tensor<[42], f32>
122+
}

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -554,26 +554,6 @@ TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const {
554554
return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
555555
}
556556

557-
::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
558-
TestTensorType::getBufferType(
559-
const ::mlir::bufferization::BufferizationOptions &,
560-
::llvm::function_ref<::mlir::InFlightDiagnostic()>) {
561-
return cast<bufferization::BufferLikeType>(
562-
TestMemrefType::get(getContext(), getShape(), getElementType(), nullptr));
563-
}
564-
565-
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
566-
::mlir::bufferization::BufferLikeType bufferType,
567-
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
568-
auto testMemref = dyn_cast<TestMemrefType>(bufferType);
569-
if (!testMemref)
570-
return emitError() << "expected TestMemrefType";
571-
572-
const bool valid = getShape() == testMemref.getShape() &&
573-
getElementType() == testMemref.getElementType();
574-
return mlir::success(valid);
575-
}
576-
577557
//===----------------------------------------------------------------------===//
578558
// TestTypeNewlineAndIndent
579559
//===----------------------------------------------------------------------===//
@@ -600,3 +580,29 @@ void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const {
600580
printer.printNewline();
601581
printer << ">";
602582
}
583+
584+
::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
585+
TestTensorType::getBufferType(
586+
const ::mlir::bufferization::BufferizationOptions &,
587+
::llvm::function_ref<::mlir::InFlightDiagnostic()>) {
588+
return llvm::cast<bufferization::BufferLikeType>(
589+
TestMemrefType::get(getContext(), getShape(), getElementType(), nullptr));
590+
}
591+
592+
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
593+
::mlir::bufferization::BufferLikeType bufferType,
594+
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
595+
if (auto testMemref = llvm::dyn_cast<TestMemrefType>(bufferType)) {
596+
const bool valid = getShape() == testMemref.getShape() &&
597+
getElementType() == testMemref.getElementType();
598+
return mlir::success(valid);
599+
}
600+
601+
if (auto builtinMemref = llvm::dyn_cast<MemRefType>(bufferType)) {
602+
const bool valid = getShape() == builtinMemref.getShape() &&
603+
getElementType() == builtinMemref.getElementType();
604+
return mlir::success(valid);
605+
}
606+
607+
return emitError() << "expected MemRefType or TestMemrefType";
608+
}

0 commit comments

Comments
 (0)