diff --git a/docs/guides/deploy.md b/docs/guides/deploy.md index 7f6f787cb8..b45e769e2f 100644 --- a/docs/guides/deploy.md +++ b/docs/guides/deploy.md @@ -130,6 +130,10 @@ serialized_result: bytes = result.serialize() Then, send the serialized result back to the client. After this, the client can decrypt to receive the result of the computation. +{% hint style="info" %} +Clear arguments can directly be passed to `server.run` (e.g., `server.run(x, 10, z, evaluation_keys=...)`). +{% endhint %} + ## Decrypting the result (on the client) Once you have received the serialized result of the computation from the server, you can deserialize it: diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 47b415a3ae..55e68d4639 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -8,11 +8,12 @@ import shutil import tempfile from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union # mypy: disable-error-code=attr-defined import concrete.compiler import jsonpickle +import numpy as np from concrete.compiler import ( CompilationContext, CompilationOptions, @@ -23,6 +24,7 @@ ProgramCompilationFeedback, PublicArguments, ServerProgram, + SimulatedValueExporter, set_compiler_logging, set_llvm_debug_flag, ) @@ -45,6 +47,7 @@ ParameterSelectionStrategy, ) from .specs import ClientSpecs +from .utils import friendly_type_format from .value import Value # pylint: enable=import-error,no-member,no-name-in-module @@ -68,6 +71,9 @@ class Server: _configuration: Optional[Configuration] _composition_rules: Optional[List[CompositionRule]] + _clear_input_indices: Dict[str, Set[int]] + _clear_input_shapes: Dict[str, Dict[int, Tuple[int, ...]]] + def __init__( self, client_specs: ClientSpecs, @@ -89,6 +95,23 @@ def __init__( self._mlir = None self._composition_rules = composition_rules + self._clear_input_indices = {} + self._clear_input_shapes = {} + + functions_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"] + for function_parameters in functions_parameters: + name = function_parameters["name"] + self._clear_input_indices[name] = { + index + for index, input_spec in enumerate(function_parameters["inputs"]) + if "plaintext" in input_spec["typeInfo"] + } + self._clear_input_shapes[name] = { + index: tuple(input_spec["rawInfo"]["shape"]["dimensions"]) + for index, input_spec in enumerate(function_parameters["inputs"]) + if "plaintext" in input_spec["typeInfo"] + } + assert_that( support.load_client_parameters(compilation_result).serialize() == client_specs.client_parameters.serialize() @@ -416,10 +439,32 @@ def run( raise ValueError(message) if not isinstance(arg, Value): - message = f"Expected argument {i} to be an fhe.Value but it's {type(arg).__name__}" - raise ValueError(message) + if i not in self._clear_input_indices[function_name]: + message = ( + f"Expected argument {i} to be an fhe.Value " + f"but it's {friendly_type_format(type(arg))}" + ) + raise ValueError(message) + + # Simulated value exporter can be used here + # as "clear" fhe.Values have the same + # internal representation as "simulation" fhe.Values + + exporter = SimulatedValueExporter.new( + self.client_specs.client_parameters, + function_name, + ) + + if isinstance(arg, (int, np.integer)): + arg = exporter.export_scalar(i, arg) + else: + arg = np.array(arg) + arg = exporter.export_tensor(i, arg.flatten().tolist(), arg.shape) - buffers.append(arg.inner) + if isinstance(arg, Value): + buffers.append(arg.inner) + else: + buffers.append(arg) public_args = PublicArguments.new(self.client_specs.client_parameters, buffers) server_circuit = self._server_program.get_server_circuit(function_name) diff --git a/frontends/concrete-python/concrete/fhe/compilation/utils.py b/frontends/concrete-python/concrete/fhe/compilation/utils.py index 2ba2c7c392..7010a86279 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/utils.py +++ b/frontends/concrete-python/concrete/fhe/compilation/utils.py @@ -82,10 +82,16 @@ def validate_input_args( Returns: List[Optional[Union[int, np.ndarray]]]: ordered validated args """ + functions_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"] - client_parameters_json = next( - filter(lambda x: x["name"] == function_name, functions_parameters) - ) + for function_parameters in functions_parameters: + if function_parameters["name"] == function_name: + client_parameters_json = function_parameters + break + else: + message = f"Function `{function_name}` is not in the module" + raise ValueError(message) + assert "inputs" in client_parameters_json input_specs = client_parameters_json["inputs"] if len(args) != len(input_specs): diff --git a/frontends/concrete-python/tests/compilation/test_circuit.py b/frontends/concrete-python/tests/compilation/test_circuit.py index b08b22866b..816038cf19 100644 --- a/frontends/concrete-python/tests/compilation/test_circuit.py +++ b/frontends/concrete-python/tests/compilation/test_circuit.py @@ -291,6 +291,41 @@ def function(x): server.cleanup() +def test_client_server_api_run_with_clear(helpers): + """ + Test running server run API with a clear input. + """ + + configuration = helpers.configuration() + + @fhe.compiler({"x": "encrypted", "y": "clear", "z": "clear"}) + def function(x, y, z): + return x + y + z + + inputset = fhe.inputset(fhe.uint3, fhe.uint3, fhe.tensor[fhe.uint3, 2, 2]) # type: ignore + circuit = function.compile(inputset, configuration.fork()) + + client = circuit.client + server = circuit.server + + x, y, z = 3, 2, [[1, 2], [3, 4]] + + encrypted_x, _, _ = client.encrypt(x, None, None) + encrypted_result = server.run(encrypted_x, y, z, evaluation_keys=client.evaluation_keys) + result = client.decrypt(encrypted_result) + assert np.array_equal(result, x + y + np.array(z)) + + with pytest.raises(ValueError) as excinfo: + server.run(1, 2, 3, evaluation_keys=client.evaluation_keys) + + assert str(excinfo.value) == "Expected argument 0 to be an fhe.Value but it's int" + + with pytest.raises(RuntimeError) as excinfo: + server.run(encrypted_x, [2, 2], 3, evaluation_keys=client.evaluation_keys) + + assert str(excinfo.value) == "Tried to transform plaintext value with incompatible shape." + + def test_client_server_api_crt(helpers): """ Test client/server API on a CRT circuit.