From 4f5070abb14fd6474b7e838f934e1a0aa408e48b Mon Sep 17 00:00:00 2001 From: rudy Date: Mon, 29 Jul 2024 10:44:17 +0200 Subject: [PATCH] feat(frontend-python): module run are scheduled and parallelized in a worker pool --- docs/guides/configure.md | 8 + .../concrete/fhe/compilation/configuration.py | 5 + .../concrete/fhe/compilation/module.py | 137 +++++++++++++++++- .../tests/compilation/test_modules.py | 79 ++++++++++ 4 files changed, 221 insertions(+), 8 deletions(-) diff --git a/docs/guides/configure.md b/docs/guides/configure.md index 88e0b5b3f0..49e7470a12 100644 --- a/docs/guides/configure.md +++ b/docs/guides/configure.md @@ -210,3 +210,11 @@ When options are specified both in the `configuration` and as kwargs in the `com #### verbose: bool = False - Print details related to compilation. +#### auto_schedule_run: bool = False + - Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized. + - For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished. + - E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available. + - If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically. + - Automatic scheduling behavior can be override locally by calling directly a variant of `run`: + - `run_sync`: forces the fhe function to occur in the current thread, not in the background, + - `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]` diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 421dda91d6..5cfe97b721 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -997,6 +997,7 @@ class Configuration: composable: bool range_restriction: Optional[RangeRestriction] keyset_restriction: Optional[KeysetRestriction] + auto_schedule_run: bool def __init__( self, @@ -1068,6 +1069,7 @@ def __init__( simulate_encrypt_run_decrypt: bool = False, range_restriction: Optional[RangeRestriction] = None, keyset_restriction: Optional[KeysetRestriction] = None, + auto_schedule_run: bool = False, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1177,6 +1179,8 @@ def __init__( self.range_restriction = range_restriction self.keyset_restriction = keyset_restriction + self.auto_schedule_run = auto_schedule_run + self._validate() class Keep: @@ -1254,6 +1258,7 @@ def fork( simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP, range_restriction: Union[Keep, Optional[RangeRestriction]] = KEEP, keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP, + auto_schedule_run: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py index 279c26dff6..161ba9f06a 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -4,7 +4,10 @@ # pylint: disable=import-error,no-member,no-name-in-module +import asyncio +from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path +from threading import Thread from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union import numpy as np @@ -24,13 +27,40 @@ # pylint: enable=import-error,no-member,no-name-in-module -class ExecutionRt(NamedTuple): +class ExecutionRt: """ Runtime object class for execution. """ client: Client server: Server + auto_schedule_run: bool + fhe_executor_pool: ThreadPoolExecutor + fhe_waiter_loop: asyncio.BaseEventLoop + fhe_waiter_thread: Thread # daemon thread + + def __init__(self, client, server, auto_schedule_run): + self.client = client + self.server = server + self.auto_schedule_run = auto_schedule_run + if auto_schedule_run: + self.fhe_executor_pool = ThreadPoolExecutor() + self.fhe_waiter_loop = asyncio.new_event_loop() + + def loop_thread(): + asyncio.set_event_loop(self.fhe_waiter_loop) + self.fhe_waiter_loop.run_forever() + + self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True) + self.fhe_waiter_thread.start() + else: + self.fhe_executor_pool = None + self.fhe_waiter_loop = None + self.fhe_waiter_thread = None + + def __del__(self): + if self.fhe_waiter_loop: + self.fhe_waiter_loop.stop() # daemon cleanup class SimulationRt(NamedTuple): @@ -177,10 +207,48 @@ def encrypt( return tuple(args) if len(args) > 1 else args[0] # type: ignore return self.execution_runtime.val.client.encrypt(*args, function_name=self.name) - def run( + def run_sync( self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], ) -> Union[Value, Tuple[Value, ...]]: + """ + Evaluate the function synchronuously. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ + + return self._run(True, *args) + + def run_async( + self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]] + ) -> "Union[Future[Value], Future[Tuple[Value, ...]]]": + """ + Evaluate the function asynchronuously. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ + if isinstance(self.runtime, ExecutionRt) and not self.runtime.fhe_executor_pool: + self.runtime = ExecutionRt(self.runtime.client, self.runtime.server, True) + self.runtime.auto_schedule_run = False + + return self._run(False, *args) + + def run( + self, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...], Future]: """ Evaluate the function. @@ -192,15 +260,65 @@ def run( Union[Value, Tuple[Value, ...]]: result(s) of evaluation """ + if isinstance(self.runtime, ExecutionRt): + auto_schedule_run = self.runtime.auto_schedule_run + else: + auto_schedule_run = False + return self._run(not auto_schedule_run, *args) + + def _run( + self, + sync, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...], Future]: + """ + Evaluate the function. + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ if self.configuration.simulate_encrypt_run_decrypt: return self._simulate_decrypt(self._simulate_run(*args)) # type: ignore - return self.execution_runtime.val.server.run( - *args, - evaluation_keys=self.execution_runtime.val.client.evaluation_keys, - function_name=self.name, + + assert isinstance(self.runtime, ExecutionRt) + + fhe_work = lambda *args: self.execution_runtime.server.run( + *args, evaluation_keys=self.execution_runtime.client.evaluation_keys, function_name=self.name ) + def args_ready(args): + return [arg.result() if isinstance(arg, Future) else arg for arg in args] + + if sync: + return fhe_work(*args_ready(args)) + + all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args) + + fhe_work_future = lambda *args: self.runtime.fhe_executor_pool.submit(fhe_work, *args) + if all_args_done: + return fhe_work_future(*args_ready(args)) + + # waiting args to be ready with async coroutines + # it only required one thread to run unlimited waits vs unlimited sync threads + async def wait_async(arg): + if not isinstance(arg, Future): + return arg + if arg.done(): + return arg.result() + return await asyncio.wrap_future(arg, loop=self.runtime.fhe_waiter_loop) + + async def args_ready_and_submit(*args): + args = [await wait_async(arg) for arg in args] + return await wait_async(fhe_work_future(*args)) + + run_async = args_ready_and_submit(*args) + return asyncio.run_coroutine_threadsafe(run_async, self.runtime.fhe_waiter_loop) + def decrypt( self, *results: Union[Value, Tuple[Value, ...]], @@ -220,7 +338,9 @@ def decrypt( if self.configuration.simulate_encrypt_run_decrypt: return tuple(results) if len(results) > 1 else results[0] # type: ignore - return self.execution_runtime.val.client.decrypt(*results, function_name=self.name) + assert isinstance(self.runtime, ExecutionRt) + results = [res.result() if isinstance(res, Future) else res for res in results] + return self.execution_runtime.client.decrypt(*results, function_name=self.name) def encrypt_run_decrypt(self, *args: Any) -> Any: """ @@ -620,12 +740,13 @@ def init_execution(): execution_client = Client( execution_server.client_specs, keyset_cache_directory, is_simulated=False ) - return ExecutionRt(execution_client, execution_server) + return ExecutionRt(execution_client, execution_server, self.configuration.auto_schedule_run) self.execution_runtime = Lazy(init_execution) if configuration.fhe_execution: self.execution_runtime.init() + @property def mlir(self) -> str: """Textual representation of the MLIR module. diff --git a/frontends/concrete-python/tests/compilation/test_modules.py b/frontends/concrete-python/tests/compilation/test_modules.py index e3ea233c2d..f6e171f237 100644 --- a/frontends/concrete-python/tests/compilation/test_modules.py +++ b/frontends/concrete-python/tests/compilation/test_modules.py @@ -4,6 +4,7 @@ import inspect import tempfile +from concurrent.futures import Future from pathlib import Path import numpy as np @@ -952,3 +953,81 @@ def inc(x, y): }, helpers.configuration().fork(), ) + +class IncDec: + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return fhe.refresh(x + 1) + + @fhe.function({"x": "encrypted"}) + def dec(x): + return fhe.refresh(x - 1) + + precision = 4 + + inputset = list(range(1, 2**precision - 1)) + to_compile = {"inc": inputset, "dec": inputset} + + +def test_run_async(): + """ + Test `run_async` with `auto_schedule_run=False` configuration option. + """ + + module = IncDec.Module.compile(IncDec.to_compile) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run_async(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run(a) + assert isinstance(b, type(encrypted_x)) + + result = module.inc.decrypt(b) + assert result == sample_x + + +def test_run_sync(): + """ + Test `run_sync` with `auto_schedule_run=True` configuration option. + """ + + conf = fhe.Configuration(auto_schedule_run=True) + module = IncDec.Module.compile(IncDec.to_compile, conf) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run_sync(a) + assert isinstance(b, type(encrypted_x)) + + result = module.inc.decrypt(b) + assert result == sample_x + + +def test_run_auto_schedule(): + """ + Test `run` with `auto_schedule_run=True` configuration option. + """ + + conf = fhe.Configuration(auto_schedule_run=True) + module = IncDec.Module.compile(IncDec.to_compile, conf) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run(a) + assert isinstance(b, Future) + + result = module.inc.decrypt(b) + assert result == sample_x