Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calyx ConstantOp Support #7086

Merged
merged 7 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/circt/Dialect/Calyx/CalyxHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
Twine prefix);

calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, Type type,
Twine prefix);

/// A helper function to create constants in the HW dialect.
hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
Expand Down
36 changes: 36 additions & 0 deletions include/circt/Dialect/Calyx/CalyxPrimitives.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//

include "mlir/IR/BuiltinAttributeInterfaces.td"

/// Base class for Calyx primitives.
class CalyxPrimitive<string mnemonic, list<Trait> traits = []> :
CalyxCell<mnemonic, traits> {
Expand All @@ -18,6 +20,40 @@ class CalyxPrimitive<string mnemonic, list<Trait> traits = []> :
let skipDefaultBuilders = 1;
}

def ConstantOp: CalyxPrimitive<"constant",
[ConstantLike, FirstAttrDerivedResultType,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllTypesMatch<["value", "out"]>
]> {
let summary = "integer or floating point constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
floating-point constant specified by an attribute.

Example:

```
// Integer constant
%1 = calyx.constant 42 : i32

// Floating point constant
%1 = calyx.constant 42.00+e00 : f32
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
```
}];
let arguments = (ins TypedAttrInterface:$value);

let results = (outs SignlessIntegerOrFloatLike:$out);

let builders = [
/// Build a ConstantOp from a prebuilt attribute.
OpBuilder <(ins "StringRef":$sym_name, "TypedAttr":$attr)>,
];

let hasFolder = 1;
let assemblyFormat = "attr-dict $value";
let hasVerifier = 1;
}

/// The n-bit, undef op which only provides the out signal
def UndefLibOp: CalyxPrimitive<"undefined", []> {
let summary = "An undefined signal";
Expand Down
23 changes: 17 additions & 6 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,12 +891,23 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
arith::ConstantOp constOp) const {
/// Move constant operations to the compOp body as hw::ConstantOp's.
APInt value;
calyx::matchConstantOp(constOp, value);
auto hwConstOp = rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp, value);
hwConstOp->moveAfter(getComponent().getBodyBlock(),
getComponent().getBodyBlock()->begin());
if (isa<IntegerType>(constOp.getType())) {
/// Move constant operations to the compOp body as hw::ConstantOp's.
APInt value;
calyx::matchConstantOp(constOp, value);
auto hwConstOp =
rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp, value);
hwConstOp->moveAfter(getComponent().getBodyBlock(),
getComponent().getBodyBlock()->begin());
} else {
std::string name = getState<ComponentLoweringState>().getUniqueName("cst");
auto calyxConstOp = rewriter.create<calyx::ConstantOp>(
constOp.getLoc(), name, constOp.getValueAttr());
calyxConstOp->moveAfter(getComponent().getBodyBlock(),
getComponent().getBodyBlock()->begin());
rewriter.replaceAllUsesWith(constOp, calyxConstOp.getOut());
}

return success();
}

Expand Down
68 changes: 68 additions & 0 deletions lib/Dialect/Calyx/CalyxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1953,6 +1953,74 @@ ParseResult GroupDoneOp::parse(OpAsmParser &parser, OperationState &result) {
return parseGroupPort(parser, result);
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
if (isa<FloatAttr>(getValue())) {
setNameFn(getResult(), "cst");
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
return;
}
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue());
auto intType = llvm::dyn_cast<IntegerType>(getType());

// Sugar i1 constants with 'true' and 'false'.
if (intType && intType.getWidth() == 1)
return setNameFn(getResult(), intCst.getInt() > 0 ? "true" : "false");

// Otherwise, build a complex name with the value and type.
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << 'c' << intCst.getValue();
if (intType)
specialName << '_' << getType();
setNameFn(getResult(), specialName.str());
}

LogicalResult ConstantOp::verify() {
auto type = getType();
// The value's type must match the return type.
if (auto valType = getValue().getType(); valType != type) {
return emitOpError() << "value type " << valType
<< " must match return type: " << type;
}
// Integer values must be signless.
if (llvm::isa<IntegerType>(type) &&
!llvm::cast<IntegerType>(type).isSignless())
return emitOpError("integer return type must be signless");
// Any float or integers attribute are acceptable.
if (!llvm::isa<IntegerAttr, FloatAttr>(getValue())) {
return emitOpError("value must be an integer or float attribute");
}

return success();
}

OpFoldResult calyx::ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}

void calyx::ConstantOp::build(OpBuilder &builder, OperationState &state,
StringRef symName, TypedAttr attr) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("value", attr);
SmallVector<Type> types;
types.push_back(attr.getType()); // Out
state.addTypes(types);
}

SmallVector<StringRef> ConstantOp::portNames() { return {"out"}; }

SmallVector<Direction> ConstantOp::portDirections() { return {Output}; }

SmallVector<DictionaryAttr> ConstantOp::portAttributes() {
return {DictionaryAttr::get(getContext())};
}

bool ConstantOp::isCombinational() { return true; }

//===----------------------------------------------------------------------===//
// RegisterOp
//===----------------------------------------------------------------------===//
Expand Down
34 changes: 31 additions & 3 deletions lib/Dialect/Calyx/Export/CalyxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include <bitset>
#include <string>

using namespace circt;
using namespace calyx;
Expand Down Expand Up @@ -142,6 +145,10 @@ struct ImportTracker {
static constexpr std::string_view sMemories = "memories/seq";
return {sMemories};
})
.Case<ConstantOp>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sFloat = "float";
return {sFloat};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
Expand Down Expand Up @@ -253,6 +260,9 @@ struct Emitter {
// Invoke emission
void emitInvoke(InvokeOp invoke);

// Floating point Constant emission
void emitConstant(ConstantOp constant);

// Emits a library primitive with template parameters based on all in- and
// output ports.
// e.g.:
Expand Down Expand Up @@ -445,7 +455,7 @@ struct Emitter {
return;
}

auto definingOp = value.getDefiningOp();
auto *definingOp = value.getDefiningOp();
assert(definingOp && "Value does not have a defining operation.");

TypeSwitch<Operation *>(definingOp)
Expand Down Expand Up @@ -638,6 +648,7 @@ void Emitter::emitComponent(ComponentInterface op) {
.Case<MemoryOp>([&](auto op) { emitMemory(op); })
.Case<SeqMemoryOp>([&](auto op) { emitSeqMemory(op); })
.Case<hw::ConstantOp>([&](auto op) { /*Do nothing*/ })
.Case<calyx::ConstantOp>([&](auto op) { emitConstant(op); })
.Case<SliceLibOp, PadLibOp, ExtSILibOp>(
[&](auto op) { emitLibraryPrimTypedByAllPorts(op); })
.Case<LtLibOp, GtLibOp, EqLibOp, NeqLibOp, GeLibOp, LeLibOp, SltLibOp,
Expand Down Expand Up @@ -899,6 +910,23 @@ void Emitter::emitInvoke(InvokeOp invoke) {
os << RParen() << semicolonEndL();
}

void Emitter::emitConstant(ConstantOp constantOp) {
TypedAttr attr = constantOp.getValueAttr();
assert(isa<FloatAttr>(attr) && "must be a floating point constant");
auto fltAttr = cast<FloatAttr>(attr);
APFloat value = fltAttr.getValue();
auto type = cast<FloatType>(fltAttr.getType());
double doubleValue = value.convertToDouble();
auto floatBits = value.getSizeInBits(type.getFloatSemantics());
indent() << constantOp.getName().str() << space() << equals() << space()
<< "std_float_const";
// Currently defaults to IEEE-754 representation [1].
// [1]: https://github.com/calyxir/calyx/blob/main/primitives/float.futil
static constexpr int32_t IEEE754 = 0;
os << LParen() << std::to_string(IEEE754) << comma() << floatBits << comma()
<< std::to_string(doubleValue) << RParen() << semicolonEndL();
}

/// Calling getName() on a calyx operation will return "calyx.${opname}". This
/// function returns whatever is left after the first '.' in the string,
/// removing the 'calyx' prefix.
Expand Down Expand Up @@ -954,8 +982,8 @@ void Emitter::emitWires(WiresOp op) {
TypeSwitch<Operation *>(&bodyOp)
.Case<GroupInterface>([&](auto op) { emitGroup(op); })
.Case<AssignOp>([&](auto op) { emitAssignment(op); })
.Case<hw::ConstantOp, comb::AndOp, comb::OrOp, comb::XorOp, CycleOp>(
[&](auto op) { /* Do nothing. */ })
.Case<hw::ConstantOp, calyx::ConstantOp, comb::AndOp, comb::OrOp,
comb::XorOp, CycleOp>([&](auto op) { /* Do nothing. */ })
.Default([&](auto op) {
emitOpError(op, "not supported for emission inside wires section");
});
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Calyx/Transforms/CalyxHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
return builder.create<RegisterOp>(loc, (prefix + "_reg").str(), width);
}

calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, Type type,
Twine prefix) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(component.getBodyBlock());
return builder.create<RegisterOp>(loc, (prefix + "_reg").str(), type);
}

hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
size_t value) {
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,10 @@ void InlineCombGroups::recurseInlineCombGroups(
// LateSSAReplacement)
if (isa<BlockArgument>(src) ||
isa<calyx::RegisterOp, calyx::MemoryOp, calyx::SeqMemoryOp,
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp>(
src.getDefiningOp()))
calyx::ConstantOp, hw::ConstantOp, mlir::arith::ConstantOp,
calyx::MultPipeLibOp, calyx::DivUPipeLibOp, calyx::DivSPipeLibOp,
calyx::RemSPipeLibOp, calyx::RemUPipeLibOp, mlir::scf::WhileOp,
calyx::InstanceOp>(src.getDefiningOp()))
continue;

auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(
Expand Down Expand Up @@ -753,11 +753,11 @@ BuildReturnRegs::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,

for (auto argType : enumerate(funcOp.getResultTypes())) {
auto convArgType = calyx::convIndexType(rewriter, argType.value());
assert(isa<IntegerType>(convArgType) && "unsupported return type");
unsigned width = convArgType.getIntOrFloatBitWidth();
assert((isa<IntegerType>(convArgType) || isa<FloatType>(convArgType)) &&
"unsupported return type");
std::string name = "ret_arg" + std::to_string(argType.index());
auto reg =
createRegister(funcOp.getLoc(), rewriter, getComponent(), width, name);
auto reg = createRegister(funcOp.getLoc(), rewriter, getComponent(),
convArgType, name);
getState().addReturnReg(reg, argType.index());

rewriter.setInsertionPointToStart(
Expand Down
24 changes: 24 additions & 0 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,27 @@ module {
return %0, %1 : i8, i8
}
}

// -----

// Test integer and floating point constant

// CHECK: calyx.group @ret_assign_0 {
// CHECK-DAG: calyx.assign %ret_arg0_reg.in = %in0 : f32
// CHECK-DAG: calyx.assign %ret_arg0_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %ret_arg1_reg.in = %c42_i32 : i32
// CHECK-DAG: calyx.assign %ret_arg1_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %ret_arg2_reg.in = %cst : f32
// CHECK-DAG: calyx.assign %ret_arg2_reg.write_en = %true : i1
// CHECK-DAG: %0 = comb.and %ret_arg2_reg.done, %ret_arg1_reg.done, %ret_arg0_reg.done : i1
// CHECK-DAG: calyx.group_done %0 ? %true : i1
// CHECK-DAG: }

module {
func.func @main(%arg0 : f32) -> (f32, i32, f32) {
%0 = arith.constant 42 : i32
%1 = arith.constant 4.2e+1 : f32

return %arg0, %0, %1 : f32, i32, f32
}
}
39 changes: 39 additions & 0 deletions test/Dialect/Calyx/emit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,42 @@ module attributes {calyx.entrypoint = "main"} {
}
}
}

// -----

module attributes {calyx.entrypoint = "main"} {
calyx.component @main(%clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %out1: f32, %done: i1 {done}) {
// CHECK: cst_0 = std_float_const(0, 32, 4.200000);
%c42_i32 = hw.constant 42 : i32
%cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32
%true = hw.constant true
%ret_arg1_reg.in, %ret_arg1_reg.write_en, %ret_arg1_reg.clk, %ret_arg1_reg.reset, %ret_arg1_reg.out, %ret_arg1_reg.done = calyx.register @ret_arg1_reg : f32, i1, i1, i1, f32, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
calyx.wires {
calyx.assign %out1 = %ret_arg1_reg.out : f32
calyx.assign %out0 = %ret_arg0_reg.out : i32

// CHECK-LABEL: group ret_assign_0 {
// CHECK-NEXT: ret_arg0_reg.in = 32'd42;
// CHECK-NEXT: ret_arg0_reg.write_en = 1'd1;
// CHECK-NEXT: ret_arg1_reg.in = cst_0.out;
// CHECK-NEXT: ret_arg1_reg.write_en = 1'd1;
// CHECK-NEXT: ret_assign_0[done] = (ret_arg1_reg.done & ret_arg0_reg.done) ? 1'd1;
// CHECK-NEXT: }
calyx.group @ret_assign_0 {
calyx.assign %ret_arg0_reg.in = %c42_i32 : i32
calyx.assign %ret_arg0_reg.write_en = %true : i1
calyx.assign %ret_arg1_reg.in = %cst : f32
calyx.assign %ret_arg1_reg.write_en = %true : i1
%0 = comb.and %ret_arg1_reg.done, %ret_arg0_reg.done : i1
calyx.group_done %0 ? %true : i1
}
}
calyx.control {
calyx.seq {
calyx.enable @ret_assign_0
}
}
} {toplevel}
}

Loading