From 2b612a2dcd1cc0eeb91dcbcfc0cb288f77f20ff3 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Tue, 7 Feb 2023 19:56:28 +0900 Subject: [PATCH] [HWLegalizeModules] Legalize aggregate constant (#4626) This PR extends HWLegalizeModules to support aggregate constants. Also `signalPassFailure` is called when we fail to legalize expressions. Fix https://github.com/llvm/circt/issues/4623. --- .../SV/Transforms/HWLegalizeModules.cpp | 41 +++++++++++++------ .../SV/hw-legalize-modules-packed-arrays.mlir | 34 +++++++++++++-- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp b/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp index 79784d02640f..025d866250ba 100644 --- a/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp +++ b/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp @@ -56,33 +56,46 @@ struct HWLegalizeModulesPass /// This returns a replacement operation if lowering was successful, null /// otherwise. Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) { - // If the operand is an array_create, then we can lower this into a casez. - auto createOp = getOp.getInput().getDefiningOp(); - if (!createOp) + SmallVector caseValues; + OpBuilder builder(&thisHWModule.getBodyBlock()->front()); + // If the operand is an array_create or aggregate constant, then we can lower + // this into a casez. + if (auto createOp = getOp.getInput().getDefiningOp()) + caseValues = SmallVector(llvm::reverse(createOp.getOperands())); + else if (auto aggregateConstant = + getOp.getInput().getDefiningOp()) { + for (auto elem : llvm::reverse(aggregateConstant.getFields())) { + if (auto intAttr = dyn_cast(elem)) + caseValues.push_back(builder.create( + aggregateConstant.getLoc(), intAttr)); + else + caseValues.push_back(builder.create( + aggregateConstant.getLoc(), getOp.getType(), + elem.cast())); + } + } else { return nullptr; + } // array_get(idx, array_create(a,b,c,d)) ==> casez(idx). Value index = getOp.getIndex(); // Create the wire for the result of the casez in the hw.module. - OpBuilder builder(&thisHWModule.getBodyBlock()->front()); - auto theWire = builder.create(getOp.getLoc(), getOp.getType(), builder.getStringAttr("casez_tmp")); builder.setInsertionPoint(getOp); + auto loc = getOp.getInput().getDefiningOp()->getLoc(); // A casez is a procedural operation, so if we're in a non-procedural region // we need to inject an always_comb block. if (!getOp->getParentOp()->hasTrait()) { - auto alwaysComb = builder.create(createOp.getLoc()); + auto alwaysComb = builder.create(loc); builder.setInsertionPointToEnd(alwaysComb.getBodyBlock()); } // If we are missing elements in the array (it is non-power of two), then // add a default 'X' value. - SmallVector caseValues(llvm::reverse(createOp.getOperands())); - if (1ULL << index.getType().getIntOrFloatBitWidth() != - createOp.getNumOperands()) { + if (1ULL << index.getType().getIntOrFloatBitWidth() != caseValues.size()) { caseValues.push_back( builder.create(getOp.getLoc(), getOp.getType())); } @@ -92,7 +105,7 @@ Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) { // Create the casez itself. builder.create( - createOp.getLoc(), CaseStmtType::CaseZStmt, index, caseValues.size(), + loc, CaseStmtType::CaseZStmt, index, caseValues.size(), [&](size_t caseIdx) -> std::unique_ptr { // Use a default pattern for the last value, even if we are complete. // This avoids tools thinking they need to insert a latch due to @@ -106,7 +119,7 @@ Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) { else thePattern = std::make_unique(caseValue, context); ++caseValue; - builder.create(createOp.getLoc(), theWire, theValue); + builder.create(loc, theWire, theValue); return thePattern; }); @@ -149,9 +162,10 @@ void HWLegalizeModulesPass::processPostOrder(Block &body) { continue; } - // If this is a dead array_create, then we can just delete it. This is + // If this is a dead array, then we can just delete it. This is // probably left over from get/create lowering. - if (isa(op) && op.use_empty()) { + if (isa(op) && + op.use_empty()) { op.erase(); continue; } @@ -163,6 +177,7 @@ void HWLegalizeModulesPass::processPostOrder(Block &body) { for (auto value : op.getResults()) { if (value.getType().isa()) { op.emitError("unsupported packed array expression"); + signalPassFailure(); } } } diff --git a/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir b/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir index b1030bb253e0..682d2c702316 100644 --- a/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir +++ b/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir @@ -1,8 +1,6 @@ -// RUN: circt-opt -hw-legalize-modules -verify-diagnostics %s | FileCheck %s +// RUN: circt-opt -split-input-file -hw-legalize-modules -verify-diagnostics %s | FileCheck %s module attributes {circt.loweringOptions = "disallowPackedArrays"} { - -// CHECK-LABEL: hw.module @reject_arrays hw.module @reject_arrays(%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8, %sel: i2, %clock: i1) -> (a: !hw.array<4xi8>) { @@ -19,7 +17,10 @@ hw.module @reject_arrays(%arg0: i8, %arg1: i8, %arg2: i8, %1 = sv.read_inout %reg : !hw.inout> hw.output %1 : !hw.array<4xi8> } +} +// ----- +module attributes {circt.loweringOptions = "disallowPackedArrays"} { // CHECK-LABEL: hw.module @array_create_get_comb hw.module @array_create_get_comb(%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8, %sel: i2) @@ -83,4 +84,31 @@ hw.module @array_create_get_default(%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8, } } +// CHECK-LABEL: hw.module @array_constant_get_comb +hw.module @array_constant_get_comb(%sel: i2) + -> (a: i8) { + // CHECK: %casez_tmp = sv.reg : !hw.inout + // CHECK: sv.alwayscomb { + // CHECK: sv.case casez %sel : i2 + // CHECK: case b00: { + // CHECK: sv.bpassign %casez_tmp, %c3_i8 : i8 + // CHECK: } + // CHECK: case b01: { + // CHECK: sv.bpassign %casez_tmp, %c2_i8 : i8 + // CHECK: } + // CHECK: case b10: { + // CHECK: sv.bpassign %casez_tmp, %c1_i8 : i8 + // CHECK: } + // CHECK: default: { + // CHECK: sv.bpassign %casez_tmp, %c0_i8 : i8 + // CHECK: } + // CHECK: } + %0 = hw.aggregate_constant [0 : i8, 1 : i8, 2 : i8, 3 : i8] : !hw.array<4xi8> + // CHECK: %0 = sv.read_inout %casez_tmp : !hw.inout + %1 = hw.array_get %0[%sel] : !hw.array<4xi8>, i2 + + // CHECK: hw.output %0 : i8 + hw.output %1 : i8 +} + } // end builtin.module