Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL] Add a new FIRRTL annotation to specify type lowering behavior of module body #7751

Merged
merged 5 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/Dialects/FIRRTL/FIRRTLAnnotations.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,34 @@ The options are:
}
```

### BodyTypeLoweringAnnotation

| Property | Type | Description |
| ------------------- | ------ | ---------------------------------------------------- |
| class | string | `circt.BodyTypeLoweringAnnotation` |
| convention | string | See `Convention` annotation |
| target | string | See `Convention` annotation |
| includeHierarchy | bool | Apply the convention to all modules in the hierarchy |

Specify the type lowering option for module internal signals.
This is similar to the `Convention` annotation, but for internal signals
rather than module ports. Refer to the `Convention` annotation for each
property description.

When `includeHierarchy` is `false`, it indicates the convention is applied only to
the specified module. If `includeHierarchy` is `true`, the convention is applied to
all modules in the hierarchy. If there are multiple annotation instances that specify
conventions, the `scalarized` convention takes precedence over the `internal` convention.

```json
{
"class": "circt.BodyTypeLoweringAnnotation",
"convention": "scalarized",
"target": "~Foo|Bar",
"includeHierarchy": true
}
```

### ElaborationArtefactsDirectory

| Property | Type | Description |
Expand Down
2 changes: 2 additions & 0 deletions include/circt/Dialect/FIRRTL/AnnotationDetails.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ constexpr const char *rawAnnotations = "rawAnnotations";
//===----------------------------------------------------------------------===//

constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation";
constexpr const char *typeLoweringAnnoClass =
"circt.BodyTypeLoweringAnnotation";
constexpr const char *dontTouchAnnoClass =
"firrtl.transforms.DontTouchAnnotation";
constexpr const char *enumComponentAnnoClass =
Expand Down
67 changes: 67 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,72 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,
return error() << "can only target to a module or extmodule";
}

static LogicalResult applyBodyTypeLoweringAnno(const AnnoPathValue &target,
DictionaryAttr anno,
ApplyState &state) {
auto *op = target.ref.getOp();
auto loc = op->getLoc();
auto error = [&]() {
auto diag = mlir::emitError(loc);
diag << typeLoweringAnnoClass;
return diag;
};

auto opTarget = dyn_cast<OpAnnoTarget>(target.ref);
if (!opTarget)
return error() << "must target a module object";

if (!target.isLocal())
return error() << "must be local";

auto moduleOp = dyn_cast<FModuleOp>(op);

if (!moduleOp)
return error() << "can only target to a module";

auto conventionStrAttr =
tryGetAs<StringAttr>(anno, anno, "convention", loc, conventionAnnoClass);

if (!conventionStrAttr)
return failure();

auto conventionStr = conventionStrAttr.getValue();
auto conventionOpt = parseConvention(conventionStr);
if (!conventionOpt)
return error() << "unknown convention " << conventionStr;

auto convention = *conventionOpt;

if (convention == Convention::Internal)
// Convention is internal by default so there is nothing to change
return success();

auto conventionAttr = ConventionAttr::get(op->getContext(), convention);

// `includeHierarchy` only valid in BodyTypeLowering.
bool includeHierarchy = false;
if (auto includeHierarchyAttr = tryGetAs<BoolAttr>(
anno, anno, "includeHierarchy", loc, conventionAnnoClass))
includeHierarchy = includeHierarchyAttr.getValue();

if (includeHierarchy) {
// If includeHierarchy is true, update the convention for all modules in
// the hierarchy.
for (auto *node :
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
if (!node)
continue;
if (auto fmodule = dyn_cast<FModuleOp>(*node->getModule()))
fmodule->setAttr("body_type_lowering", conventionAttr);
}
} else {
// Update the convention.
moduleOp->setAttr("body_type_lowering", conventionAttr);
}

return success();
}

static LogicalResult applyModulePrefixAnno(const AnnoPathValue &target,
DictionaryAttr anno,
ApplyState &state) {
Expand Down Expand Up @@ -553,6 +619,7 @@ static llvm::StringMap<AnnoRecord> annotationRecords{{
{memTapBlackboxClass, {stdResolve, applyWithoutTarget<true>}},
// Miscellaneous Annotations
{conventionAnnoClass, {stdResolve, applyConventionAnno}},
{typeLoweringAnnoClass, {stdResolve, applyBodyTypeLoweringAnno}},
{dontTouchAnnoClass,
{stdResolve, applyWithoutTarget<true, true, WireOp, NodeOp, RegOp,
RegResetOp, InstanceOp, MemOp, CombMemOp,
Expand Down
42 changes: 27 additions & 15 deletions lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is interesting about these changes is that we could move in a direction of removing the CLI options for aggregate preservation and instead rely on the annotations / attributes. As a middle ground, it may be better to change the CLI (in a follow-on) to apply the annotations or do it as a part of parsing.

This is more of a thought than a definite direction we should go.

Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {

TypeLoweringVisitor(
MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
Convention bodyConvention,
PreserveAggregate::PreserveMode memoryPreservationMode,
SymbolTable &symTbl, const AttrCache &cache,
const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
: context(context), aggregatePreservationMode(preserveAggregate),
: context(context), defaultAggregatePreservationMode(preserveAggregate),
memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
cache(cache), conventionTable(conventionTable) {}
cache(cache), conventionTable(conventionTable) {
bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
? PreserveAggregate::None
: defaultAggregatePreservationMode;
}
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
Expand Down Expand Up @@ -422,7 +427,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
Location errorLoc);

PreserveAggregate::PreserveMode
getPreservationModeForModule(FModuleLike moduleLike);
getPreservationModeForPorts(FModuleLike moduleLike);
Value getSubWhatever(Value val, size_t index);

size_t uniqueIdx = 0;
Expand All @@ -434,7 +439,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
MLIRContext *context;

/// Aggregate preservation mode.
PreserveAggregate::PreserveMode aggregatePreservationMode;
PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
PreserveAggregate::PreserveMode memoryPreservationMode;

/// The builder is set and maintained in the main loop.
Expand All @@ -453,21 +459,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
};
} // namespace

/// Return aggregate preservation mode for the module. If the module has a
/// Return aggregate preservation mode for the module ports. If the module has a
/// scalarized linkage, then we may not preserve it's aggregate ports.
PreserveAggregate::PreserveMode
TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) {
TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
auto lookup = conventionTable.find(module);
if (lookup == conventionTable.end())
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
switch (lookup->second) {
case Convention::Scalarized:
return PreserveAggregate::None;
case Convention::Internal:
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}
llvm_unreachable("Unknown convention");
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}

Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
Expand Down Expand Up @@ -636,7 +642,7 @@ bool TypeLoweringVisitor::lowerProducer(
return false;
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;

if (!peelType(srcFType, fieldTypes, aggregatePreservationMode))
if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
return false;

SmallVector<Value> lowered;
Expand Down Expand Up @@ -802,7 +808,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
// Flatten any bundle types.
SmallVector<FlatBundleFieldEntry> fieldTypes;
auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi.type);
if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module)))
if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
return false;

// Ports with internalPath set cannot be lowered.
Expand Down Expand Up @@ -922,7 +928,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
// Attempt to get the bundle types.
SmallVector<FlatBundleFieldEntry> fields;

if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode))
if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
return false;

// Loop over the leaf aggregates.
Expand Down Expand Up @@ -1455,7 +1461,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
SmallVector<Direction> newDirs;
SmallVector<Attribute> newNames;
SmallVector<Attribute> newPortAnno;
PreserveAggregate::PreserveMode mode = getPreservationModeForModule(
PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
cast<FModuleLike>(op.getReferencedOperation(symTbl)));

endFields.push_back(0);
Expand Down Expand Up @@ -1668,9 +1674,15 @@ void LowerTypesPass::runOnOperation() {

// This lambda, executes in parallel for each Op within the circt.
auto lowerModules = [&](FModuleLike op) -> LogicalResult {
// Use body type lowering attribute if it exists, otherwise use internal.
Convention convention = Convention::Internal;
if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
op->getDiscardableAttr("body_type_lowering")))
convention = conventionAttr.getValue();

auto tl =
TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories,
symTbl, cache, conventionTable);
TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
preserveMemories, symTbl, cache, conventionTable);
tl.lowerModule(op);

return LogicalResult::failure(tl.isFailed());
Expand Down
23 changes: 21 additions & 2 deletions test/Dialect/FIRRTL/annotations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [
// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = false}
]} {
// CHECK: attributes {convention = #firrtl<convention scalarized>}
// CHECK: attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations = [
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}
]} {
// CHECK: @Test() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {
firrtl.instance child @Child()
}

// CHECK: @Child() attributes {body_type_lowering = #firrtl<convention scalarized>}
firrtl.module @Child() attributes {convention = #firrtl<convention internal>} {}

// CHECK: @Child2() {
firrtl.module @Child2() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"},
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"},
Expand Down
39 changes: 39 additions & 0 deletions test/Dialect/FIRRTL/lower-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,45 @@ firrtl.circuit "UnrealizedConversion" {
}
}

firrtl.circuit "Conventions1" {
// COMMON-LABEL: @Conventions1
// AGGREGATE-SAME: %input_0
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module public @Conventions1(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions2
// AGGREGATE-SAME: %input_0: !firrtl.uint<8>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions2(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions3
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module private @Conventions3(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions4
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions4(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, exhaustive test. 👍

// Test that memories have their prefixes copied when lowering.
// See: https://github.com/llvm/circt/issues/7835
firrtl.circuit "MemoryPrefixCopying" {
Expand Down
Loading