From 0c1465dd92b224d61f3007d7e4792dab340a9c4c Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Wed, 4 Dec 2024 23:29:06 -0800 Subject: [PATCH] [FIRRTL] Add a new FIRRTL annotation to specify type lowering behavior of module body (#7751) Add a new annotation to control type lowering behavior for internal signals within a module, separate from the port convention. This allows more fine-grained control over how aggregate types are handled inside modules. The new annotation works similarly to ConventionAnnotation but applies to internal signals rather than module ports. It supports the same conventions and includes an 'includeHierarchy' option to apply the setting to all modules in the hierarchy. --- docs/Dialects/FIRRTL/FIRRTLAnnotations.md | 28 ++++++++ .../circt/Dialect/FIRRTL/AnnotationDetails.h | 2 + .../FIRRTL/Transforms/LowerAnnotations.cpp | 67 +++++++++++++++++++ lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp | 42 +++++++----- test/Dialect/FIRRTL/annotations.mlir | 23 ++++++- test/Dialect/FIRRTL/lower-types.mlir | 39 +++++++++++ 6 files changed, 184 insertions(+), 17 deletions(-) diff --git a/docs/Dialects/FIRRTL/FIRRTLAnnotations.md b/docs/Dialects/FIRRTL/FIRRTLAnnotations.md index 695f57ba482a..4ec9a5e730ba 100644 --- a/docs/Dialects/FIRRTL/FIRRTLAnnotations.md +++ b/docs/Dialects/FIRRTL/FIRRTLAnnotations.md @@ -345,6 +345,34 @@ The options are: } ``` +### BodyTypeLoweringAnnotation + +| Property | Type | Description | +| ------------------- | ------ | ---------------------------------------------------- | +| class | string | `circt.BodyTypeLoweringAnnotation` | +| convention | string | See `Convention` annotation | +| target | string | See `Convention` annotation | +| includeHierarchy | bool | Apply the convention to all modules in the hierarchy | + +Specify the type lowering option for module internal signals. +This is similar to the `Convention` annotation, but for internal signals +rather than module ports. Refer to the `Convention` annotation for each +property description. + +When `includeHierarchy` is `false`, it indicates the convention is applied only to +the specified module. If `includeHierarchy` is `true`, the convention is applied to +all modules in the hierarchy. If there are multiple annotation instances that specify +conventions, the `scalarized` convention takes precedence over the `internal` convention. + +```json +{ + "class": "circt.BodyTypeLoweringAnnotation", + "convention": "scalarized", + "target": "~Foo|Bar", + "includeHierarchy": true +} +``` + ### ElaborationArtefactsDirectory | Property | Type | Description | diff --git a/include/circt/Dialect/FIRRTL/AnnotationDetails.h b/include/circt/Dialect/FIRRTL/AnnotationDetails.h index a1cf73979839..186d5482f6a6 100644 --- a/include/circt/Dialect/FIRRTL/AnnotationDetails.h +++ b/include/circt/Dialect/FIRRTL/AnnotationDetails.h @@ -29,6 +29,8 @@ constexpr const char *rawAnnotations = "rawAnnotations"; //===----------------------------------------------------------------------===// constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation"; +constexpr const char *typeLoweringAnnoClass = + "circt.BodyTypeLoweringAnnotation"; constexpr const char *dontTouchAnnoClass = "firrtl.transforms.DontTouchAnnotation"; constexpr const char *enumComponentAnnoClass = diff --git a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp index 1ec3f630c515..178a257f6bba 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp @@ -311,6 +311,72 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target, return error() << "can only target to a module or extmodule"; } +static LogicalResult applyBodyTypeLoweringAnno(const AnnoPathValue &target, + DictionaryAttr anno, + ApplyState &state) { + auto *op = target.ref.getOp(); + auto loc = op->getLoc(); + auto error = [&]() { + auto diag = mlir::emitError(loc); + diag << typeLoweringAnnoClass; + return diag; + }; + + auto opTarget = dyn_cast(target.ref); + if (!opTarget) + return error() << "must target a module object"; + + if (!target.isLocal()) + return error() << "must be local"; + + auto moduleOp = dyn_cast(op); + + if (!moduleOp) + return error() << "can only target to a module"; + + auto conventionStrAttr = + tryGetAs(anno, anno, "convention", loc, conventionAnnoClass); + + if (!conventionStrAttr) + return failure(); + + auto conventionStr = conventionStrAttr.getValue(); + auto conventionOpt = parseConvention(conventionStr); + if (!conventionOpt) + return error() << "unknown convention " << conventionStr; + + auto convention = *conventionOpt; + + if (convention == Convention::Internal) + // Convention is internal by default so there is nothing to change + return success(); + + auto conventionAttr = ConventionAttr::get(op->getContext(), convention); + + // `includeHierarchy` only valid in BodyTypeLowering. + bool includeHierarchy = false; + if (auto includeHierarchyAttr = tryGetAs( + anno, anno, "includeHierarchy", loc, conventionAnnoClass)) + includeHierarchy = includeHierarchyAttr.getValue(); + + if (includeHierarchy) { + // If includeHierarchy is true, update the convention for all modules in + // the hierarchy. + for (auto *node : + llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) { + if (!node) + continue; + if (auto fmodule = dyn_cast(*node->getModule())) + fmodule->setAttr("body_type_lowering", conventionAttr); + } + } else { + // Update the convention. + moduleOp->setAttr("body_type_lowering", conventionAttr); + } + + return success(); +} + static LogicalResult applyModulePrefixAnno(const AnnoPathValue &target, DictionaryAttr anno, ApplyState &state) { @@ -553,6 +619,7 @@ static llvm::StringMap annotationRecords{{ {memTapBlackboxClass, {stdResolve, applyWithoutTarget}}, // Miscellaneous Annotations {conventionAnnoClass, {stdResolve, applyConventionAnno}}, + {typeLoweringAnnoClass, {stdResolve, applyBodyTypeLoweringAnno}}, {dontTouchAnnoClass, {stdResolve, applyWithoutTarget { TypeLoweringVisitor( MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate, + Convention bodyConvention, PreserveAggregate::PreserveMode memoryPreservationMode, SymbolTable &symTbl, const AttrCache &cache, const llvm::DenseMap &conventionTable) - : context(context), aggregatePreservationMode(preserveAggregate), + : context(context), defaultAggregatePreservationMode(preserveAggregate), memoryPreservationMode(memoryPreservationMode), symTbl(symTbl), - cache(cache), conventionTable(conventionTable) {} + cache(cache), conventionTable(conventionTable) { + bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized + ? PreserveAggregate::None + : defaultAggregatePreservationMode; + } using FIRRTLVisitor::visitDecl; using FIRRTLVisitor::visitExpr; using FIRRTLVisitor::visitStmt; @@ -422,7 +427,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { Location errorLoc); PreserveAggregate::PreserveMode - getPreservationModeForModule(FModuleLike moduleLike); + getPreservationModeForPorts(FModuleLike moduleLike); Value getSubWhatever(Value val, size_t index); size_t uniqueIdx = 0; @@ -434,7 +439,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { MLIRContext *context; /// Aggregate preservation mode. - PreserveAggregate::PreserveMode aggregatePreservationMode; + PreserveAggregate::PreserveMode defaultAggregatePreservationMode; + PreserveAggregate::PreserveMode bodyAggregatePreservationMode; PreserveAggregate::PreserveMode memoryPreservationMode; /// The builder is set and maintained in the main loop. @@ -453,21 +459,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { }; } // namespace -/// Return aggregate preservation mode for the module. If the module has a +/// Return aggregate preservation mode for the module ports. If the module has a /// scalarized linkage, then we may not preserve it's aggregate ports. PreserveAggregate::PreserveMode -TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) { +TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) { auto lookup = conventionTable.find(module); if (lookup == conventionTable.end()) - return aggregatePreservationMode; + return defaultAggregatePreservationMode; switch (lookup->second) { case Convention::Scalarized: return PreserveAggregate::None; case Convention::Internal: - return aggregatePreservationMode; + return defaultAggregatePreservationMode; } llvm_unreachable("Unknown convention"); - return aggregatePreservationMode; + return defaultAggregatePreservationMode; } Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) { @@ -636,7 +642,7 @@ bool TypeLoweringVisitor::lowerProducer( return false; SmallVector fieldTypes; - if (!peelType(srcFType, fieldTypes, aggregatePreservationMode)) + if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode)) return false; SmallVector lowered; @@ -805,7 +811,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex, // Flatten any bundle types. SmallVector fieldTypes; auto srcType = type_cast(newArgs[argIndex].pi.type); - if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module))) + if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module))) return false; // Ports with internalPath set cannot be lowered. @@ -925,7 +931,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) { // Attempt to get the bundle types. SmallVector fields; - if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode)) + if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode)) return false; // Loop over the leaf aggregates. @@ -1458,7 +1464,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { SmallVector newDirs; SmallVector newNames; SmallVector newPortAnno; - PreserveAggregate::PreserveMode mode = getPreservationModeForModule( + PreserveAggregate::PreserveMode mode = getPreservationModeForPorts( cast(op.getReferencedOperation(symTbl))); endFields.push_back(0); @@ -1662,9 +1668,15 @@ void LowerTypesPass::runOnOperation() { // This lambda, executes in parallel for each Op within the circt. auto lowerModules = [&](FModuleLike op) -> LogicalResult { + // Use body type lowering attribute if it exists, otherwise use internal. + Convention convention = Convention::Internal; + if (auto conventionAttr = dyn_cast_or_null( + op->getDiscardableAttr("body_type_lowering"))) + convention = conventionAttr.getValue(); + auto tl = - TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories, - symTbl, cache, conventionTable); + TypeLoweringVisitor(&getContext(), preserveAggregate, convention, + preserveMemories, symTbl, cache, conventionTable); tl.lowerModule(op); return LogicalResult::failure(tl.isFailed()); diff --git a/test/Dialect/FIRRTL/annotations.mlir b/test/Dialect/FIRRTL/annotations.mlir index 3568fc4b498b..cb0075ead789 100644 --- a/test/Dialect/FIRRTL/annotations.mlir +++ b/test/Dialect/FIRRTL/annotations.mlir @@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [ // ----- firrtl.circuit "Test" attributes {rawAnnotations =[ - {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"} + {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}, + {class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = false} ]} { - // CHECK: attributes {convention = #firrtl} + // CHECK: attributes {body_type_lowering = #firrtl, convention = #firrtl} firrtl.module @Test() attributes {convention = #firrtl} {} } // ----- +firrtl.circuit "Test" attributes {rawAnnotations = [ + {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}, + {class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true} + ]} { + // CHECK: @Test() attributes {body_type_lowering = #firrtl, convention = #firrtl} + firrtl.module @Test() attributes {convention = #firrtl} { + firrtl.instance child @Child() + } + + // CHECK: @Child() attributes {body_type_lowering = #firrtl} + firrtl.module @Child() attributes {convention = #firrtl} {} + + // CHECK: @Child2() { + firrtl.module @Child2() attributes {convention = #firrtl} {} +} + +// ----- + firrtl.circuit "Test" attributes {rawAnnotations =[ {class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"}, {class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"}, diff --git a/test/Dialect/FIRRTL/lower-types.mlir b/test/Dialect/FIRRTL/lower-types.mlir index 2de51c8d835f..bb38a5787c30 100644 --- a/test/Dialect/FIRRTL/lower-types.mlir +++ b/test/Dialect/FIRRTL/lower-types.mlir @@ -1405,6 +1405,45 @@ firrtl.circuit "UnrealizedConversion" { } } +firrtl.circuit "Conventions1" { + // COMMON-LABEL: @Conventions1 + // AGGREGATE-SAME: %input_0 + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.vector, 1> + firrtl.module public @Conventions1(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions2 + // AGGREGATE-SAME: %input_0: !firrtl.uint<8> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.uint<8> + firrtl.module private @Conventions2(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions3 + // AGGREGATE-SAME: %input: !firrtl.vector, 1> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.vector, 1> + firrtl.module private @Conventions3(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions4 + // AGGREGATE-SAME: %input: !firrtl.vector, 1> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.uint<8> + firrtl.module private @Conventions4(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } +} + // Test that memories have their prefixes copied when lowering. // See: https://github.com/llvm/circt/issues/7835 firrtl.circuit "MemoryPrefixCopying" {