Skip to content

Commit

Permalink
refactor(compiler): reorganize passes and add memory usage pass
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 7, 2023
1 parent 2baed92 commit eb9e646
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback);

mlir::LogicalResult
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
Expand All @@ -100,18 +105,22 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
bool unrollLoops);

mlir::LogicalResult
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation);
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation);

mlir::LogicalResult
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool gpu);
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool gpu);

mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,13 +521,11 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
return std::move(res);
}

// Concrete -> Canonical dialects
if (mlir::concretelang::pipeline::lowerConcreteToStd(
// Add runtime context in Concrete
if (mlir::concretelang::pipeline::addRuntimeContext(
mlirContext, module, enablePass, options.simulate)
.failed()) {
return StreamStringError(
"Lowering from Bufferized Concrete to canonical MLIR "
"dialects failed");
return StreamStringError("Adding Runtime Context failed");
}

// SDFG -> Canonical dialects
Expand All @@ -538,12 +536,27 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
"Lowering from SDFG to canonical MLIR dialects failed");
}

// bufferize and related passes
if (mlir::concretelang::pipeline::lowerToStd(
mlirContext, module, enablePass, loopParallelize, options.emitGPUOps)
.failed()) {
return StreamStringError("Failed to lower to std");
}

if (target == Target::STD)
return std::move(res);

if (res.feedback) {
if (mlir::concretelang::pipeline::computeMemoryUsage(
mlirContext, module, this->enablePass, res.feedback.value())
.failed()) {
return StreamStringError("Computing memory usage failed");
}
}

// MLIR canonical dialects -> LLVM Dialect
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
mlirContext, module, enablePass, loopParallelize, options.emitGPUOps)
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
enablePass)
.failed()) {
return StreamStringError("Failed to lower to LLVM dialect");
}
Expand Down
41 changes: 32 additions & 9 deletions compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include <concretelang/Conversion/Passes.h>
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h>
#include <concretelang/Dialect/Concrete/Transforms/Passes.h>
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
Expand Down Expand Up @@ -349,6 +350,19 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}

mlir::LogicalResult
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback) {
mlir::PassManager pm(&context);
pipelinePrinting("Computing Memory Usage", pm, context);

addPotentiallyNestedPass(
pm, std::make_unique<Concrete::MemoryUsagePass>(feedback), enablePass);

return pm.run(module.getOperation());
}

mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
Expand Down Expand Up @@ -385,11 +399,11 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
}

mlir::LogicalResult
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation) {
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation) {
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteToStd", pm, context);
pipelinePrinting("Adding Runtime Context", pm, context);
if (!simulation) {
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
enablePass);
Expand All @@ -408,12 +422,12 @@ lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}

mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool gpu) {
mlir::LogicalResult lowerToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool gpu) {
mlir::PassManager pm(&context);
pipelinePrinting("StdToLLVM", pm, context);
pipelinePrinting("Lowering to Std", pm, context);

// Bufferize
mlir::bufferization::OneShotBufferizationOptions bufferizationOptions;
Expand Down Expand Up @@ -472,6 +486,15 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass);

return pm.run(module);
}

mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("StdToLLVM", pm, context);

// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),
Expand Down

0 comments on commit eb9e646

Please sign in to comment.