-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(compiler): compute memory usage per location
- Loading branch information
Showing
8 changed files
with
280 additions
and
1 deletion.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
...s/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama | ||
// Exceptions. See | ||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt | ||
// for license information. | ||
|
||
#ifndef CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H | ||
#define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H | ||
|
||
#include <mlir/IR/BuiltinOps.h> | ||
#include <mlir/IR/Operation.h> | ||
#include <mlir/Pass/Pass.h> | ||
|
||
#include <concretelang/Support/CompilationFeedback.h> | ||
|
||
namespace mlir { | ||
namespace concretelang { | ||
namespace Concrete { | ||
|
||
struct MemoryUsagePass | ||
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> { | ||
|
||
CompilationFeedback &feedback; | ||
|
||
MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {}; | ||
|
||
void runOnOperation() override { | ||
WalkResult walk = | ||
getOperation()->walk([&](Operation *op, const WalkStage &stage) { | ||
if (stage.isBeforeAllRegions()) { | ||
std::optional<StringError> error = this->enter(op); | ||
if (error.has_value()) { | ||
op->emitError() << error->mesg; | ||
return WalkResult::interrupt(); | ||
} | ||
} | ||
|
||
if (stage.isAfterAllRegions()) { | ||
std::optional<StringError> error = this->exit(op); | ||
if (error.has_value()) { | ||
op->emitError() << error->mesg; | ||
return WalkResult::interrupt(); | ||
} | ||
} | ||
|
||
return WalkResult::advance(); | ||
}); | ||
|
||
if (walk.wasInterrupted()) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
std::optional<StringError> enter(Operation *op); | ||
|
||
std::optional<StringError> exit(Operation *op); | ||
|
||
size_t iterations = 1; | ||
}; | ||
|
||
} // namespace Concrete | ||
} // namespace concretelang | ||
} // namespace mlir | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
add_mlir_library( | ||
ConcreteDialectAnalysis | ||
MemoryUsage.cpp | ||
ADDITIONAL_HEADER_DIRS | ||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete | ||
DEPENDS | ||
ConcreteDialect | ||
mlir-headers | ||
LINK_LIBS | ||
PUBLIC | ||
MLIRIR | ||
ConcreteDialect) |
171 changes: 171 additions & 0 deletions
171
compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h> | ||
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h> | ||
#include <concretelang/Support/logging.h> | ||
#include <mlir/Dialect/Arith/IR/Arith.h> | ||
#include <mlir/Dialect/MemRef/IR/MemRef.h> | ||
#include <mlir/Dialect/SCF/IR/SCF.h> | ||
#include <numeric> | ||
|
||
using namespace mlir::concretelang; | ||
using namespace mlir; | ||
|
||
using Concrete::MemoryUsagePass; | ||
|
||
namespace { | ||
|
||
template <typename Op> std::string locationOf(Op op) { | ||
auto location = std::string(); | ||
auto locationStream = llvm::raw_string_ostream(location); | ||
op.getLoc()->print(locationStream); | ||
return location; | ||
} | ||
|
||
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) { | ||
int64_t high; | ||
int64_t low; | ||
|
||
if (step > 0) { | ||
low = start; | ||
high = stop; | ||
} else { | ||
low = stop; | ||
high = start; | ||
step = -step; | ||
} | ||
|
||
if (low >= high) { | ||
return 0; | ||
} | ||
|
||
return ((high - low - 1) / step) + 1; | ||
} | ||
|
||
std::optional<StringError> calculateNumberOfIterations(scf::ForOp &op, | ||
int64_t &result) { | ||
mlir::Value startValue = op.getLowerBound(); | ||
mlir::Value stopValue = op.getUpperBound(); | ||
mlir::Value stepValue = op.getStep(); | ||
|
||
auto startOp = | ||
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp()); | ||
auto stopOp = | ||
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp()); | ||
auto stepOp = | ||
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp()); | ||
|
||
if (!startOp || !stopOp || !stepOp) { | ||
return StringError("only static loops can be analyzed"); | ||
} | ||
|
||
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>(); | ||
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>(); | ||
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>(); | ||
|
||
if (!startOp || !stopOp || !stepOp) { | ||
return StringError("only integer loops can be analyzed"); | ||
} | ||
|
||
int64_t start = startAttr.getInt(); | ||
int64_t stop = stopAttr.getInt(); | ||
int64_t step = stepAttr.getInt(); | ||
|
||
result = calculateNumberOfIterations(start, stop, step); | ||
return std::nullopt; | ||
} | ||
|
||
static std::optional<StringError> on_enter(scf::ForOp &op, | ||
MemoryUsagePass &pass) { | ||
int64_t numberOfIterations; | ||
|
||
std::optional<StringError> error = | ||
calculateNumberOfIterations(op, numberOfIterations); | ||
if (error.has_value()) { | ||
return error; | ||
} | ||
|
||
assert(numberOfIterations > 0); | ||
pass.iterations *= (uint64_t)numberOfIterations; | ||
return std::nullopt; | ||
} | ||
|
||
static std::optional<StringError> on_exit(scf::ForOp &op, | ||
MemoryUsagePass &pass) { | ||
int64_t numberOfIterations; | ||
|
||
std::optional<StringError> error = | ||
calculateNumberOfIterations(op, numberOfIterations); | ||
if (error.has_value()) { | ||
return error; | ||
} | ||
|
||
assert(numberOfIterations > 0); | ||
pass.iterations /= (uint64_t)numberOfIterations; | ||
return std::nullopt; | ||
} | ||
|
||
int64_t getElementTypeSize(mlir::Type elementType) { | ||
if (auto integerType = mlir::dyn_cast<mlir::IntegerType>(elementType)) { | ||
auto width = integerType.getWidth(); | ||
return std::ceil(width / 8); | ||
} | ||
return -1; | ||
} | ||
|
||
static std::optional<StringError> on_enter(memref::AllocOp &op, | ||
MemoryUsagePass &pass) { | ||
auto bufferType = op.getType(); | ||
auto shape = bufferType.getShape(); | ||
auto elementType = bufferType.getElementType(); | ||
auto elementSize = getElementTypeSize(elementType); | ||
if (elementSize == -1) | ||
return StringError( | ||
"allocation of buffer with a non-supported element-type"); | ||
for (auto size : shape) { | ||
if (size == mlir::ShapedType::kDynamic) { | ||
log_verbose() << "warning: dynamic dimension found during computation of " | ||
"memory usage. Dynamic size will be ignored"; | ||
} | ||
} | ||
auto multiply_ignore_dyn_size = [](int64_t size_1, int64_t size_2) { | ||
// we don't want to multiply by a dynamic size | ||
return size_2 == mlir::ShapedType::kDynamic ? size_1 : size_1 * size_2; | ||
}; | ||
|
||
auto location = locationOf(op); | ||
// pass.iterations number of allocation of size: shape_1 * ... * shape_n * | ||
// element_size | ||
auto memoryUsage = | ||
pass.iterations * elementSize * | ||
std::accumulate(shape.begin(), shape.end(), 1, multiply_ignore_dyn_size); | ||
|
||
pass.feedback.memoryUsagePerLoc[location] += memoryUsage; | ||
|
||
return std::nullopt; | ||
} | ||
} // namespace | ||
|
||
std::optional<StringError> MemoryUsagePass::enter(mlir::Operation *op) { | ||
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) { | ||
std::optional<StringError> error = on_enter(typedOp, *this); | ||
if (error.has_value()) { | ||
return error; | ||
} | ||
} | ||
if (auto typedOp = llvm::dyn_cast<memref::AllocOp>(op)) { | ||
std::optional<StringError> error = on_enter(typedOp, *this); | ||
if (error.has_value()) { | ||
return error; | ||
} | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
std::optional<StringError> MemoryUsagePass::exit(mlir::Operation *op) { | ||
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) { | ||
std::optional<StringError> error = on_exit(typedOp, *this); | ||
if (error.has_value()) { | ||
return error; | ||
} | ||
} | ||
return std::nullopt; | ||
} |
1 change: 1 addition & 0 deletions
1
compilers/concrete-compiler/compiler/lib/Dialect/Concrete/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
add_subdirectory(Analysis) | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters