Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile in-memory MLIR Module #514

Merged
merged 2 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITSupport_Py support, const char *module,
mlir::concretelang::CompilationOptions options);

MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile_module(
JITSupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx);

MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITSupport_Py support,
mlir::concretelang::JitCompilationResult &);
Expand Down Expand Up @@ -76,6 +82,12 @@ MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibrarySupport_Py support, const char *module,
mlir::concretelang::CompilationOptions options);

MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile_module(
LibrarySupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx);

MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
library_load_client_parameters(LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ class CompilerEngine {
compile(llvm::SourceMgr &sm, Target target,
std::optional<std::shared_ptr<Library>> lib = {});

llvm::Expected<CompilationResult>
compile(mlir::ModuleOp module, Target target,
std::optional<std::shared_ptr<Library>> lib = {});

llvm::Expected<CompilerEngine::Library>
compile(std::vector<std::string> inputs, std::string outputDirPath,
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
Expand All @@ -298,6 +302,13 @@ class CompilerEngine {
bool generateCompilationFeedback = true,
bool generateCppHeader = true);

llvm::Expected<CompilerEngine::Library>
compile(mlir::ModuleOp module, std::string outputDirPath,
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
bool generateStaticLib = true, bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true);

void setCompilationOptions(CompilationOptions &options) {
compilerOptions = options;
if (options.v0FHEConstraints.has_value()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class JITSupport

llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(llvm::SourceMgr &program, CompilationOptions options) override;
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(mlir::ModuleOp program,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
CompilationOptions options) override;
using LambdaSupport::compile;

llvm::Expected<std::shared_ptr<concretelang::JITLambda>>
Expand All @@ -63,6 +67,11 @@ class JITSupport
}

private:
template <typename T>
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compileWithEngine(T program, CompilationOptions options,
concretelang::CompilerEngine &engine);

std::optional<std::string> runtimeLibPath;
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ template <typename Lambda, typename CompilationResult> class LambdaSupport {
llvm::SourceMgr &program,
CompilationOptions options = CompilationOptions("main")) = 0;

llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
mlir::ModuleOp program,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
CompilationOptions options = CompilationOptions("main")) = 0;

llvm::Expected<std::unique_ptr<CompilationResult>>
compile(llvm::StringRef program,
CompilationOptions options = CompilationOptions("main")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,17 @@ class LibrarySupport
auto context = CompilationContext::createShared();
concretelang::CompilerEngine engine(context);
engine.setCompilationOptions(options);
return compileWithEngine<llvm::SourceMgr &>(program, options, engine);
}

// Compile to a library
auto library = engine.compile(
program, outputPath, runtimeLibraryPath, generateSharedLib,
generateStaticLib, generateClientParameters,
generateCompilationFeedback, generateCppHeader);
if (auto err = library.takeError()) {
return std::move(err);
}

if (!options.clientParametersFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}

auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = *options.clientParametersFuncName;
return std::move(result);
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(mlir::ModuleOp program,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
CompilationOptions options) override {
// Setup the compiler engine
concretelang::CompilerEngine engine(cctx);
engine.setCompilationOptions(options);
return compileWithEngine<mlir::ModuleOp>(program, options, engine);
}
using LambdaSupport::compile;

Expand Down Expand Up @@ -169,6 +162,29 @@ class LibrarySupport
bool generateClientParameters;
bool generateCompilationFeedback;
bool generateCppHeader;

template <typename T>
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compileWithEngine(T program, CompilationOptions options,
concretelang::CompilerEngine &engine) {
// Compile to a library
auto library = engine.compile(
program, outputPath, runtimeLibraryPath, generateSharedLib,
generateStaticLib, generateClientParameters,
generateCompilationFeedback, generateCppHeader);
if (auto err = library.takeError()) {
return std::move(err);
}

if (!options.clientParametersFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}

auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = *options.clientParametersFuncName;
return std::move(result);
}
};

} // namespace concretelang
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ declare_mlir_python_sources(
concrete/compiler/__init__.py
concrete/compiler/client_parameters.py
concrete/compiler/client_support.py
concrete/compiler/compilation_context.py
concrete/compiler/compilation_feedback.py
concrete/compiler/compilation_options.py
concrete/compiler/jit_compilation_result.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/Jit.h"
#include <mlir-c/Bindings/Python/Interop.h>
#include <mlir/CAPI/IR.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/ExecutionEngine/OptUtils.h>
Expand Down Expand Up @@ -129,6 +131,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::class_<mlir::concretelang::JITLambda,
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
"JITLambda");
pybind11::class_<mlir::concretelang::CompilationContext,
std::shared_ptr<mlir::concretelang::CompilationContext>>(
m, "CompilationContext")
.def(pybind11::init([]() {
return mlir::concretelang::CompilationContext::createShared();
}))
.def("mlir_context",
[](std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
auto mlirCtx = cctx->getMLIRContext();
return pybind11::reinterpret_steal<pybind11::object>(
mlirPythonContextToCapsule(wrap(mlirCtx)));
});
pybind11::class_<JITSupport_Py>(m, "JITSupport")
.def(pybind11::init([](std::string runtimeLibPath) {
return jit_support(runtimeLibPath);
Expand All @@ -139,6 +153,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
SignalGuard signalGuard;
return jit_compile(support, mlir_program.c_str(), options);
})
.def("compile",
[](JITSupport_Py &support, pybind11::object mlir_module,
CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
SignalGuard signalGuard;
return jit_compile_module(
support,
unwrap(mlirPythonCapsuleToModule(mlir_module.ptr())).clone(),
options, cctx);
})
.def("load_client_parameters",
[](JITSupport_Py &support,
mlir::concretelang::JitCompilationResult &result) {
Expand Down Expand Up @@ -191,6 +215,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
SignalGuard signalGuard;
return library_compile(support, mlir_program.c_str(), options);
})
.def("compile",
[](LibrarySupport_Py &support, pybind11::object mlir_module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
SignalGuard signalGuard;
return library_compile_module(
support,
unwrap(mlirPythonCapsuleToModule(mlir_module.ptr())).clone(),
options, cctx);
})
.def("load_client_parameters",
[](LibrarySupport_Py &support,
mlir::concretelang::LibraryCompilationResult &result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ jit_compile(JITSupport_Py support, const char *module,
return std::move(*compilationResult);
}

MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile_module(
JITSupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, cctx, options));
return std::move(*compilationResult);
}

MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITSupport_Py support,
mlir::concretelang::JitCompilationResult &result) {
Expand Down Expand Up @@ -88,6 +98,16 @@ library_compile(LibrarySupport_Py support, const char *module,
return std::move(*compilationResult);
}

MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile_module(
LibrarySupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, cctx, options));
return std::move(*compilationResult);
}

MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
library_load_client_parameters(
LibrarySupport_Py support,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# pylint: enable=no-name-in-module,import-error

from .compilation_options import CompilationOptions
from .compilation_context import CompilationContext
from .key_set_cache import KeySetCache
from .client_parameters import ClientParameters
from .compilation_feedback import CompilationFeedback
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.

"""CompilationContext.

CompilationContext holds the MLIR Context supposed to be used during IR generation.
"""

# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
CompilationContext as _CompilationContext,
)
from mlir.ir import Context as MlirContext

# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp


class CompilationContext(WrapperCpp):
"""Support class for compilation context.

CompilationContext is meant to outlive mlir_context().
Do not use the mlir_context after deleting the CompilationContext.
"""

def __init__(self, compilation_context: _CompilationContext):
"""Wrap the native Cpp object.

Args:
compilation_context (_CompilationContext): object to wrap

Raises:
TypeError: if compilation_context is not of type _CompilationContext
"""
if not isinstance(compilation_context, _CompilationContext):
raise TypeError(
f"compilation_context must be of type _CompilationContext, not "
f"{type(compilation_context)}"
)
super().__init__(compilation_context)

@staticmethod
# pylint: disable=arguments-differ
def new() -> "CompilationContext":
"""Build a CompilationContext.

Returns:
CompilationContext
"""
return CompilationContext.wrap(_CompilationContext())

def mlir_context(
self,
) -> MlirContext:
"""
Get the MLIR context used by the compilation context.

The Compilation Context should outlive the mlir_context.

Returns:
MlirContext: MLIR context of the compilation context
"""
# pylint: disable=protected-access
return MlirContext._CAPICreate(self.cpp().mlir_context())
# pylint: enable=protected-access
Loading