Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute memory usage per location #538

Merged
merged 3 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H
#define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>

#include <concretelang/Support/CompilationFeedback.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here and same reflexion for statistics, why not follow the common pass pattern?


namespace mlir {
namespace concretelang {
namespace Concrete {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enter/exit pass could be factorized

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or even more the nested loop logic too


struct MemoryUsagePass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the common pass pattern, this should be hidden in cpp and expose uniquely a createAnalysisConcreteMemoryUsage function.

: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {

CompilationFeedback &feedback;

MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};

void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (walk.wasInterrupted()) {
signalPassFailure();
}
}

std::optional<StringError> enter(Operation *op);

std::optional<StringError> exit(Operation *op);

std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;

size_t iterations = 1;
};

} // namespace Concrete
} // namespace concretelang
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ struct CompilationFeedback {
/// @brief statistics
std::vector<Statistic> statistics;

/// @brief memory usage per location
std::map<std::string, int64_t> memoryUsagePerLoc;

/// Fill the sizes from the client parameters.
void
fillFromClientParameters(::concretelang::clientlib::ClientParameters params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback);

mlir::LogicalResult
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
Expand All @@ -100,18 +105,27 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
bool unrollLoops);

mlir::LogicalResult
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation);
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation);

mlir::LogicalResult
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool gpu);
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops);

mlir::LogicalResult lowerToCAPI(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool gpu);

mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"crt_decompositions_of_outputs",
&mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs)
.def_readonly("statistics",
&mlir::concretelang::CompilationFeedback::statistics);
&mlir::concretelang::CompilationFeedback::statistics)
.def_readonly(
"memory_usage_per_location",
&mlir::concretelang::CompilationFeedback::memoryUsagePerLoc);

pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JITCompilationResult");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, compilation_feedback: _CompilationFeedback):
compilation_feedback.crt_decompositions_of_outputs
)
self.statistics = compilation_feedback.statistics
self.memory_usage_per_location = compilation_feedback.memory_usage_per_location

super().__init__(compilation_feedback)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
"ExtractSDFGOps",
"SDFGToStreamEmulator",
"TFHEDialectAnalysis",
"ConcreteDialectAnalysis",
];

fn main() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_library(
ConcreteDialectAnalysis
MemoryUsage.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete
DEPENDS
ConcreteDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
ConcreteDialect)
Loading
Loading