Skip to content

Commit

Permalink
feat(frontend): accept clear arguments in server.run
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Aug 12, 2024
1 parent c511bd8 commit bb52dbb
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/guides/deploy.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ serialized_result: bytes = result.serialize()

Then, send the serialized result back to the client. After this, the client can decrypt to receive the result of the computation.

{% hint style="info" %}
Clear arguments can directly be passed to `server.run` (e.g., `server.run(x, 10, z, evaluation_keys=...)`).
{% endhint %}

## Decrypting the result (on the client)

Once you have received the serialized result of the computation from the server, you can deserialize it:
Expand Down
53 changes: 49 additions & 4 deletions frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import shutil
import tempfile
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

# mypy: disable-error-code=attr-defined
import concrete.compiler
import jsonpickle
import numpy as np
from concrete.compiler import (
CompilationContext,
CompilationOptions,
Expand All @@ -23,6 +24,7 @@
ProgramCompilationFeedback,
PublicArguments,
ServerProgram,
SimulatedValueExporter,
set_compiler_logging,
set_llvm_debug_flag,
)
Expand All @@ -45,6 +47,7 @@
ParameterSelectionStrategy,
)
from .specs import ClientSpecs
from .utils import friendly_type_format
from .value import Value

# pylint: enable=import-error,no-member,no-name-in-module
Expand All @@ -68,6 +71,9 @@ class Server:
_configuration: Optional[Configuration]
_composition_rules: Optional[List[CompositionRule]]

_clear_input_indices: Dict[str, Set[int]]
_clear_input_shapes: Dict[str, Dict[int, Tuple[int, ...]]]

def __init__(
self,
client_specs: ClientSpecs,
Expand All @@ -89,6 +95,23 @@ def __init__(
self._mlir = None
self._composition_rules = composition_rules

self._clear_input_indices = {}
self._clear_input_shapes = {}

functions_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"]
for function_parameters in functions_parameters:
name = function_parameters["name"]
self._clear_input_indices[name] = {
index
for index, input_spec in enumerate(function_parameters["inputs"])
if "plaintext" in input_spec["typeInfo"]
}
self._clear_input_shapes[name] = {
index: tuple(input_spec["rawInfo"]["shape"]["dimensions"])
for index, input_spec in enumerate(function_parameters["inputs"])
if "plaintext" in input_spec["typeInfo"]
}

assert_that(
support.load_client_parameters(compilation_result).serialize()
== client_specs.client_parameters.serialize()
Expand Down Expand Up @@ -416,10 +439,32 @@ def run(
raise ValueError(message)

if not isinstance(arg, Value):
message = f"Expected argument {i} to be an fhe.Value but it's {type(arg).__name__}"
raise ValueError(message)
if i not in self._clear_input_indices[function_name]:
message = (
f"Expected argument {i} to be an fhe.Value "
f"but it's {friendly_type_format(type(arg))}"
)
raise ValueError(message)

# Simulated value exporter can be used here
# as "clear" fhe.Values have the same
# internal representation as "simulation" fhe.Values

exporter = SimulatedValueExporter.new(
self.client_specs.client_parameters,
function_name,
)

if isinstance(arg, (int, np.integer)):
arg = exporter.export_scalar(i, arg)
else:
arg = np.array(arg)
arg = exporter.export_tensor(i, arg.flatten().tolist(), arg.shape)

buffers.append(arg.inner)
if isinstance(arg, Value):
buffers.append(arg.inner)
else:
buffers.append(arg)

public_args = PublicArguments.new(self.client_specs.client_parameters, buffers)
server_circuit = self._server_program.get_server_circuit(function_name)
Expand Down
12 changes: 9 additions & 3 deletions frontends/concrete-python/concrete/fhe/compilation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,16 @@ def validate_input_args(
Returns:
List[Optional[Union[int, np.ndarray]]]: ordered validated args
"""

functions_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"]
client_parameters_json = next(
filter(lambda x: x["name"] == function_name, functions_parameters)
)
for function_parameters in functions_parameters:
if function_parameters["name"] == function_name:
client_parameters_json = function_parameters
break
else:
message = f"Function `{function_name}` is not in the module"
raise ValueError(message)

assert "inputs" in client_parameters_json
input_specs = client_parameters_json["inputs"]
if len(args) != len(input_specs):
Expand Down
35 changes: 35 additions & 0 deletions frontends/concrete-python/tests/compilation/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,41 @@ def function(x):
server.cleanup()


def test_client_server_api_run_with_clear(helpers):
"""
Test running server run API with a clear input.
"""

configuration = helpers.configuration()

@fhe.compiler({"x": "encrypted", "y": "clear", "z": "clear"})
def function(x, y, z):
return x + y + z

inputset = fhe.inputset(fhe.uint3, fhe.uint3, fhe.tensor[fhe.uint3, 2, 2]) # type: ignore
circuit = function.compile(inputset, configuration.fork())

client = circuit.client
server = circuit.server

x, y, z = 3, 2, [[1, 2], [3, 4]]

encrypted_x, _, _ = client.encrypt(x, None, None)
encrypted_result = server.run(encrypted_x, y, z, evaluation_keys=client.evaluation_keys)
result = client.decrypt(encrypted_result)
assert np.array_equal(result, x + y + np.array(z))

with pytest.raises(ValueError) as excinfo:
server.run(1, 2, 3, evaluation_keys=client.evaluation_keys)

assert str(excinfo.value) == "Expected argument 0 to be an fhe.Value but it's int"

with pytest.raises(RuntimeError) as excinfo:
server.run(encrypted_x, [2, 2], 3, evaluation_keys=client.evaluation_keys)

assert str(excinfo.value) == "Tried to transform plaintext value with incompatible shape."


def test_client_server_api_crt(helpers):
"""
Test client/server API on a CRT circuit.
Expand Down

0 comments on commit bb52dbb

Please sign in to comment.