Skip to content

Commit

Permalink
feat(compiler): compute memory usage per location
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 8, 2023
1 parent fa66f99 commit 1d44d89
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H
#define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>

#include <concretelang/Support/CompilationFeedback.h>

namespace mlir {
namespace concretelang {
namespace Concrete {

struct MemoryUsagePass
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {

CompilationFeedback &feedback;

MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};

void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (walk.wasInterrupted()) {
signalPassFailure();
}
}

std::optional<StringError> enter(Operation *op);

std::optional<StringError> exit(Operation *op);

size_t iterations = 1;
};

} // namespace Concrete
} // namespace concretelang
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ struct CompilationFeedback {
/// @brief statistics
std::vector<Statistic> statistics;

/// @brief memory usage per location
std::map<std::string, int64_t> memoryUsagePerLoc;

/// Fill the sizes from the client parameters.
void
fillFromClientParameters(::concretelang::clientlib::ClientParameters params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
"ExtractSDFGOps",
"SDFGToStreamEmulator",
"TFHEDialectAnalysis",
"ConcreteDialectAnalysis",
];

fn main() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_library(
ConcreteDialectAnalysis
MemoryUsage.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete
DEPENDS
ConcreteDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
ConcreteDialect)
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
#include <concretelang/Support/logging.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <numeric>

using namespace mlir::concretelang;
using namespace mlir;

using Concrete::MemoryUsagePass;

namespace {

template <typename Op> std::string locationOf(Op op) {
auto location = std::string();
auto locationStream = llvm::raw_string_ostream(location);
op.getLoc()->print(locationStream);
return location;
}

int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
int64_t high;
int64_t low;

if (step > 0) {
low = start;
high = stop;
} else {
low = stop;
high = start;
step = -step;
}

if (low >= high) {
return 0;
}

return ((high - low - 1) / step) + 1;
}

std::optional<StringError> calculateNumberOfIterations(scf::ForOp &op,
int64_t &result) {
mlir::Value startValue = op.getLowerBound();
mlir::Value stopValue = op.getUpperBound();
mlir::Value stepValue = op.getStep();

auto startOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
auto stopOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
auto stepOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());

if (!startOp || !stopOp || !stepOp) {
return StringError("only static loops can be analyzed");
}

auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();

if (!startOp || !stopOp || !stepOp) {
return StringError("only integer loops can be analyzed");
}

int64_t start = startAttr.getInt();
int64_t stop = stopAttr.getInt();
int64_t step = stepAttr.getInt();

result = calculateNumberOfIterations(start, stop, step);
return std::nullopt;
}

static std::optional<StringError> on_enter(scf::ForOp &op,
MemoryUsagePass &pass) {
int64_t numberOfIterations;

std::optional<StringError> error =
calculateNumberOfIterations(op, numberOfIterations);
if (error.has_value()) {
return error;
}

assert(numberOfIterations > 0);
pass.iterations *= (uint64_t)numberOfIterations;
return std::nullopt;
}

static std::optional<StringError> on_exit(scf::ForOp &op,
MemoryUsagePass &pass) {
int64_t numberOfIterations;

std::optional<StringError> error =
calculateNumberOfIterations(op, numberOfIterations);
if (error.has_value()) {
return error;
}

assert(numberOfIterations > 0);
pass.iterations /= (uint64_t)numberOfIterations;
return std::nullopt;
}

int64_t getElementTypeSize(mlir::Type elementType) {
if (auto integerType = mlir::dyn_cast<mlir::IntegerType>(elementType)) {
auto width = integerType.getWidth();
return std::ceil(width / 8);
}
return -1;
}

static std::optional<StringError> on_enter(memref::AllocOp &op,
MemoryUsagePass &pass) {
auto bufferType = op.getType();
auto shape = bufferType.getShape();
auto elementType = bufferType.getElementType();
auto elementSize = getElementTypeSize(elementType);
if (elementSize == -1)
return StringError(
"allocation of buffer with a non-supported element-type");
for (auto size : shape) {
if (size == mlir::ShapedType::kDynamic) {
log_verbose() << "warning: dynamic dimension found during computation of "
"memory usage. Dynamic size will be ignored";
}
}
auto multiply_ignore_dyn_size = [](int64_t size_1, int64_t size_2) {
// we don't want to multiply by a dynamic size
return size_2 == mlir::ShapedType::kDynamic ? size_1 : size_1 * size_2;
};

auto location = locationOf(op);
// pass.iterations number of allocation of size: shape_1 * ... * shape_n *
// element_size
auto memoryUsage =
pass.iterations * elementSize *
std::accumulate(shape.begin(), shape.end(), 1, multiply_ignore_dyn_size);

pass.feedback.memoryUsagePerLoc[location] += memoryUsage;

return std::nullopt;
}
} // namespace

std::optional<StringError> MemoryUsagePass::enter(mlir::Operation *op) {
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
std::optional<StringError> error = on_enter(typedOp, *this);
if (error.has_value()) {
return error;
}
}
if (auto typedOp = llvm::dyn_cast<memref::AllocOp>(op)) {
std::optional<StringError> error = on_enter(typedOp, *this);
if (error.has_value()) {
return error;
}
}
return std::nullopt;
}

std::optional<StringError> MemoryUsagePass::exit(mlir::Operation *op) {
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
std::optional<StringError> error = on_exit(typedOp, *this);
if (error.has_value()) {
return error;
}
}
return std::nullopt;
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ add_mlir_library(
ConcretelangRuntime
ConcretelangClientLib
ConcretelangServerLib
TFHEDialectAnalysis)
TFHEDialectAnalysis
ConcreteDialectAnalysis)

target_include_directories(ConcretelangSupport PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
{"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs},
};

auto memoryUsageJsonArray = llvm::json::Array();
for (auto key : v.memoryUsagePerLoc) {
auto memoryUsageJson = llvm::json::Object();
memoryUsageJson.insert({"location", key.first});
memoryUsageJson.insert({"usage", key.second});
memoryUsageJsonArray.push_back(std::move(memoryUsageJson));
}
object.insert({"memoryUsagePerLoc", std::move(memoryUsageJsonArray)});

auto statisticsJson = llvm::json::Array();
for (auto statistic : v.statistics) {
auto statisticJson = llvm::json::Object();
Expand Down Expand Up @@ -177,6 +186,23 @@ bool fromJSON(const llvm::json::Value j,
return false;
}

auto memoryUsageArray = object->getArray("memoryUsagePerLoc");
if (!memoryUsageArray) {
return false;
}
for (auto memoryUsageValue : *memoryUsageArray) {
auto memoryUsage = memoryUsageValue.getAsObject();
if (!memoryUsage) {
return false;
}
auto loc = memoryUsage->getString("location");
auto usage = memoryUsage->getInteger("usage");
if (!loc || !usage) {
return false;
}
v.memoryUsagePerLoc[loc->str()] = *usage;
}

auto statistics = object->getArray("statistics");
if (!statistics) {
return false;
Expand Down

0 comments on commit 1d44d89

Please sign in to comment.