Skip to content

Commit

Permalink
[RTG] Add set union operation (#7916)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Dec 9, 2024
1 parent de64a10 commit 2aecdf7
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 3 deletions.
18 changes: 18 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include "mlir/IR/Properties.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "circt/Dialect/RTG/IR/RTGInterfaces.td"

// Base class for the operation in this dialect.
Expand Down Expand Up @@ -141,6 +142,23 @@ def SetDifferenceOp : RTGOp<"set_difference", [
}];
}

def SetUnionOp : RTGOp<"set_union", [
Pure, SameOperandsAndResultType, Commutative
]> {
let summary = "computes the union of sets";
let description = [{
Computes the union of the given sets. The list of sets must contain at
least one element.
}];

let arguments = (ins Variadic<SetType>:$sets);
let results = (outs SetType:$result);

let assemblyFormat = [{
$sets `:` qualified(type($result)) attr-dict
}];
}

//===- Bag Operations ------------------------------------------------------===//

def BagCreateOp : RTGOp<"bag_create", [Pure, SameVariadicOperandSize]> {
Expand Down
3 changes: 2 additions & 1 deletion include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class RTGOpVisitor {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Operation *, ResultType>(op)
.template Case<SequenceOp, SequenceClosureOp, SetCreateOp,
SetSelectRandomOp, SetDifferenceOp, TestOp,
SetSelectRandomOp, SetDifferenceOp, SetUnionOp, TestOp,
InvokeSequenceOp, BagCreateOp, BagSelectRandomOp,
BagDifferenceOp, TargetOp, YieldOp>(
[&](auto expr) -> ResultType {
Expand Down Expand Up @@ -89,6 +89,7 @@ class RTGOpVisitor {
HANDLE(SetCreateOp, Unhandled);
HANDLE(SetSelectRandomOp, Unhandled);
HANDLE(SetDifferenceOp, Unhandled);
HANDLE(SetUnionOp, Unhandled);
HANDLE(BagCreateOp, Unhandled);
HANDLE(BagSelectRandomOp, Unhandled);
HANDLE(BagDifferenceOp, Unhandled);
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,17 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(SetUnionOp op, function_ref<void(Operation *)> addToWorklist) {
SetVector<ElaboratorValue *> result;
for (auto set : op.getSets())
result.set_union(cast<SetValue>(state.at(set))->getSet());

internalizeResult<SetValue>(op.getResult(), std::move(result),
op.getType());
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(BagCreateOp op, function_ref<void(Operation *)> addToWorklist) {
MapVector<ElaboratorValue *, uint64_t> bag;
Expand Down
4 changes: 3 additions & 1 deletion test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ func.func @sets(%arg0: i32, %arg1: i32) {
// CHECK: [[SET:%.+]] = rtg.set_create %arg0, %arg1 : i32
// CHECK: [[R:%.+]] = rtg.set_select_random [[SET]] : !rtg.set<i32>
// CHECK: [[EMPTY:%.+]] = rtg.set_create : i32
// CHECK: rtg.set_difference [[SET]], [[EMPTY]] : !rtg.set<i32>
// CHECK: [[DIFF:%.+]] = rtg.set_difference [[SET]], [[EMPTY]] : !rtg.set<i32>
// CHECK: rtg.set_union [[SET]], [[DIFF]] : !rtg.set<i32>
%set = rtg.set_create %arg0, %arg1 : i32
%r = rtg.set_select_random %set : !rtg.set<i32>
%empty = rtg.set_create : i32
%diff = rtg.set_difference %set, %empty : !rtg.set<i32>
%union = rtg.set_union %set, %diff : !rtg.set<i32>

return
}
Expand Down
7 changes: 7 additions & 0 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ rtg.sequence @seq {
// expected-error @below {{operand types must match bag element type}}
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i64, i64, index, index) -> !rtg.bag<i32>
}

// -----

rtg.sequence @seq {
// expected-error @below {{expected 1 or more operands, but found 0}}
rtg.set_union : !rtg.set<i32>
}
4 changes: 3 additions & 1 deletion test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ rtg.test @setOperations : !rtg.dict<> {
%1 = arith.constant 3 : i32
%2 = arith.constant 4 : i32
%3 = arith.constant 5 : i32
%set = rtg.set_create %0, %1, %2, %0 : i32
%set0 = rtg.set_create %0, %1, %0 : i32
%set1 = rtg.set_create %2, %0 : i32
%set = rtg.set_union %set0, %set1 : !rtg.set<i32>
%4 = rtg.set_select_random %set : !rtg.set<i32> {rtg.elaboration_custom_seed = 1}
%new_set = rtg.set_create %3, %4 : i32
%diff = rtg.set_difference %set, %new_set : !rtg.set<i32>
Expand Down

0 comments on commit 2aecdf7

Please sign in to comment.