Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/python multi circuits #724

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,19 @@ using concretelang::serverlib::ServerProgram;
using concretelang::values::TransportValue;
using concretelang::values::Value;

#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
auto VARNAME = EXPECTED; \
if (auto err = VARNAME.takeError()) { \
throw std::runtime_error(llvm::toString(std::move(err))); \
}

#define CONCAT(a, b) CONCAT_INNER(a, b)
#define CONCAT_INNER(a, b) a##b

#define GET_OR_THROW_EXPECTED_(VARNAME, RESULT, MAYBE) \
auto MAYBE = RESULT; \
if (auto err = MAYBE.takeError()) { \
throw std::runtime_error(llvm::toString(std::move(err))); \
} \
VARNAME = std::move(*MAYBE);

#define GET_OR_THROW_EXPECTED(VARNAME, RESULT) \
GET_OR_THROW_EXPECTED_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__))

#define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \
auto MAYBE = RESULT; \
if (MAYBE.has_failure()) { \
Expand Down Expand Up @@ -370,6 +374,13 @@ class LibrarySupport {
useSimulation};
}

llvm::Expected<ServerProgram>
loadServerProgram(LibraryCompilationResult &result, bool useSimulation) {
EXPECTED_TRY(auto programInfo, getProgramInfo());
return outcomeToExpected(ServerProgram::load(
programInfo.asReader(), getSharedLibPath(), useSimulation));
}

/// Load the client parameters from the compilation result.
llvm::Expected<::concretelang::clientlib::ClientParameters>
loadClientParameters(LibraryCompilationResult &result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ServerCircuit {
Result<std::vector<TransportValue>> call(const ServerKeyset &serverKeyset,
std::vector<TransportValue> &args);

/// Simulate the circuit with public arguments.
Result<std::vector<TransportValue>>
simulate(std::vector<TransportValue> &args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ declare_mlir_python_sources(
concrete/compiler/parameter.py
concrete/compiler/public_arguments.py
concrete/compiler/public_result.py
concrete/compiler/server_circuit.py
concrete/compiler/server_program.py
concrete/compiler/evaluation_keys.py
concrete/compiler/simulated_value_decrypter.py
concrete/compiler/simulated_value_exporter.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "concretelang/Common/Keysets.h"
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/ServerLib/ServerLib.h"
#include "concretelang/Support/logging.h"
#include <llvm/Support/Debug.h>
#include <mlir-c/Bindings/Python/Interop.h>
Expand Down Expand Up @@ -79,65 +80,65 @@ library_compile(LibrarySupport_Py support, const char *module,
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(module),
llvm::SMLoc());
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(sm, options));
return std::move(*compilationResult);
GET_OR_THROW_EXPECTED(auto compilationResult,
support.support.compile(sm, options));
return compilationResult;
}

std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile_module(
LibrarySupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, cctx, options));
return std::move(*compilationResult);
GET_OR_THROW_EXPECTED(auto compilationResult,
support.support.compile(module, cctx, options));
return compilationResult;
}

concretelang::clientlib::ClientParameters library_load_client_parameters(
LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
support.support.loadClientParameters(result));
return *clientParameters;
GET_OR_THROW_EXPECTED(auto clientParameters,
support.support.loadClientParameters(result));
return clientParameters;
}

mlir::concretelang::ProgramCompilationFeedback
library_load_compilation_feedback(
LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
support.support.loadCompilationFeedback(result));
return *compilationFeedback;
GET_OR_THROW_EXPECTED(auto compilationFeedback,
support.support.loadCompilationFeedback(result));
return compilationFeedback;
}

concretelang::serverlib::ServerLambda
library_load_server_lambda(LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result,
std::string circuitName, bool useSimulation) {
GET_OR_THROW_LLVM_EXPECTED(
serverLambda,
GET_OR_THROW_EXPECTED(
auto serverLambda,
support.support.loadServerLambda(result, circuitName, useSimulation));
return *serverLambda;
return serverLambda;
}

std::unique_ptr<concretelang::clientlib::PublicResult>
library_server_call(LibrarySupport_Py support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
GET_OR_THROW_LLVM_EXPECTED(
publicResult, support.support.serverCall(lambda, args, evaluationKeys));
return std::move(*publicResult);
GET_OR_THROW_EXPECTED(auto publicResult, support.support.serverCall(
lambda, args, evaluationKeys));
return publicResult;
}

std::unique_ptr<concretelang::clientlib::PublicResult>
library_simulate(LibrarySupport_Py support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args) {
GET_OR_THROW_LLVM_EXPECTED(publicResult,
support.support.simulate(lambda, args));
return std::move(*publicResult);
GET_OR_THROW_EXPECTED(auto publicResult,
support.support.simulate(lambda, args));
return publicResult;
}

std::string library_get_shared_lib_path(LibrarySupport_Py support) {
Expand Down Expand Up @@ -1190,6 +1191,48 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return pybind11::bytes(valueSerialize(value));
});

pybind11::class_<ServerProgram>(m, "ServerProgram")
.def_static("load",
[](LibrarySupport_Py &support, bool useSimulation) {
GET_OR_THROW_EXPECTED(auto programInfo,
support.support.getProgramInfo());
auto sharedLibPath = support.support.getSharedLibPath();
GET_OR_THROW_RESULT(
auto result,
ServerProgram::load(programInfo.asReader(),
sharedLibPath, useSimulation));
return result;
})
.def("get_server_circuit", [](ServerProgram &program,
const std::string &circuitName) {
GET_OR_THROW_RESULT(auto result, program.getServerCircuit(circuitName));
return result;
});

pybind11::class_<ServerCircuit>(m, "ServerCircuit")
.def("call",
[](ServerCircuit &circuit,
::concretelang::clientlib::PublicArguments &publicArguments,
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
SignalGuard signalGuard;
auto keyset = evaluationKeys.keyset;
auto values = publicArguments.values;
GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values));
::concretelang::clientlib::PublicResult res{output};
return std::make_unique<::concretelang::clientlib::PublicResult>(
std::move(res));
})
.def("simulate",
[](ServerCircuit &circuit,
::concretelang::clientlib::PublicArguments &publicArguments) {
pybind11::gil_scoped_release release;
auto values = publicArguments.values;
GET_OR_THROW_RESULT(auto output, circuit.simulate(values));
::concretelang::clientlib::PublicResult res{output};
return std::make_unique<::concretelang::clientlib::PublicResult>(
std::move(res));
});

pybind11::class_<::concretelang::clientlib::ValueExporter>(m, "ValueExporter")
.def_static(
"create",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .simulated_value_decrypter import SimulatedValueDecrypter
from .simulated_value_exporter import SimulatedValueExporter
from .parameter import Parameter
from .server_program import ServerProgram


def init_dfr():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.

"""ServerCircuit."""

# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
ServerCircuit as _ServerCircuit,
)

# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
from .public_arguments import PublicArguments
from .public_result import PublicResult
from .evaluation_keys import EvaluationKeys


class ServerCircuit(WrapperCpp):
"""ServerCircuit references a circuit that can be called for execution and simulation."""

def __init__(self, server_circuit: _ServerCircuit):
"""Wrap the native Cpp object.
Args:
server_circuit (_ServerCircuit): object to wrap
Raises:
TypeError: if server_circuit is not of type _ServerCircuit
"""
if not isinstance(server_circuit, _ServerCircuit):
raise TypeError(
f"server_circuit must be of type _ServerCircuit, not {type(server_circuit)}"
)
super().__init__(server_circuit)

def call(
self,
public_arguments: PublicArguments,
evaluation_keys: EvaluationKeys,
) -> PublicResult:
"""Executes the circuit on the public arguments.
Args:
public_arguments (PublicArguments): public arguments to execute on
execution_keys (EvaluationKeys): evaluation keys to use for execution.
Raises:
TypeError: if public_arguments is not of type PublicArguments, or if evaluation_keys is
not of type EvaluationKeys
Returns:
PublicResult: A public result object containing the results.
"""
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not "
f"{type(public_arguments)}"
)
if not isinstance(evaluation_keys, EvaluationKeys):
raise TypeError(
f"simulation must be of type EvaluationKeys, not "
f"{type(evaluation_keys)}"
)
return PublicResult.wrap(
self.cpp().call(public_arguments.cpp(), evaluation_keys.cpp())
)

def simulate(
self,
public_arguments: PublicArguments,
) -> PublicResult:
"""Simulates the circuit on the public arguments.
Args:
public_arguments (PublicArguments): public arguments to execute on
Raises:
TypeError: if public_arguments is not of type PublicArguments
Returns:
PublicResult: A public result object containing the results.
"""
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not "
f"{type(public_arguments)}"
)
return PublicResult.wrap(self.cpp().simulate(public_arguments.cpp()))
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.

"""ServerProgram."""

# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
ServerProgram as _ServerProgram,
)

# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
from .library_support import LibrarySupport
from .server_circuit import ServerCircuit


class ServerProgram(WrapperCpp):
"""ServerProgram references compiled circuit objects."""

def __init__(self, server_program: _ServerProgram):
"""Wrap the native Cpp object.
Args:
server_program (_ServerProgram): object to wrap
Raises:
TypeError: if server_program is not of type _ServerProgram
"""
if not isinstance(server_program, _ServerProgram):
raise TypeError(
f"server_program must be of type _ServerProgram, not {type(server_program)}"
)
super().__init__(server_program)

@staticmethod
def load(
library_support: LibrarySupport,
simulation: bool,
) -> "ServerProgram":
"""Loads the server program from a library support.
Args:
library_support (LibrarySupport): library support
simulation (bool): use simulation for execution
Raises:
TypeError: if library_support is not of type LibrarySupport, or if simulation is not of type bool
Returns:
ServerProgram: A server program object containing references to circuits for calls.
"""
if not isinstance(library_support, LibrarySupport):
raise TypeError(
f"library_support must be of type LibrarySupport, not "
f"{type(library_support)}"
)
if not isinstance(simulation, bool):
raise TypeError(
f"simulation must be of type bool, not " f"{type(simulation)}"
)
return ServerProgram.wrap(
_ServerProgram.load(library_support.cpp(), simulation)
)

def get_server_circuit(self, circuit_name: str) -> ServerCircuit:
"""Returns a given circuit if it is part of the program.
Args:
circuit_name (str): name of the circuit to retrieve.
Raises:
TypeError: if circuit_name is not of type str
RuntimeError: if the circuit is not part of the program
"""
if not isinstance(circuit_name, str):
raise TypeError(
f"circuit_name must be of type str, not {type(circuit_name)}"
)

return ServerCircuit.wrap(self.cpp().get_server_circuit(circuit_name))
2 changes: 1 addition & 1 deletion frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Value,
inputset,
)
from .compilation.decorators import circuit, compiler
from .compilation.decorators import circuit, compiler, function, module
from .dtypes import Integer
from .extensions import (
AutoRounder,
Expand Down
Loading
Loading