Skip to content

Commit

Permalink
[RTG] Add BagType and operations
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Dec 2, 2024
1 parent 3a32905 commit cb34cbd
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 5 deletions.
65 changes: 65 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,71 @@ def SetDifferenceOp : RTGOp<"set_difference", [
}];
}

//===- Bag Operations ------------------------------------------------------===//

def BagCreateOp : RTGOp<"bag_create", [Pure, SameVariadicOperandSize]> {
let summary = "constructs a bag";
let description = [{
This operation constructs a bag with the provided values and associated
multiples. This means the bag constructed in the following example contains
two of each `%arg0` and `%arg0` (`{%arg0, %arg0, %arg1, %arg1}`).

```mlir
%0 = arith.constant 2 : index
%1 = rtg.bag_create (%0 x %arg0, %0 x %arg1) : i32
```
}];

let arguments = (ins Variadic<AnyType>:$elements,
Variadic<Index>:$multiples);
let results = (outs BagType:$bag);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

def BagSelectRandomOp : RTGOp<"bag_select_random", [
Pure,
TypesMatchWith<"output must be element type of input bag", "bag", "output",
"llvm::cast<rtg::BagType>($_self).getElementType()">
]> {
let summary = "select a random element from the bag";
let description = [{
This operation returns an element from the bag selected uniformely at
random. Therefore, the number of duplicates of each element can be used to
bias the distribution.
If the bag does not contain any elements, the behavior of this operation is
undefined.
}];

let arguments = (ins BagType:$bag);
let results = (outs AnyType:$output);

let assemblyFormat = "$bag `:` qualified(type($bag)) attr-dict";
}

def BagDifferenceOp : RTGOp<"bag_difference", [
Pure,
AllTypesMatch<["original", "diff", "output"]>
]> {
let summary = "computes the difference of two bags";
let description = [{
For each element the resulting bag will have as many fewer than the
'original' bag as there are in the 'diff' bag. However, if the 'inf'
attribute is attached, all elements of that kind will be removed (i.e., it
is assumed the 'diff' bag has infinitely many copies of each element).
}];

let arguments = (ins BagType:$original,
BagType:$diff,
UnitAttr:$inf);
let results = (outs BagType:$output);

let assemblyFormat = [{
$original `,` $diff (`inf` $inf^)? `:` qualified(type($output)) attr-dict
}];
}

//===- Test Specification Operations --------------------------------------===//

def TestOp : RTGOp<"test", [
Expand Down
17 changes: 17 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,27 @@ def SetType : RTGTypeDef<"Set"> {
let assemblyFormat = "`<` $elementType `>`";
}

def BagType : RTGTypeDef<"Bag"> {
let summary = "a bag of values";
let description = [{
This type represents a standard bag/multiset datastructure. It does not make
any assumptions about the underlying implementation.
}];

let parameters = (ins "::mlir::Type":$elementType);

let mnemonic = "bag";
let assemblyFormat = "`<` $elementType `>`";
}

class SetTypeOf<Type elementType> : ContainerType<
elementType, SetType.predicate,
"llvm::cast<rtg::SetType>($_self).getElementType()", "set">;

class BagTypeOf<Type elementType> : ContainerType<
elementType, BagType.predicate,
"llvm::cast<rtg::BagType>($_self).getElementType()", "bag">;

def DictType : RTGTypeDef<"Dict"> {
let summary = "a dictionary";
let description = [{
Expand Down
16 changes: 11 additions & 5 deletions include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class RTGOpVisitor {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Operation *, ResultType>(op)
.template Case<SequenceOp, SequenceClosureOp, SetCreateOp,
SetSelectRandomOp, SetDifferenceOp, InvokeSequenceOp,
TestOp, TargetOp, YieldOp>([&](auto expr) -> ResultType {
return thisCast->visitOp(expr, args...);
})
SetSelectRandomOp, SetDifferenceOp, TestOp,
InvokeSequenceOp, BagCreateOp, BagSelectRandomOp,
BagDifferenceOp, TargetOp, YieldOp>(
[&](auto expr) -> ResultType {
return thisCast->visitOp(expr, args...);
})
.template Case<ContextResourceOpInterface>(
[&](auto expr) -> ResultType {
return thisCast->visitContextResourceOp(expr, args...);
Expand Down Expand Up @@ -79,6 +81,9 @@ class RTGOpVisitor {
HANDLE(SetCreateOp, Unhandled);
HANDLE(SetSelectRandomOp, Unhandled);
HANDLE(SetDifferenceOp, Unhandled);
HANDLE(BagCreateOp, Unhandled);
HANDLE(BagSelectRandomOp, Unhandled);
HANDLE(BagDifferenceOp, Unhandled);
HANDLE(TestOp, Unhandled);
HANDLE(TargetOp, Unhandled);
HANDLE(YieldOp, Unhandled);
Expand All @@ -93,7 +98,7 @@ class RTGTypeVisitor {
ResultType dispatchTypeVisitor(Type type, ExtraArgs... args) {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Type, ResultType>(type)
.template Case<SequenceType, SetType, DictType>(
.template Case<SequenceType, SetType, BagType, DictType>(
[&](auto expr) -> ResultType {
return thisCast->visitType(expr, args...);
})
Expand Down Expand Up @@ -138,6 +143,7 @@ class RTGTypeVisitor {

HANDLE(SequenceType, Unhandled);
HANDLE(SetType, Unhandled);
HANDLE(BagType, Unhandled);
HANDLE(DictType, Unhandled);
#undef HANDLE
};
Expand Down
72 changes: 72 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,78 @@ LogicalResult SetCreateOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// BagCreateOp
//===----------------------------------------------------------------------===//

ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
multipleOperands;
Type elemType;

if (!parser.parseOptionalLParen()) {
while (true) {
OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
parser.parseOperand(elementOperand))
return failure();

elementOperands.push_back(elementOperand);
multipleOperands.push_back(multipleOperand);

if (parser.parseOptionalComma()) {
if (parser.parseRParen())
return failure();
break;
}
}
}

if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(elemType))
return failure();

result.addTypes({BagType::get(result.getContext(), elemType)});

for (auto operand : elementOperands)
if (parser.resolveOperand(operand, elemType, result.operands))
return failure();

for (auto operand : multipleOperands)
if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
result.operands))
return failure();

return success();
}

void BagCreateOp::print(OpAsmPrinter &p) {
p << " ";
if (!getElements().empty())
p << "(";
llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
[&](auto elAndMultiple) {
auto [el, multiple] = elAndMultiple;
p << multiple << " x " << el;
});
if (!getElements().empty())
p << ")";

p.printOptionalAttrDict((*this)->getAttrs());
p << " : " << getBag().getType().getElementType();
}

LogicalResult BagCreateOp::verify() {
if (!llvm::all_equal(getElements().getTypes()))
return emitOpError() << "types of all elements must match";

if (getElements().size() > 0)
if (getElements()[0].getType() != getBag().getType().getElementType())
return emitOpError() << "operand types must match bag element type";

return success();
}

//===----------------------------------------------------------------------===//
// TestOp
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ func.func @sets(%arg0: i32, %arg1: i32) {
return
}

// CHECK-LABEL: @bags
rtg.sequence @bags {
^bb0(%arg0: i32, %arg1: i32, %arg2: index):
// CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32
// CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag<i32>
// CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag<i32>
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] inf : !rtg.bag<i32>
%bag = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32
%r = rtg.bag_select_random %bag : !rtg.bag<i32>
%empty = rtg.bag_create : i32
%diff = rtg.bag_difference %bag, %empty : !rtg.bag<i32>
%diff2 = rtg.bag_difference %bag, %empty inf : !rtg.bag<i32>
}

// CHECK-LABEL: rtg.target @empty_target : !rtg.dict<> {
// CHECK-NOT: rtg.yield
rtg.target @empty_target : !rtg.dict<> {
Expand Down
16 changes: 16 additions & 0 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,19 @@ rtg.test @test : !rtg.dict<b: i32, a: i32> {
rtg.test @test : !rtg.dict<"": i32> {
^bb0(%arg0: i32):
}

// -----

rtg.sequence @seq {
^bb0(%arg0: i32, %arg1: i64, %arg2: index):
// expected-error @below {{types of all elements must match}}
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i32, i64, index, index) -> !rtg.bag<i32>
}

// -----

rtg.sequence @seq {
^bb0(%arg0: i64, %arg1: i64, %arg2: index):
// expected-error @below {{operand types must match bag element type}}
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i64, i64, index, index) -> !rtg.bag<i32>
}

0 comments on commit cb34cbd

Please sign in to comment.