From ba46de03383936c02be0bb2b3f49c68b30d0779a Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 14 Jun 2024 13:08:23 +0300 Subject: [PATCH] feat(frontend-python): simulate encrypt run decrypt option --- docs/guides/configure.md | 5 ++ .../concrete/fhe/compilation/circuit.py | 9 ++++ .../concrete/fhe/compilation/configuration.py | 10 ++++ .../concrete/fhe/compilation/module.py | 27 +++++++++-- .../tests/compilation/test_circuit.py | 33 +++++++++++++ .../tests/compilation/test_modules.py | 47 +++++++++++++++++++ 6 files changed, 128 insertions(+), 3 deletions(-) diff --git a/docs/guides/configure.md b/docs/guides/configure.md index a4b74d42fb..a72240db65 100644 --- a/docs/guides/configure.md +++ b/docs/guides/configure.md @@ -146,3 +146,8 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * False disables it for all cases. * Integer value enables or disables it depending on the original bit width. * With the default value of 8, only the values with original bit width <= 8 will be converted to their original precision. +* **simulate\_encrypt\_run\_decrypt**: bool = False + * Whether to use simulate encrypt/run/decrypt methods of the circuit/module instead of doing the actual encryption/evaluation/decryption. + * When this option is set to `True`, encrypt and decrypt are identity functions, and run is a wrapper around simulation. In other words, this option allows to switch off the encryption to quickly test if a function has expected semantic (without paying the price of FHE execution). + * This is extremely unsafe and should only be used during development. + * For this reason, it requires **enable\_unsafe\_features** to be set to `True`. diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index 994c17cefd..2aadc0c27f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -249,6 +249,9 @@ def encrypt( encrypted argument(s) for evaluation """ + if self.configuration.simulate_encrypt_run_decrypt: + return args if len(args) != 1 else args[0] # type: ignore + if not hasattr(self, "client"): # pragma: no cover self.enable_fhe_execution() @@ -270,6 +273,9 @@ def run( result(s) of evaluation """ + if self.configuration.simulate_encrypt_run_decrypt: + return self.simulate(*args) + if not hasattr(self, "server"): # pragma: no cover self.enable_fhe_execution() @@ -292,6 +298,9 @@ def decrypt( decrypted result(s) of evaluation """ + if self.configuration.simulate_encrypt_run_decrypt: + return results if len(results) != 1 else results[0] # type: ignore + if not hasattr(self, "client"): # pragma: no cover self.enable_fhe_execution() diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index e8438984b8..91f88c13b3 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -992,6 +992,7 @@ class Configuration: detect_overflow_in_simulation: bool dynamic_indexing_check_out_of_bounds: bool dynamic_assignment_check_out_of_bounds: bool + simulate_encrypt_run_decrypt: bool def __init__( self, @@ -1060,6 +1061,7 @@ def __init__( detect_overflow_in_simulation: bool = False, dynamic_indexing_check_out_of_bounds: bool = True, dynamic_assignment_check_out_of_bounds: bool = True, + simulate_encrypt_run_decrypt: bool = False, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1165,6 +1167,8 @@ def __init__( self.dynamic_indexing_check_out_of_bounds = dynamic_indexing_check_out_of_bounds self.dynamic_assignment_check_out_of_bounds = dynamic_assignment_check_out_of_bounds + self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt + self._validate() class Keep: @@ -1239,6 +1243,7 @@ def fork( detect_overflow_in_simulation: Union[Keep, bool] = KEEP, dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP, dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP, + simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. @@ -1307,6 +1312,11 @@ def _validate(self): if self.use_insecure_key_cache: message = "Insecure key cache cannot be used without enabling unsafe features" raise RuntimeError(message) + if self.simulate_encrypt_run_decrypt: + message = ( + "Simulating encrypt/run/decrypt cannot be used without enabling unsafe features" + ) + raise RuntimeError(message) if self.use_insecure_key_cache and self.insecure_key_cache_location is None: message = "Insecure key cache cannot be enabled without specifying its location" diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py index ee4dce215c..c10317276b 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -54,11 +54,19 @@ class FheFunction: runtime: Union[ExecutionRt, SimulationRt] graph: Graph name: str + configuration: Configuration - def __init__(self, name: str, runtime: Union[ExecutionRt, SimulationRt], graph: Graph): + def __init__( + self, + name: str, + runtime: Union[ExecutionRt, SimulationRt], + graph: Graph, + configuration: Configuration, + ): self.name = name self.runtime = runtime self.graph = graph + self.configuration = configuration def __call__( self, @@ -169,6 +177,10 @@ def encrypt( Optional[Union[Value, Tuple[Optional[Value], ...]]]: encrypted argument(s) for evaluation """ + + if self.configuration.simulate_encrypt_run_decrypt: + return args if len(args) != 1 else args[0] # type: ignore + assert isinstance(self.runtime, ExecutionRt) return self.runtime.client.encrypt(*args, function_name=self.name) @@ -187,6 +199,10 @@ def run( Union[Value, Tuple[Value, ...]]: result(s) of evaluation """ + + if self.configuration.simulate_encrypt_run_decrypt: + return self.simulate(*args) + assert isinstance(self.runtime, ExecutionRt) return self.runtime.server.run( *args, evaluation_keys=self.runtime.client.evaluation_keys, function_name=self.name @@ -207,6 +223,10 @@ def decrypt( Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: decrypted result(s) of evaluation """ + + if self.configuration.simulate_encrypt_run_decrypt: + return results if len(results) != 1 else results[0] # type: ignore + assert isinstance(self.runtime, ExecutionRt) return self.runtime.client.decrypt(*results, function_name=self.name) @@ -682,7 +702,8 @@ def functions(self) -> Dict[str, FheFunction]: Return a dictionnary containing all the functions of the module. """ return { - name: FheFunction(name, self.runtime, self.graphs[name]) for name in self.graphs.keys() + name: FheFunction(name, self.runtime, self.graphs[name], self.configuration) + for name in self.graphs.keys() } @property @@ -705,4 +726,4 @@ def __getattr__(self, item): if item not in list(self.graphs.keys()): error = f"No attribute {item}" raise AttributeError(error) - return FheFunction(item, self.runtime, self.graphs[item]) + return FheFunction(item, self.runtime, self.graphs[item], self.configuration) diff --git a/frontends/concrete-python/tests/compilation/test_circuit.py b/frontends/concrete-python/tests/compilation/test_circuit.py index 2ef2d3d9d3..8e6e5bc8fa 100644 --- a/frontends/concrete-python/tests/compilation/test_circuit.py +++ b/frontends/concrete-python/tests/compilation/test_circuit.py @@ -740,3 +740,36 @@ def f(x, y): circuit.keys = fhe.Keys.deserialize(keys2) result = circuit.decrypt(output) assert np.array_equal(result, (sample_x + sample_y) ** 2) + + +def test_simulate_encrypt_run_decrypt(helpers): + """ + Test `simulate_encrypt_run_decrypt` configuration option. + """ + + def f(x, y): + return x + y + + inputset = fhe.inputset(fhe.uint3, fhe.uint3) + configuration = helpers.configuration().fork( + fhe_execution=False, + fhe_simulation=True, + simulate_encrypt_run_decrypt=True, + ) + + compiler = fhe.Compiler(f, {"x": "encrypted", "y": "encrypted"}) + circuit = compiler.compile(inputset, configuration) + + sample_x, sample_y = 3, 4 + encrypted_x, encrypted_y = circuit.encrypt(sample_x, sample_y) + encrypted_result = circuit.run(encrypted_x, encrypted_y) + result = circuit.decrypt(encrypted_result) + + assert result == sample_x + sample_y + + # Make sure computation happened in simulation. + assert isinstance(encrypted_x, int) + assert isinstance(encrypted_y, int) + assert hasattr(circuit, "simulator") + assert not hasattr(circuit, "server") + assert isinstance(encrypted_result, int) diff --git a/frontends/concrete-python/tests/compilation/test_modules.py b/frontends/concrete-python/tests/compilation/test_modules.py index a820bcf331..da09ed2ebf 100644 --- a/frontends/concrete-python/tests/compilation/test_modules.py +++ b/frontends/concrete-python/tests/compilation/test_modules.py @@ -8,6 +8,7 @@ import pytest from concrete import fhe +from concrete.fhe.compilation.module import SimulationRt # pylint: disable=missing-class-docstring, missing-function-docstring, no-self-argument, unused-variable, no-member, unused-argument, function-redefined, expression-not-assigned # same disables for ruff: @@ -555,3 +556,49 @@ def c(x): assert module.b.decrypt(b_enc) == 20 c_enc = module.c.run(b_enc) assert module.c.decrypt(c_enc) == 40 + + +def test_simulate_encrypt_run_decrypt(helpers): + """ + Test `simulate_encrypt_run_decrypt` configuration option. + """ + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return x + 1 % 20 + + @fhe.function({"x": "encrypted"}) + def dec(x): + return x - 1 % 20 + + inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] + module = Module.compile( + {"inc": inputset, "dec": inputset}, + helpers.configuration().fork( + fhe_execution=False, + fhe_simulation=True, + simulate_encrypt_run_decrypt=True, + ), + ) + + sample_x = 10 + encrypted_x = module.inc.encrypt(sample_x) + + encrypted_result = module.inc.run(encrypted_x) + result = module.inc.decrypt(encrypted_result) + assert result == 11 + + # Make sure computation happened in simulation. + assert isinstance(encrypted_x, int) + assert isinstance(module.inc.runtime, SimulationRt) + assert isinstance(encrypted_result, int) + + encrypted_result = module.dec.run(encrypted_result) + result = module.dec.decrypt(encrypted_result) + assert result == 10 + + # Make sure computation happened in simulation. + assert isinstance(module.dec.runtime, SimulationRt) + assert isinstance(encrypted_result, int)