From f61ea9db350420208fedd7fec7b2a13f9a53cafd Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Thu, 28 Nov 2024 11:49:18 +0000 Subject: [PATCH] [RTG] Add set union operation --- include/circt/Dialect/RTG/IR/RTGOps.td | 18 ++++++++++++++++++ include/circt/Dialect/RTG/IR/RTGVisitors.h | 3 ++- lib/Dialect/RTG/Transforms/ElaborationPass.cpp | 11 +++++++++++ test/Dialect/RTG/IR/basic.mlir | 4 +++- test/Dialect/RTG/IR/errors.mlir | 7 +++++++ test/Dialect/RTG/Transform/elaboration.mlir | 4 +++- 6 files changed, 44 insertions(+), 3 deletions(-) diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index 6b85fb76073c..a270137bf59d 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -16,6 +16,7 @@ include "mlir/IR/Properties.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "circt/Dialect/RTG/IR/RTGInterfaces.td" // Base class for the operation in this dialect. @@ -141,6 +142,23 @@ def SetDifferenceOp : RTGOp<"set_difference", [ }]; } +def SetUnionOp : RTGOp<"set_union", [ + Pure, SameOperandsAndResultType, Commutative +]> { + let summary = "computes the union of sets"; + let description = [{ + Computes the union of the given sets. The list of sets must contain at + least one element. + }]; + + let arguments = (ins Variadic:$sets); + let results = (outs SetType:$result); + + let assemblyFormat = [{ + $sets `:` qualified(type($result)) attr-dict + }]; +} + //===- Bag Operations ------------------------------------------------------===// def BagCreateOp : RTGOp<"bag_create", [Pure, SameVariadicOperandSize]> { diff --git a/include/circt/Dialect/RTG/IR/RTGVisitors.h b/include/circt/Dialect/RTG/IR/RTGVisitors.h index a8382493e1dc..4c4babb96edc 100644 --- a/include/circt/Dialect/RTG/IR/RTGVisitors.h +++ b/include/circt/Dialect/RTG/IR/RTGVisitors.h @@ -32,7 +32,7 @@ class RTGOpVisitor { auto *thisCast = static_cast(this); return TypeSwitch(op) .template Case( [&](auto expr) -> ResultType { @@ -89,6 +89,7 @@ class RTGOpVisitor { HANDLE(SetCreateOp, Unhandled); HANDLE(SetSelectRandomOp, Unhandled); HANDLE(SetDifferenceOp, Unhandled); + HANDLE(SetUnionOp, Unhandled); HANDLE(BagCreateOp, Unhandled); HANDLE(BagSelectRandomOp, Unhandled); HANDLE(BagDifferenceOp, Unhandled); diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 16a30d59cacf..20498be23a70 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -427,6 +427,17 @@ class Elaborator : public RTGOpVisitor, return DeletionKind::Delete; } + FailureOr + visitOp(SetUnionOp op, function_ref addToWorklist) { + SetVector result; + for (auto set : op.getSets()) + result.set_union(cast(state.at(set))->getSet()); + + internalizeResult(op.getResult(), std::move(result), + op.getType()); + return DeletionKind::Delete; + } + FailureOr dispatchOpVisitor(Operation *op, function_ref addToWorklist) { diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index 5dd9126e917b..85298992c4d8 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -30,11 +30,13 @@ func.func @sets(%arg0: i32, %arg1: i32) { // CHECK: [[SET:%.+]] = rtg.set_create %arg0, %arg1 : i32 // CHECK: [[R:%.+]] = rtg.set_select_random [[SET]] : !rtg.set // CHECK: [[EMPTY:%.+]] = rtg.set_create : i32 - // CHECK: rtg.set_difference [[SET]], [[EMPTY]] : !rtg.set + // CHECK: [[DIFF:%.+]] = rtg.set_difference [[SET]], [[EMPTY]] : !rtg.set + // CHECK: rtg.set_union [[SET]], [[DIFF]] : !rtg.set %set = rtg.set_create %arg0, %arg1 : i32 %r = rtg.set_select_random %set : !rtg.set %empty = rtg.set_create : i32 %diff = rtg.set_difference %set, %empty : !rtg.set + %union = rtg.set_union %set, %diff : !rtg.set return } diff --git a/test/Dialect/RTG/IR/errors.mlir b/test/Dialect/RTG/IR/errors.mlir index b2f23049c9a8..2d5f147f234b 100644 --- a/test/Dialect/RTG/IR/errors.mlir +++ b/test/Dialect/RTG/IR/errors.mlir @@ -66,3 +66,10 @@ rtg.sequence @seq { // expected-error @below {{operand types must match bag element type}} "rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i64, i64, index, index) -> !rtg.bag } + +// ----- + +rtg.sequence @seq { + // expected-error @below {{expected 1 or more operands, but found 0}} + rtg.set_union : !rtg.set +} diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index 69000aa3714c..95ff3657e9f5 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -17,7 +17,9 @@ rtg.test @setOperations : !rtg.dict<> { %1 = arith.constant 3 : i32 %2 = arith.constant 4 : i32 %3 = arith.constant 5 : i32 - %set = rtg.set_create %0, %1, %2, %0 : i32 + %set0 = rtg.set_create %0, %1, %0 : i32 + %set1 = rtg.set_create %2, %0 : i32 + %set = rtg.set_union %set0, %set1 : !rtg.set %4 = rtg.set_select_random %set : !rtg.set {rtg.elaboration_custom_seed = 1} %new_set = rtg.set_create %3, %4 : i32 %diff = rtg.set_difference %set, %new_set : !rtg.set