Skip to content

Commit

Permalink
[RTG] Add BagType CAPI and Python Bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Dec 4, 2024
1 parent 69b551d commit 1036368
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 1 deletion.
6 changes: 6 additions & 0 deletions include/circt-c/Dialect/RTG.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ MLIR_CAPI_EXPORTED bool rtgTypeIsASet(MlirType type);
/// Creates an RTG set type in the context.
MLIR_CAPI_EXPORTED MlirType rtgSetTypeGet(MlirType elementType);

/// If the type is an RTG bag.
MLIR_CAPI_EXPORTED bool rtgTypeIsABag(MlirType type);

/// Creates an RTG bag type in the context.
MLIR_CAPI_EXPORTED MlirType rtgBagTypeGet(MlirType elementType);

/// If the type is an RTG dict.
MLIR_CAPI_EXPORTED bool rtgTypeIsADict(MlirType type);

Expand Down
17 changes: 16 additions & 1 deletion integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import circt

from circt.dialects import rtg, rtgtest
from circt.ir import Context, Location, Module, InsertionPoint, Block, StringAttr, TypeAttr
from circt.ir import Context, Location, Module, InsertionPoint, Block, StringAttr, TypeAttr, IndexType
from circt.passmanager import PassManager
from circt import rtgtool_support as rtgtool

Expand Down Expand Up @@ -77,3 +77,18 @@
# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK-NEXT: }
print(m)

with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
indexTy = IndexType.get()
sequenceTy = rtg.SequenceType.get()
setTy = rtg.SetType.get(indexTy)
bagTy = rtg.BagType.get(indexTy)
seq = rtg.SequenceOp('seq')
Block.create_at_start(seq.bodyRegion, [sequenceTy, setTy, bagTy])

# CHECK: rtg.sequence @seq
# CHECK: (%{{.*}}: !rtg.sequence, %{{.*}}: !rtg.set<index>, %{{.*}}: !rtg.bag<index>):
print(m)
8 changes: 8 additions & 0 deletions lib/Bindings/Python/RTGModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ void circt::python::populateDialectRTGSubmodule(py::module &m) {
},
py::arg("self"), py::arg("element_type"));

mlir_type_subclass(m, "BagType", rtgTypeIsABag)
.def_classmethod(
"get",
[](py::object cls, MlirType elementType) {
return cls(rtgBagTypeGet(elementType));
},
py::arg("self"), py::arg("element_type"));

mlir_type_subclass(m, "DictType", rtgTypeIsADict)
.def_classmethod(
"get",
Expand Down
10 changes: 10 additions & 0 deletions lib/CAPI/Dialect/RTG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ MlirType rtgSetTypeGet(MlirType elementType) {
return wrap(SetType::get(ty.getContext(), ty));
}

// BagType
//===----------------------------------------------------------------------===//

bool rtgTypeIsABag(MlirType type) { return isa<BagType>(unwrap(type)); }

MlirType rtgBagTypeGet(MlirType elementType) {
auto ty = unwrap(elementType);
return wrap(BagType::get(ty.getContext(), ty));
}

// DictType
//===----------------------------------------------------------------------===//

Expand Down
11 changes: 11 additions & 0 deletions test/CAPI/rtg.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ static void testSetType(MlirContext ctx) {
mlirTypeDump(setTy);
}

static void testBagType(MlirContext ctx) {
MlirType elTy = mlirIntegerTypeGet(ctx, 32);
MlirType bagTy = rtgBagTypeGet(elTy);

// CHECK: is_bag
fprintf(stderr, rtgTypeIsABag(bagTy) ? "is_bag\n" : "isnot_bag\n");
// CHECK: !rtg.bag<i32>
mlirTypeDump(bagTy);
}

static void testDictType(MlirContext ctx) {
MlirType elTy = mlirIntegerTypeGet(ctx, 32);
MlirAttribute name0 =
Expand Down Expand Up @@ -62,6 +72,7 @@ int main(int argc, char **argv) {

testSequenceType(ctx);
testSetType(ctx);
testBagType(ctx);
testDictType(ctx);

mlirContextDestroy(ctx);
Expand Down

0 comments on commit 1036368

Please sign in to comment.