Skip to content

Commit

Permalink
[RTG] Elaboration support for get_size operations
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Nov 29, 2024
1 parent 135c9d8 commit 31e652b
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 4 deletions.
10 changes: 6 additions & 4 deletions include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ class RTGTypeVisitor {
ResultType dispatchTypeVisitor(Type type, ExtraArgs... args) {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Type, ResultType>(type)
.template Case<SequenceType, SetType, BagType, DictType>(
[&](auto expr) -> ResultType {
return thisCast->visitType(expr, args...);
})
.template Case<SequenceType, SetType, BagType, DictType, IndexType,
IntegerType>([&](auto expr) -> ResultType {
return thisCast->visitType(expr, args...);
})
.template Case<ContextResourceTypeInterface>(
[&](auto expr) -> ResultType {
return thisCast->visitContextResourceType(expr, args...);
Expand Down Expand Up @@ -158,6 +158,8 @@ class RTGTypeVisitor {
HANDLE(SetType, Unhandled);
HANDLE(BagType, Unhandled);
HANDLE(DictType, Unhandled);
HANDLE(IndexType, Unhandled);
HANDLE(IntegerType, Unhandled);
#undef HANDLE
};

Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/RTG/Transforms/RTGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def ElaborationPass : Pass<"rtg-elaborate", "mlir::ModuleOp"> {
// Define a custom constructor to have more control over the pass options
// (e.g., std::optional options are not handled very well).
let constructor = "::circt::rtg::createElaborationPass()";
let dependentDialects = ["mlir::arith::ArithDialect"];
}

#endif // CIRCT_DIALECT_RTG_RTGPASSES_TD
61 changes: 61 additions & 0 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,43 @@ struct InternMapInfo : public DenseMapInfo<ElaboratorValue *> {
// Main Elaborator Implementation
//===----------------------------------------------------------------------===//

/// Construct an SSA value from a given elaborated value.
class Materializer : public RTGTypeVisitor<Materializer, Value, OpBuilder &,
Location, ElaboratorValue *> {
public:
using Base = RTGTypeVisitor<Materializer, Value, OpBuilder &, Location,
ElaboratorValue *>;
using Base::visitType;

Value visitUnhandledType(Type type, OpBuilder &builder, Location loc,
ElaboratorValue *val) {
return Value();
}

Value visitType(IndexType type, OpBuilder &builder, Location loc,
ElaboratorValue *val) {
auto res = builder.create<arith::ConstantOp>(
loc, IntegerAttr::get(type, cast<IntegerValue>(val)->getInt()));
materializedValues[{val, builder.getBlock()}] = res;
return res;
}

Value materialize(Block *block, Location loc, ElaboratorValue *val) {
if (val->isOpaqueValue())
return val->getOpaqueValue();

auto iter = materializedValues.find({val, block});
if (iter != materializedValues.end())
return iter->second;

OpBuilder builder = OpBuilder::atBlockBegin(block);
return dispatchTypeVisitor(val->getType(), builder, loc, val);
}

private:
DenseMap<std::pair<ElaboratorValue *, Block *>, Value> materializedValues;
};

/// Used to signal to the elaboration driver whether the operation should be
/// removed.
enum class DeletionKind { Keep, Delete };
Expand Down Expand Up @@ -444,6 +481,15 @@ class Elaborator
FailureOr<DeletionKind>
visitExternalOp(Operation *op,
function_ref<void(Operation *)> addToWorklist) {
for (auto &operand : op->getOpOperands()) {
auto val = materializer.materialize(op->getBlock(), op->getLoc(),
state.at(operand.get()));
if (!val)
return op->emitError("failed to materialize value for operand #")
<< operand.getOperandNumber();
operand.set(val);
}

// Treat values defined by external ops as opaque, non-elaborated values.
for (auto res : op->getResults())
internalizeResult<ElaboratorValue>(res);
Expand Down Expand Up @@ -607,6 +653,13 @@ class Elaborator
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(SetGetSizeOp op, function_ref<void(Operation *)> addToWorklist) {
auto size = cast<SetValue>(state.at(op.getSet()))->getAsArrayRef().size();
internalizeResult<IntegerValue>(op.getResult(), size);
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(BagCreateOp op, function_ref<void(Operation *)> addToWorklist) {
DenseMap<ElaboratorValue *, uint64_t> bag;
Expand Down Expand Up @@ -749,6 +802,13 @@ class Elaborator
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(BagGetSizeOp op, function_ref<void(Operation *)> addToWorklist) {
auto size = cast<BagValue>(state.at(op.getBag()))->getBag().size();
internalizeResult<IntegerValue>(op.getResult(), size);
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
dispatchOpVisitor(Operation *op,
function_ref<void(Operation *)> addToWorklist) {
Expand Down Expand Up @@ -858,6 +918,7 @@ class Elaborator

// A map from SSA values to a pointer of an interned elaborator value.
DenseMap<Value, ElaboratorValue *> state;
Materializer materializer;

SymbolTable symTable;
};
Expand Down
23 changes: 23 additions & 0 deletions test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,29 @@ rtg.test @bagOperations : !rtg.dict<> {
rtg.invoke_sequence %seq2
}

// CHECK-LABEL: rtg.test @setSize
rtg.test @setSize : !rtg.dict<> {
// CHECK-NEXT: [[C:%.+]] = arith.constant 1 : index
// CHECK-NEXT: index.add [[C]], [[C]]
// CHECK-NEXT: }
%c5_i32 = arith.constant 5 : i32
%set = rtg.set_create %c5_i32 : i32
%size = rtg.set_get_size %set : !rtg.set<i32>
index.add %size, %size
}

// CHECK-LABEL: rtg.test @bagSize
rtg.test @bagSize : !rtg.dict<> {
// CHECK-NEXT: [[C:%.+]] = arith.constant 1 : index
// CHECK-NEXT: index.add [[C]], [[C]]
// CHECK-NEXT: }
%c8 = arith.constant 8 : index
%c5_i32 = arith.constant 5 : i32
%bag = rtg.bag_create (%c8 x %c5_i32) : i32
%size = rtg.bag_get_size %bag : !rtg.bag<i32>
index.add %size, %size
}

// CHECK-LABEL: rtg.sequence @seq3
rtg.sequence @seq3 {
^bb0(%arg0: !rtg.set<!rtg.sequence>):
Expand Down
1 change: 1 addition & 0 deletions tools/circt-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ target_link_libraries(circt-opt
MLIREmitCDialect
MLIRFuncInlinerExtension
MLIRVectorDialect
MLIRIndexDialect
)

export_executable_symbols_for_plugins(circt-opt)
2 changes: 2 additions & 0 deletions tools/circt-opt/circt-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand Down Expand Up @@ -55,6 +56,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::scf::SCFDialect>();
registry.insert<mlir::emitc::EmitCDialect>();
registry.insert<mlir::vector::VectorDialect>();
registry.insert<mlir::index::IndexDialect>();

circt::registerAllDialects(registry);
circt::registerAllPasses();
Expand Down

0 comments on commit 31e652b

Please sign in to comment.