Skip to content

Commit

Permalink
add addFN to Calyx
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed May 2, 2024
1 parent 0051651 commit 3c375e3
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 8 deletions.
9 changes: 9 additions & 0 deletions include/circt/Dialect/Calyx/CalyxPrimitives.td
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,15 @@ class CombinationalArithBinaryLibraryOp<string mnemonic> :
let results = (outs AnyType:$left, AnyType:$right, AnyType:$out);
}

class ArithBinaryFloatingPointLibraryOp<string mnemonic> : ArithBinaryLibraryOp<mnemonic, [
SameTypeConstraint<"left", "out">
]> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
AnyFloat:$left, AnyFloat:$right, AnySignlessInteger:$roundingMode, AnyFloat:$out,
AnySignlessInteger:$execptionalFlags, I1:$done);
}

def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {}
def AddLibOp : CombinationalArithBinaryLibraryOp<"add"> {}
def SubLibOp : CombinationalArithBinaryLibraryOp<"sub"> {}
def ShruLibOp : CombinationalArithBinaryLibraryOp<"shru"> {}
Expand Down
93 changes: 88 additions & 5 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "circt/Dialect/Calyx/CalyxOps.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Support/LLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -28,6 +29,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"

#include <variant>
Expand Down Expand Up @@ -379,7 +381,66 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {

return success();
}


/// buildLibraryBinaryFloatingPointOp will build a
/// TCalyxLibBinaryFloatingPointOp, to deal with AddFNOp
template <typename TOpType, typename TSrcOp>
LogicalResult buildLibraryBinaryFloatingPointOp(PatternRewriter &rewriter,
TSrcOp op, TOpType opFN,
Value out) const {
StringRef opName = TSrcOp::getOperationName().split(".").second;
Location loc = op.getLoc();
Type width = op.getResult().getType();

// Pass the result from the Operation to the Calyx primitive.
op.getResult().replaceAllUsesWith(out);
auto reg = createRegister(
op.getLoc(), rewriter, getComponent(), width,
getState<ComponentLoweringState>().getUniqueName(opName));
// Floating point number calculations are not combinational, so a GroupOp is
// required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
OpBuilder builder(group->getRegion(0));
getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
group);

rewriter.setInsertionPointToEnd(group.getBodyBlock());
rewriter.create<calyx::AssignOp>(loc, opFN.getLeft(), op.getLhs());
rewriter.create<calyx::AssignOp>(loc, opFN.getRight(), op.getRhs());
// Write the output to this register.
rewriter.create<calyx::AssignOp>(loc, reg.getIn(), out);
// The write enable port is high when the calculation is done.
rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), opFN.getDone());
// Set opFN to high as long as its done signal is not high.
// This prevents the opFN from executing for the cycle that we write
// to register. To get !(opFN.done) we do 1 xor opFN.done
hw::ConstantOp c1 = createConstant(loc, rewriter, getComponent(), 1, 1);
rewriter.create<calyx::AssignOp>(
loc, opFN.getGo(), c1,
comb::createOrFoldNot(group.getLoc(), opFN.getDone(), builder));
// The group is done when the register write is complete.
rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());

if (isa<calyx::AddFNOp>(opFN)) {
hw::ConstantOp subOp;
if (isa<arith::AddFOp>(op)) {
subOp = createConstant(loc, rewriter, getComponent(), 1, 0);
} else {
subOp = createConstant(loc, rewriter, getComponent(), 1, 1);
}
rewriter.create<calyx::AssignOp>(loc, opFN.getSubOp(), subOp);
}

// Register the values for the calculation.
getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
getState<ComponentLoweringState>().registerEvaluatingGroup(opFN.getLeft(),
group);
getState<ComponentLoweringState>().registerEvaluatingGroup(opFN.getRight(),
group);

return success();
}

/// Creates assignments within the provided group to the address ports of the
/// memoryOp based on the provided addressValues.
void assignAddressPorts(PatternRewriter &rewriter, Location loc,
Expand Down Expand Up @@ -499,10 +560,17 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
// for sequential memories will cause a read to take at least 2 cycles,
// but it will usually be better because combinational reads on memories
// can significantly decrease the maximum achievable frequency.
auto reg = createRegister(
loadOp.getLoc(), rewriter, getComponent(),
loadOp.getMemRefType().getElementType(),
getState<ComponentLoweringState>().getUniqueName("load"));
calyx::RegisterOp reg;
if (loadOp.getMemRefType().isa<IntegerType>())
reg = createRegister(
loadOp.getLoc(), rewriter, getComponent(),
loadOp.getMemRefType().getElementTypeBitWidth(),
getState<ComponentLoweringState>().getUniqueName("load"));
else
reg = createRegister(
loadOp.getLoc(), rewriter, getComponent(),
loadOp.getMemRefType().getElementType(),
getState<ComponentLoweringState>().getUniqueName("load"));
rewriter.setInsertionPointToEnd(group.getBodyBlock());
rewriter.create<calyx::AssignOp>(loadOp.getLoc(), reg.getIn(),
memoryInterface.readData());
Expand Down Expand Up @@ -823,6 +891,21 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
AddFOp addf) const {
Location loc = addf.getLoc();
Type width = addf.getResult().getType();
IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
five = rewriter.getIntegerType(5);
auto addFN =
getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AddFNOp>(
rewriter, loc,
{one, one, one, one, one, width, width, three, width, five, one});
return buildLibraryBinaryFloatingPointOp<calyx::AddFNOp>(
rewriter, addf, addFN, addFN.getOut());
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
AddIOp op) const {
return buildLibraryOp<calyx::CombGroupOp, calyx::AddLibOp>(rewriter, op);
Expand Down
46 changes: 46 additions & 0 deletions lib/Dialect/Calyx/CalyxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2989,7 +2989,53 @@ LogicalResult SliceLibOp::verify() {
DictionaryAttr::get(getContext())}; \
}

#define ImplBinFloatingPointOpCellInterface(OpType) \
SmallVector<StringRef> OpType::portNames() { \
return { \
clkPort, resetPort, goPort, "control", "subOp", "left", \
"right", "roundingMode", "out", "exceptionFlags", donePort}; \
} \
\
SmallVector<Direction> OpType::portDirections() { \
return {Input, Input, Input, Input, Input, Input, \
Input, Input, Output, Output, Output}; \
} \
\
void OpType::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { \
getCellAsmResultNames(setNameFn, *this, this->portNames()); \
} \
\
SmallVector<DictionaryAttr> OpType::portAttributes() { \
MLIRContext *context = getContext(); \
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(context, 1), 1); \
NamedAttrList go, clk, reset, done; \
go.append(goPort, isSet); \
clk.append(clkPort, isSet); \
reset.append(resetPort, isSet); \
done.append(donePort, isSet); \
return { \
clk.getDictionary(context), /* Clk */ \
reset.getDictionary(context), /* Reset */ \
go.getDictionary(context), /* Go */ \
DictionaryAttr::get(context), /* Control */ \
DictionaryAttr::get(context), /* subOp */ \
DictionaryAttr::get(context), /* roundingMode */ \
DictionaryAttr::get(context), /* Lhs */ \
DictionaryAttr::get(context), /* Rhs */ \
DictionaryAttr::get(context), /* Out */ \
done.getDictionary(context), /* Done */ \
DictionaryAttr::get(context) /* exceptionFlags */ \
}; \
} \
\
\
bool \
OpType::isCombinational() { \
return false; \
}

// clang-format off
ImplBinFloatingPointOpCellInterface(AddFNOp)
ImplBinPipeOpCellInterface(MultPipeLibOp, "out")
ImplBinPipeOpCellInterface(DivUPipeLibOp, "out_quotient")
ImplBinPipeOpCellInterface(DivSPipeLibOp, "out_quotient")
Expand Down
39 changes: 37 additions & 2 deletions lib/Dialect/Calyx/Export/CalyxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ struct ImportTracker {
static constexpr std::string_view sMemories = "memories/seq";
return {sMemories};
})
.Case<AddFNOp>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sFloatingPoint = "float/addFN";
return {sFloatingPoint};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
Expand Down Expand Up @@ -277,7 +281,9 @@ struct Emitter {
// f = std_foo(1);
void emitLibraryPrimTypedByFirstOutputPort(
Operation *op, std::optional<StringRef> calyxLibName = {});


void emitLibraryFloatingPoint(Operation *op);

private:
/// Used to track which imports are required for this program.
ImportTracker importTracker;
Expand Down Expand Up @@ -650,14 +656,16 @@ void Emitter::emitComponent(ComponentInterface op) {
.Case<RegisterOp>([&](auto op) { emitRegister(op); })
.Case<MemoryOp>([&](auto op) { emitMemory(op); })
.Case<SeqMemoryOp>([&](auto op) { emitSeqMemory(op); })
.Case<hw::ConstantOp, calyx::ConstantOp>([&](auto op) { /*Do nothing*/ })
.Case<hw::ConstantOp, calyx::ConstantOp>(
[&](auto op) { /*Do nothing*/ })
.Case<SliceLibOp, PadLibOp, ExtSILibOp>(
[&](auto op) { emitLibraryPrimTypedByAllPorts(op); })
.Case<LtLibOp, GtLibOp, EqLibOp, NeqLibOp, GeLibOp, LeLibOp, SltLibOp,
SgtLibOp, SeqLibOp, SneqLibOp, SgeLibOp, SleLibOp, AddLibOp,
SubLibOp, ShruLibOp, RshLibOp, SrshLibOp, LshLibOp, AndLibOp,
NotLibOp, OrLibOp, XorLibOp, WireLibOp>(
[&](auto op) { emitLibraryPrimTypedByFirstInputPort(op); })
.Case<AddFNOp>([&](auto op) { emitLibraryFloatingPoint(op); })
.Case<MuxLibOp>(
[&](auto op) { emitLibraryPrimTypedByFirstOutputPort(op); })
.Case<MultPipeLibOp>(
Expand Down Expand Up @@ -949,6 +957,33 @@ void Emitter::emitLibraryPrimTypedByFirstOutputPort(
<< LParen() << bitWidth << RParen() << semicolonEndL();
}

void Emitter::emitLibraryFloatingPoint(Operation *op) {
auto cell = cast<CellInterface>(op);
unsigned bitWidth =
cell.getOutputPorts()[0].getType().getIntOrFloatBitWidth();
unsigned expWidth, sigWidth;
if (bitWidth == 16) {
expWidth = 5;
sigWidth = 11;
} else if (bitWidth == 32) {
expWidth = 8;
sigWidth = 24;
} else if (bitWidth == 64) {
expWidth = 11;
sigWidth = 53;
} else if (bitWidth == 128) {
expWidth = 15;
sigWidth = 113;
} else {
op->emitError("Unsupported floating point width");
}
StringRef opName = op->getName().getStringRef();
indent() << getAttributes(op, /*atFormat=*/true) << cell.instanceName()
<< space() << equals() << space() << removeCalyxPrefix(opName)
<< LParen() << expWidth << comma() << sigWidth << comma() << bitWidth
<< RParen() << semicolonEndL();
}

void Emitter::emitAssignment(AssignOp op) {

emitValue(op.getDest(), /*isIndented=*/true);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ void InlineCombGroups::recurseInlineCombGroups(
calyx::ConstantOp, hw::ConstantOp, mlir::arith::ConstantOp,
calyx::MultPipeLibOp, calyx::DivUPipeLibOp, calyx::DivSPipeLibOp,
calyx::RemSPipeLibOp, calyx::RemUPipeLibOp, mlir::scf::WhileOp,
calyx::InstanceOp>(src.getDefiningOp()))
calyx::InstanceOp, calyx::AddFNOp>(src.getDefiningOp()))
continue;

auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(
Expand Down

0 comments on commit 3c375e3

Please sign in to comment.