Skip to content

Commit

Permalink
feat(frontend-python): simulate encrypt run decrypt option
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Jun 14, 2024
1 parent f2b6a83 commit ba46de0
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/guides/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,8 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* False disables it for all cases.
* Integer value enables or disables it depending on the original bit width.
* With the default value of 8, only the values with original bit width <= 8 will be converted to their original precision.
* **simulate\_encrypt\_run\_decrypt**: bool = False
* Whether to use simulate encrypt/run/decrypt methods of the circuit/module instead of doing the actual encryption/evaluation/decryption.
* When this option is set to `True`, encrypt and decrypt are identity functions, and run is a wrapper around simulation. In other words, this option allows to switch off the encryption to quickly test if a function has expected semantic (without paying the price of FHE execution).
* This is extremely unsafe and should only be used during development.
* For this reason, it requires **enable\_unsafe\_features** to be set to `True`.
9 changes: 9 additions & 0 deletions frontends/concrete-python/concrete/fhe/compilation/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def encrypt(
encrypted argument(s) for evaluation
"""

if self.configuration.simulate_encrypt_run_decrypt:
return args if len(args) != 1 else args[0] # type: ignore

if not hasattr(self, "client"): # pragma: no cover
self.enable_fhe_execution()

Expand All @@ -270,6 +273,9 @@ def run(
result(s) of evaluation
"""

if self.configuration.simulate_encrypt_run_decrypt:
return self.simulate(*args)

if not hasattr(self, "server"): # pragma: no cover
self.enable_fhe_execution()

Expand All @@ -292,6 +298,9 @@ def decrypt(
decrypted result(s) of evaluation
"""

if self.configuration.simulate_encrypt_run_decrypt:
return results if len(results) != 1 else results[0] # type: ignore

if not hasattr(self, "client"): # pragma: no cover
self.enable_fhe_execution()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,7 @@ class Configuration:
detect_overflow_in_simulation: bool
dynamic_indexing_check_out_of_bounds: bool
dynamic_assignment_check_out_of_bounds: bool
simulate_encrypt_run_decrypt: bool

def __init__(
self,
Expand Down Expand Up @@ -1060,6 +1061,7 @@ def __init__(
detect_overflow_in_simulation: bool = False,
dynamic_indexing_check_out_of_bounds: bool = True,
dynamic_assignment_check_out_of_bounds: bool = True,
simulate_encrypt_run_decrypt: bool = False,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1165,6 +1167,8 @@ def __init__(
self.dynamic_indexing_check_out_of_bounds = dynamic_indexing_check_out_of_bounds
self.dynamic_assignment_check_out_of_bounds = dynamic_assignment_check_out_of_bounds

self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt

self._validate()

class Keep:
Expand Down Expand Up @@ -1239,6 +1243,7 @@ def fork(
detect_overflow_in_simulation: Union[Keep, bool] = KEEP,
dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP,
dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP,
simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down Expand Up @@ -1307,6 +1312,11 @@ def _validate(self):
if self.use_insecure_key_cache:
message = "Insecure key cache cannot be used without enabling unsafe features"
raise RuntimeError(message)
if self.simulate_encrypt_run_decrypt:
message = (
"Simulating encrypt/run/decrypt cannot be used without enabling unsafe features"
)
raise RuntimeError(message)

if self.use_insecure_key_cache and self.insecure_key_cache_location is None:
message = "Insecure key cache cannot be enabled without specifying its location"
Expand Down
27 changes: 24 additions & 3 deletions frontends/concrete-python/concrete/fhe/compilation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,19 @@ class FheFunction:
runtime: Union[ExecutionRt, SimulationRt]
graph: Graph
name: str
configuration: Configuration

def __init__(self, name: str, runtime: Union[ExecutionRt, SimulationRt], graph: Graph):
def __init__(
self,
name: str,
runtime: Union[ExecutionRt, SimulationRt],
graph: Graph,
configuration: Configuration,
):
self.name = name
self.runtime = runtime
self.graph = graph
self.configuration = configuration

def __call__(
self,
Expand Down Expand Up @@ -169,6 +177,10 @@ def encrypt(
Optional[Union[Value, Tuple[Optional[Value], ...]]]:
encrypted argument(s) for evaluation
"""

if self.configuration.simulate_encrypt_run_decrypt:
return args if len(args) != 1 else args[0] # type: ignore

assert isinstance(self.runtime, ExecutionRt)
return self.runtime.client.encrypt(*args, function_name=self.name)

Expand All @@ -187,6 +199,10 @@ def run(
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""

if self.configuration.simulate_encrypt_run_decrypt:
return self.simulate(*args)

assert isinstance(self.runtime, ExecutionRt)
return self.runtime.server.run(
*args, evaluation_keys=self.runtime.client.evaluation_keys, function_name=self.name
Expand All @@ -207,6 +223,10 @@ def decrypt(
Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
decrypted result(s) of evaluation
"""

if self.configuration.simulate_encrypt_run_decrypt:
return results if len(results) != 1 else results[0] # type: ignore

assert isinstance(self.runtime, ExecutionRt)
return self.runtime.client.decrypt(*results, function_name=self.name)

Expand Down Expand Up @@ -682,7 +702,8 @@ def functions(self) -> Dict[str, FheFunction]:
Return a dictionnary containing all the functions of the module.
"""
return {
name: FheFunction(name, self.runtime, self.graphs[name]) for name in self.graphs.keys()
name: FheFunction(name, self.runtime, self.graphs[name], self.configuration)
for name in self.graphs.keys()
}

@property
Expand All @@ -705,4 +726,4 @@ def __getattr__(self, item):
if item not in list(self.graphs.keys()):
error = f"No attribute {item}"
raise AttributeError(error)
return FheFunction(item, self.runtime, self.graphs[item])
return FheFunction(item, self.runtime, self.graphs[item], self.configuration)
33 changes: 33 additions & 0 deletions frontends/concrete-python/tests/compilation/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,36 @@ def f(x, y):
circuit.keys = fhe.Keys.deserialize(keys2)
result = circuit.decrypt(output)
assert np.array_equal(result, (sample_x + sample_y) ** 2)


def test_simulate_encrypt_run_decrypt(helpers):
"""
Test `simulate_encrypt_run_decrypt` configuration option.
"""

def f(x, y):
return x + y

inputset = fhe.inputset(fhe.uint3, fhe.uint3)
configuration = helpers.configuration().fork(
fhe_execution=False,
fhe_simulation=True,
simulate_encrypt_run_decrypt=True,
)

compiler = fhe.Compiler(f, {"x": "encrypted", "y": "encrypted"})
circuit = compiler.compile(inputset, configuration)

sample_x, sample_y = 3, 4
encrypted_x, encrypted_y = circuit.encrypt(sample_x, sample_y)
encrypted_result = circuit.run(encrypted_x, encrypted_y)
result = circuit.decrypt(encrypted_result)

assert result == sample_x + sample_y

# Make sure computation happened in simulation.
assert isinstance(encrypted_x, int)
assert isinstance(encrypted_y, int)
assert hasattr(circuit, "simulator")
assert not hasattr(circuit, "server")
assert isinstance(encrypted_result, int)
47 changes: 47 additions & 0 deletions frontends/concrete-python/tests/compilation/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from concrete import fhe
from concrete.fhe.compilation.module import SimulationRt

# pylint: disable=missing-class-docstring, missing-function-docstring, no-self-argument, unused-variable, no-member, unused-argument, function-redefined, expression-not-assigned
# same disables for ruff:
Expand Down Expand Up @@ -555,3 +556,49 @@ def c(x):
assert module.b.decrypt(b_enc) == 20
c_enc = module.c.run(b_enc)
assert module.c.decrypt(c_enc) == 40


def test_simulate_encrypt_run_decrypt(helpers):
"""
Test `simulate_encrypt_run_decrypt` configuration option.
"""

@fhe.module()
class Module:
@fhe.function({"x": "encrypted"})
def inc(x):
return x + 1 % 20

@fhe.function({"x": "encrypted"})
def dec(x):
return x - 1 % 20

inputset = [np.random.randint(1, 20, size=()) for _ in range(100)]
module = Module.compile(
{"inc": inputset, "dec": inputset},
helpers.configuration().fork(
fhe_execution=False,
fhe_simulation=True,
simulate_encrypt_run_decrypt=True,
),
)

sample_x = 10
encrypted_x = module.inc.encrypt(sample_x)

encrypted_result = module.inc.run(encrypted_x)
result = module.inc.decrypt(encrypted_result)
assert result == 11

# Make sure computation happened in simulation.
assert isinstance(encrypted_x, int)
assert isinstance(module.inc.runtime, SimulationRt)
assert isinstance(encrypted_result, int)

encrypted_result = module.dec.run(encrypted_result)
result = module.dec.decrypt(encrypted_result)
assert result == 10

# Make sure computation happened in simulation.
assert isinstance(module.dec.runtime, SimulationRt)
assert isinstance(encrypted_result, int)

0 comments on commit ba46de0

Please sign in to comment.