Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower MemRef GetGlobal and write data to json files #7301

Merged
merged 16 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def SCFToCalyx : Pass<"lower-scf-to-calyx", "mlir::ModuleOp"> {
"Identifier of top-level function to be the entry-point component"
" of the Calyx program.">,
Option<"ciderSourceLocationMetadata", "cider-source-location-metadata", "bool", "",
"Whether to track source location for the Cider debugger.">
"Whether to track source location for the Cider debugger.">,
Option<"writeJsonOpt", "write-json", "std::string", "",
"Whether to write memory contents to the json file.">
];
}

Expand Down
37 changes: 37 additions & 0 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/JSON.h"

#include <variant>

Expand Down Expand Up @@ -450,6 +451,38 @@ class ComponentLoweringStateInterface {
return builder.create<TLibraryOp>(loc, getUniqueName(name), resTypes);
}

llvm::json::Value &getExtMemData() { return extMemData; }

const llvm::json::Value &getExtMemData() const { return extMemData; }

void setDataField(StringRef name, llvm::json::Array data) {
auto *extMemDataObj = extMemData.getAsObject();
assert(extMemDataObj && "extMemData should be an object");

auto &value = (*extMemDataObj)[name.str()];
llvm::json::Object *obj = value.getAsObject();
if (!obj) {
value = llvm::json::Object{};
obj = value.getAsObject();
}
(*obj)["data"] = llvm::json::Value(std::move(data));
}

void setFormat(StringRef name, std::string numType, bool isSigned,
unsigned width) {
auto *extMemDataObj = extMemData.getAsObject();
assert(extMemDataObj && "extMemData should be an object");

auto &value = (*extMemDataObj)[name.str()];
llvm::json::Object *obj = value.getAsObject();
if (!obj) {
value = llvm::json::Object{};
obj = value.getAsObject();
}
(*obj)["format"] = llvm::json::Object{
{"numeric_type", numType}, {"is_signed", isSigned}, {"width", width}};
}

private:
/// The component which this lowering state is associated to.
calyx::ComponentOp component;
Expand Down Expand Up @@ -486,6 +519,10 @@ class ComponentLoweringStateInterface {

/// A mapping between the callee and the instance.
llvm::StringMap<calyx::InstanceOp> instanceMap;

/// A json file to store external global memory data. See
/// https://docs.calyxir.org/lang/data-format.html?highlight=json#the-data-format
llvm::json::Value extMemData;
};

/// An interface for conversion passes that lower Calyx programs. This handles
Expand Down
127 changes: 124 additions & 3 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"

#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <filesystem>
#include <fstream>

#include <locale>
#include <numeric>
#include <variant>

namespace circt {
Expand Down Expand Up @@ -266,6 +273,14 @@ class ComponentLoweringState : public calyx::ComponentLoweringStateInterface,
/// Iterate through the operations of a source function and instantiate
/// components or primitives based on the type of the operations.
class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
public:
BuildOpGroups(MLIRContext *context, LogicalResult &resRef,
calyx::PatternApplicationState &patternState,
DenseMap<mlir::func::FuncOp, calyx::ComponentOp> &map,
calyx::CalyxLoweringState &state,
mlir::Pass::Option<std::string> &writeJsonOpt)
: FuncOpPartialLoweringPattern(context, resRef, patternState, map, state),
writeJson(writeJsonOpt) {}
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
Expand All @@ -283,7 +298,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
scf::ParallelOp, scf::ReduceOp,
/// memref
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
memref::StoreOp,
memref::StoreOp, memref::GetGlobalOp,
/// standard arithmetic
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
Expand All @@ -306,10 +321,32 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
: WalkResult::interrupt();
});

if (!writeJson.empty()) {
if (auto fileLoc = dyn_cast<mlir::FileLineColLoc>(funcOp->getLoc())) {
std::string filename = fileLoc.getFilename().str();
std::filesystem::path path(filename);
std::string jsonFileName = writeJson.append(".json");
auto outFileName = path.parent_path().append(jsonFileName);
std::ofstream outFile(outFileName);

if (!outFile.is_open()) {
llvm::errs() << "Unable to open file: " << outFileName
<< " for writing\n";
return failure();
}
llvm::raw_os_ostream llvmOut(outFile);
llvm::json::OStream jsonOS(llvmOut, 2);
jsonOS.value(getState<ComponentLoweringState>().getExtMemData());
jsonOS.flush();
outFile.close();
}
}

return success(opBuiltSuccessfully);
}

private:
mlir::Pass::Option<std::string> &writeJson;
/// Op builder specializations.
LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
Expand Down Expand Up @@ -341,6 +378,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocaOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter,
memref::GetGlobalOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
Expand Down Expand Up @@ -962,6 +1001,82 @@ static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
IntegerAttr::get(rewriter.getI1Type(), llvm::APInt(1, 1)));
componentState.registerMemoryInterface(allocOp.getResult(),
calyx::MemoryInterface(memoryOp));

unsigned elmTyBitWidth = memtype.getElementTypeBitWidth();
assert(elmTyBitWidth <= 64 && "element bitwidth should not exceed 64");
bool isFloat = !memtype.getElementType().isInteger();
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved

auto shape = allocOp.getType().getShape();
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
int totalSize =
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
// The `totalSize <= 1` check is a hack to:
// https://github.com/llvm/circt/pull/2661, where a multi-dimensional memory
// whose size in some dimension equals 1, e.g. memref<1x1x1x1xi32>, will be
// collapsed to `memref<1xi32>` with `totalSize == 1`. While the above case is
// a trivial fix, Calyx expects 1-dimensional memories in general:
// https://github.com/calyxir/calyx/issues/907
if (!(shape.size() <= 1 || totalSize <= 1)) {
allocOp.emitError("input memory dimension must be empty or one.");
return failure();
}

std::vector<uint64_t> flattenedVals(totalSize, 0);
if (isa<memref::GetGlobalOp>(allocOp)) {
auto getGlobalOp = cast<memref::GetGlobalOp>(allocOp);
auto *symbolTableOp =
getGlobalOp->template getParentWithTrait<mlir::OpTrait::SymbolTable>();
auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
// Flatten the values in the attribute
auto cstAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(
globalOp.getConstantInitValue());
int sizeCount = 0;
for (auto attr : cstAttr.template getValues<Attribute>()) {
assert((isa<mlir::FloatAttr, mlir::IntegerAttr>(attr)) &&
"memory attributes must be float or int");
if (auto fltAttr = dyn_cast<mlir::FloatAttr>(attr)) {
flattenedVals[sizeCount++] =
bit_cast<uint64_t>(fltAttr.getValueAsDouble());
} else {
auto intAttr = dyn_cast<mlir::IntegerAttr>(attr);
cgyurgyik marked this conversation as resolved.
Show resolved Hide resolved
APInt value = intAttr.getValue();
flattenedVals[sizeCount++] = *value.getRawData();
}
}

rewriter.eraseOp(globalOp);
}

llvm::json::Array result;
result.reserve(std::max(static_cast<int>(shape.size()), 1));

Type elemType = memtype.getElementType();
bool isSigned =
!elemType.isSignlessInteger() && !elemType.isUnsignedInteger();
for (uint64_t bitValue : flattenedVals) {
llvm::json::Value value = 0;
if (isFloat) {
// We cast to `double` and let downstream calyx to deal with the actual
// value's precision handling.
value = bit_cast<double>(bitValue);
} else {
APInt apInt(/*numBits=*/elmTyBitWidth, bitValue, isSigned);
// The conditional ternary operation will cause the `value` to interpret
// the underlying data as unsigned regardless `isSigned` or not.
if (isSigned)
value = static_cast<int64_t>(apInt.getSExtValue());
else
value = apInt.getZExtValue();
}
result.push_back(std::move(value));
}

componentState.setDataField(memoryOp.getName(), result);
std::string numType =
memtype.getElementType().isInteger() ? "bitnum" : "ieee754_float";
componentState.setFormat(memoryOp.getName(), numType, isSigned,
elmTyBitWidth);

return success();
}

Expand All @@ -975,6 +1090,12 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
memref::GetGlobalOp getGlobalOp) const {
return buildAllocOp(getState<ComponentLoweringState>(), rewriter,
getGlobalOp);
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::YieldOp yieldOp) const {
if (yieldOp.getOperands().empty()) {
Expand Down Expand Up @@ -2644,7 +2765,7 @@ void SCFToCalyxPass::runOnOperation() {
/// having a distinct group for each operation, groups are analogous to SSA
/// values in the source program.
addOncePattern<BuildOpGroups>(loweringPatterns, patternState, funcMap,
*loweringState);
*loweringState, writeJsonOpt);

/// This pattern traverses the CFG of the program and generates a control
/// schedule based on the calyx::GroupOp's which were registered for each
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ BasicLoopInterface::~BasicLoopInterface() = default;

ComponentLoweringStateInterface::ComponentLoweringStateInterface(
calyx::ComponentOp component)
: component(component) {}
: component(component), extMemData(llvm::json::Object{}) {}

ComponentLoweringStateInterface::~ComponentLoweringStateInterface() = default;

Expand Down
93 changes: 93 additions & 0 deletions test/Conversion/SCFToCalyx/write_memory.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// RUN: circt-opt %s --lower-scf-to-calyx="write-json=data" -canonicalize>/dev/null && cat $(dirname %s)/data.json | FileCheck %s

// CHECK-LABEL: "mem_0": {
// CHECK: "data": [
// CHECK: 0,
// CHECK: 0,
// CHECK: 0,
// CHECK: 0
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "ieee754_float",
// CHECK: "width": 32
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_1": {
// CHECK: "data": [
// CHECK: 0
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "bitnum",
// CHECK: "width": 8
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_2": {
// CHECK: "data": [
// CHECK: 43,
// CHECK: 8,
// CHECK: 4294967257,
// CHECK: 4294967277,
// CHECK: 70,
// CHECK: 4294967232,
// CHECK: 4294967289,
// CHECK: 4294967269,
// CHECK: 4294967239,
// CHECK: 5
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": false,
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
// CHECK: "numeric_type": "bitnum",
// CHECK: "width": 32
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_3": {
// CHECK: "data": [
// CHECK: 0.69999998807907104,
// CHECK: -4.1999998092651367,
// CHECK: 0
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "ieee754_float",
// CHECK: "width": 32
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_4": {
// CHECK: "data": [
// CHECK: -42,
// CHECK: 35
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "bitnum",
// CHECK: "width": 8
// CHECK: }
// CHECK: }

module {
memref.global "private" constant @constant_10xi32_0 : memref<10xi32> = dense<[43, 8, -39, -19, 70, -64, -7, -27, -57, 5]>
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
memref.global "private" constant @constant_2xsi8_0 : memref<2xsi8> = dense<[-42, 35]>
memref.global "private" constant @constant_3xf32 : memref<3xf32> = dense<[0.7, -4.2, 0.0]>
func.func @main(%arg_idx : index) -> i32 {
%alloc = memref.alloc() : memref<4xf32>
%zero_dim_mem = memref.alloca() : memref<si8>
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = memref.get_global @constant_10xi32_0 : memref<10xi32>
%ret = memref.load %0[%arg_idx] : memref<10xi32>
%1 = memref.get_global @constant_3xf32 : memref<3xf32>
%2 = memref.load %1[%c1] : memref<3xf32>
memref.store %2, %alloc[%c2] : memref<4xf32>
%3 = memref.get_global @constant_2xsi8_0 : memref<2xsi8>
%4 = memref.load %3[%c1] : memref<2xsi8>
memref.store %4, %zero_dim_mem[] : memref<si8>
return %ret : i32
}
}

Loading