Skip to content

Commit

Permalink
feat(frontend-python): add support for multi-circuits
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed Mar 5, 2024
1 parent 05bd8cc commit 36a56f6
Show file tree
Hide file tree
Showing 22 changed files with 2,024 additions and 342 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,19 @@ using concretelang::serverlib::ServerProgram;
using concretelang::values::TransportValue;
using concretelang::values::Value;

#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
auto VARNAME = EXPECTED; \
if (auto err = VARNAME.takeError()) { \
throw std::runtime_error(llvm::toString(std::move(err))); \
}

#define CONCAT(a, b) CONCAT_INNER(a, b)
#define CONCAT_INNER(a, b) a##b

#define GET_OR_THROW_EXPECTED_(VARNAME, RESULT, MAYBE) \
auto MAYBE = RESULT; \
if (auto err = MAYBE.takeError()) { \
throw std::runtime_error(llvm::toString(std::move(err))); \
} \
VARNAME = std::move(*MAYBE);

#define GET_OR_THROW_EXPECTED(VARNAME, RESULT) \
GET_OR_THROW_EXPECTED_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__))

#define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \
auto MAYBE = RESULT; \
if (MAYBE.has_failure()) { \
Expand Down Expand Up @@ -370,6 +374,13 @@ class LibrarySupport {
useSimulation};
}

llvm::Expected<ServerProgram>
loadServerProgram(LibraryCompilationResult &result, bool useSimulation) {
EXPECTED_TRY(auto programInfo, getProgramInfo());
return outcomeToExpected(ServerProgram::load(
programInfo.asReader(), getSharedLibPath(), useSimulation));
}

/// Load the client parameters from the compilation result.
llvm::Expected<::concretelang::clientlib::ClientParameters>
loadClientParameters(LibraryCompilationResult &result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ServerCircuit {
Result<std::vector<TransportValue>> call(const ServerKeyset &serverKeyset,
std::vector<TransportValue> &args);

/// Simulate the circuit with public arguments.
Result<std::vector<TransportValue>>
simulate(std::vector<TransportValue> &args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ declare_mlir_python_sources(
concrete/compiler/parameter.py
concrete/compiler/public_arguments.py
concrete/compiler/public_result.py
concrete/compiler/server_circuit.py
concrete/compiler/server_program.py
concrete/compiler/evaluation_keys.py
concrete/compiler/simulated_value_decrypter.py
concrete/compiler/simulated_value_exporter.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "concretelang/Common/Keysets.h"
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/ServerLib/ServerLib.h"
#include "concretelang/Support/logging.h"
#include <llvm/Support/Debug.h>
#include <mlir-c/Bindings/Python/Interop.h>
Expand Down Expand Up @@ -79,65 +80,65 @@ library_compile(LibrarySupport_Py support, const char *module,
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(module),
llvm::SMLoc());
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(sm, options));
return std::move(*compilationResult);
GET_OR_THROW_EXPECTED(auto compilationResult,
support.support.compile(sm, options));
return std::move(compilationResult);
}

std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile_module(
LibrarySupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, cctx, options));
return std::move(*compilationResult);
GET_OR_THROW_EXPECTED(auto compilationResult,
support.support.compile(module, cctx, options));
return std::move(compilationResult);
}

concretelang::clientlib::ClientParameters library_load_client_parameters(
LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
support.support.loadClientParameters(result));
return *clientParameters;
GET_OR_THROW_EXPECTED(auto clientParameters,
support.support.loadClientParameters(result));
return clientParameters;
}

mlir::concretelang::ProgramCompilationFeedback
library_load_compilation_feedback(
LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
support.support.loadCompilationFeedback(result));
return *compilationFeedback;
GET_OR_THROW_EXPECTED(auto compilationFeedback,
support.support.loadCompilationFeedback(result));
return compilationFeedback;
}

concretelang::serverlib::ServerLambda
library_load_server_lambda(LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result,
std::string circuitName, bool useSimulation) {
GET_OR_THROW_LLVM_EXPECTED(
serverLambda,
GET_OR_THROW_EXPECTED(
auto serverLambda,
support.support.loadServerLambda(result, circuitName, useSimulation));
return *serverLambda;
return serverLambda;
}

std::unique_ptr<concretelang::clientlib::PublicResult>
library_server_call(LibrarySupport_Py support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
GET_OR_THROW_LLVM_EXPECTED(
publicResult, support.support.serverCall(lambda, args, evaluationKeys));
return std::move(*publicResult);
GET_OR_THROW_EXPECTED(auto publicResult, support.support.serverCall(
lambda, args, evaluationKeys));
return std::move(publicResult);
}

std::unique_ptr<concretelang::clientlib::PublicResult>
library_simulate(LibrarySupport_Py support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args) {
GET_OR_THROW_LLVM_EXPECTED(publicResult,
support.support.simulate(lambda, args));
return std::move(*publicResult);
GET_OR_THROW_EXPECTED(auto publicResult,
support.support.simulate(lambda, args));
return std::move(publicResult);
}

std::string library_get_shared_lib_path(LibrarySupport_Py support) {
Expand Down Expand Up @@ -1186,6 +1187,48 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return pybind11::bytes(valueSerialize(value));
});

pybind11::class_<ServerProgram>(m, "ServerProgram")
.def_static("load",
[](LibrarySupport_Py &support, bool useSimulation) {
GET_OR_THROW_EXPECTED(auto programInfo,
support.support.getProgramInfo());
auto sharedLibPath = support.support.getSharedLibPath();
GET_OR_THROW_RESULT(
auto result,
ServerProgram::load(programInfo.asReader(),
sharedLibPath, useSimulation));
return result;
})
.def("get_server_circuit", [](ServerProgram &program,
const std::string &circuitName) {
GET_OR_THROW_RESULT(auto result, program.getServerCircuit(circuitName));
return result;
});

pybind11::class_<ServerCircuit>(m, "ServerCircuit")
.def("call",
[](ServerCircuit &circuit,
::concretelang::clientlib::PublicArguments &publicArguments,
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
SignalGuard signalGuard;
auto keyset = evaluationKeys.keyset;
auto values = publicArguments.values;
GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values));
::concretelang::clientlib::PublicResult res{output};
return std::make_unique<::concretelang::clientlib::PublicResult>(
std::move(res));
})
.def("simulate",
[](ServerCircuit &circuit,
::concretelang::clientlib::PublicArguments &publicArguments) {
pybind11::gil_scoped_release release;
auto values = publicArguments.values;
GET_OR_THROW_RESULT(auto output, circuit.simulate(values));
::concretelang::clientlib::PublicResult res{output};
return std::make_unique<::concretelang::clientlib::PublicResult>(
std::move(res));
});

pybind11::class_<::concretelang::clientlib::ValueExporter>(m, "ValueExporter")
.def_static(
"create",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .simulated_value_decrypter import SimulatedValueDecrypter
from .simulated_value_exporter import SimulatedValueExporter
from .parameter import Parameter
from .server_program import ServerProgram


def init_dfr():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.

"""ServerCircuit."""

# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
ServerCircuit as _ServerCircuit,
)

# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
from .public_arguments import PublicArguments
from .public_result import PublicResult
from .evaluation_keys import EvaluationKeys


class ServerCircuit(WrapperCpp):
"""ServerCircuit references a circuit that can be called for execution and simulation."""

def __init__(self, server_circuit: _ServerCircuit):
"""Wrap the native Cpp object.
Args:
server_circuit (_ServerCircuit): object to wrap
Raises:
TypeError: if server_circuit is not of type _ServerCircuit
"""
if not isinstance(server_circuit, _ServerCircuit):
raise TypeError(
f"server_circuit must be of type _ServerCircuit, not {type(server_circuit)}"
)
super().__init__(server_circuit)

def call(
self,
public_arguments: PublicArguments,
evaluation_keys: EvaluationKeys,
) -> PublicResult:
"""Executes the circuit on the public arguments.
Args:
public_arguments (PublicArguments): public arguments to execute on
execution_keys (EvaluationKeys): evaluation keys to use for execution.
Raises:
TypeError: if public_arguments is not of type PublicArguments, or if evaluation_keys is not of type EvaluationKeys
Returns:
PublicResult: A public result object containing the results.
"""
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not "
f"{type(public_arguments)}"
)
if not isinstance(evaluation_keys, EvaluationKeys):
raise TypeError(
f"simulation must be of type EvaluationKeys, not " f"{type(evaluation_keys)}"
)
return PublicResult.wrap(
self.cpp().call(
public_arguments.cpp(), evaluation_keys.cpp()
)
)

def simulate(
self,
public_arguments: PublicArguments,
) -> PublicResult:
"""Simulates the circuit on the public arguments.
Args:
public_arguments (PublicArguments): public arguments to execute on
Raises:
TypeError: if public_arguments is not of type PublicArguments
Returns:
PublicResult: A public result object containing the results.
"""
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not "
f"{type(public_arguments)}"
)
return PublicResult.wrap(
self.cpp().simulate(
public_arguments.cpp()
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.

"""ServerProgram."""

# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
ServerProgram as _ServerProgram,
)

# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
from .library_support import LibrarySupport
from .server_circuit import ServerCircuit


class ServerProgram(WrapperCpp):
"""ServerProgram references compiled circuit objects."""

def __init__(self, server_program: _ServerProgram):
"""Wrap the native Cpp object.
Args:
server_program (_ServerProgram): object to wrap
Raises:
TypeError: if server_program is not of type _ServerProgram
"""
if not isinstance(server_program, _ServerProgram):
raise TypeError(
f"server_program must be of type _ServerProgram, not {type(server_program)}"
)
super().__init__(server_program)

def load(
library_support: LibrarySupport,
simulation: bool,
) -> "ServerProgram":
"""Loads the server program from a library support.
Args:
library_support (LibrarySupport): library support
simulation (bool): use simulation for execution
Raises:
TypeError: if library_support is not of type LibrarySupport, or if simulation is not of type bool
Returns:
ServerProgram: A server program object containing references to circuits for calls.
"""
if not isinstance(library_support, LibrarySupport):
raise TypeError(
f"library_support must be of type LibrarySupport, not "
f"{type(library_support)}"
)
if not isinstance(simulation, bool):
raise TypeError(
f"simulation must be of type bool, not " f"{type(simulation)}"
)
return ServerProgram.wrap(
_ServerProgram.load(
library_support.cpp(), simulation
)
)


def get_server_circuit(self, circuit_name: str) -> ServerCircuit:
"""Returns a given circuit if it is part of the program.
Args:
circuit_name (str): name of the circuit to retrieve.
Raises:
TypeError: if circuit_name is not of type str
RuntimeError: if the circuit is not part of the program
"""
if not isinstance(circuit_name, str):
raise TypeError(
f"circuit_name must be of type str, not {type(circuit_name)}"
)

return ServerCircuit.wrap(
self.cpp().get_server_circuit(circuit_name)
)
2 changes: 1 addition & 1 deletion frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Server,
Value,
)
from .compilation.decorators import circuit, compiler
from .compilation.decorators import circuit, compiler, program
from .extensions import (
AutoRounder,
AutoTruncator,
Expand Down
Loading

0 comments on commit 36a56f6

Please sign in to comment.