Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit a8d42cc

Browse files
joker-ephjpienaar
authored andcommitted
Add a HasParent operation trait to enforce a specific parent on an operation (NFC)
PiperOrigin-RevId: 260532592
1 parent fae8644 commit a8d42cc

File tree

7 files changed

+35
-5
lines changed

7 files changed

+35
-5
lines changed

include/mlir/IR/OpBase.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,10 @@ def Terminator : NativeOpTrait<"IsTerminator">;
10181018
class SingleBlockImplicitTerminator<string op>
10191019
: ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
10201020

1021+
// Op's parent operation is the provided one.
1022+
class HasParent<string op>
1023+
: ParamNativeOpTrait<"HasParent", op>;
1024+
10211025
// Op result type is derived from the first attribute. If the attribute is an
10221026
// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
10231027
// attribute content is used.

include/mlir/IR/OpDefinition.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,20 @@ template <typename TerminatorOpType> struct SingleBlockImplicitTerminator {
845845
};
846846
};
847847

848+
/// This class provides a verifier for ops that are expecting a specific parent.
849+
template <typename ParentOpType> struct HasParent {
850+
template <typename ConcreteType>
851+
class Impl : public TraitBase<ConcreteType, Impl> {
852+
public:
853+
static LogicalResult verifyTrait(Operation *op) {
854+
if (isa<ParentOpType>(op->getParentOp()))
855+
return success();
856+
return op->emitOpError() << "expects parent op '"
857+
<< ParentOpType::getOperationName() << "'";
858+
}
859+
};
860+
};
861+
848862
} // end namespace OpTrait
849863

850864
//===----------------------------------------------------------------------===//

include/mlir/StandardOps/Ops.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def RemIUOp : IntArithmeticOp<"remiu"> {
765765
let hasFolder = 1;
766766
}
767767

768-
def ReturnOp : Std_Op<"return", [Terminator]> {
768+
def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
769769
let summary = "return operation";
770770
let description = [{
771771
The "return" operation represents a return operation within a function.

lib/StandardOps/Ops.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,9 +1868,7 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
18681868
}
18691869

18701870
static LogicalResult verify(ReturnOp op) {
1871-
auto function = dyn_cast_or_null<FuncOp>(op.getParentOp());
1872-
if (!function)
1873-
return op.emitOpError() << "must be nested within a 'func' region";
1871+
auto function = cast<FuncOp>(op.getParentOp());
18741872

18751873
// The operand number and types must match the function signature.
18761874
const auto &results = function.getType().getResults();

test/IR/invalid-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ func @sitofp_f32_to_i32(%arg0 : f32) {
720720

721721
func @return_not_in_function() {
722722
"foo.region"() ({
723-
// expected-error@+1 {{must be nested within a 'func' region}}
723+
// expected-error@+1 {{'std.return' op expects parent op 'func'}}
724724
return
725725
}): () -> ()
726726
return

test/IR/traits.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,13 @@ func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tenso
4040
// expected-error@+1 {{requires the same shape for all operands and results}}
4141
%0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
4242
}
43+
44+
// -----
45+
46+
func @hasParent() {
47+
"some.op"() ({
48+
// expected-error@+1 {{'test.child' op expects parent op 'test.parent'}}
49+
"test.child"() : () -> ()
50+
}) : () -> ()
51+
}
52+

test/lib/TestDialect/TestOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> {
201201
let results = (outs AnyTensor:$res);
202202
}
203203

204+
// There the "HasParent" trait.
205+
def ParentOp : TEST_Op<"parent">;
206+
def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
207+
204208
//===----------------------------------------------------------------------===//
205209
// Test Patterns
206210
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)