diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 01466a6f1f..c127c09f81 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -1558,6 +1558,37 @@ def dot(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> C return self.operation(operation, resulting_type, x.result, y.result) + def dynamic_tlu( + self, + resulting_type: ConversionType, + on: Conversion, + table: Conversion, + ) -> Conversion: + assert table.is_clear and on.is_encrypted + + if table.shape != (2**on.bit_width,): + highlights: Dict[Node, Union[str, List[str]]] = { + table.origin: f"table has the shape {table.shape}", + on.origin: f"table lookup input is {on.bit_width}-bits", + self.converting: [ + "so table cannot be looked up with this input", + f"table shape should have been {(2**on.bit_width,)}", + ], + } + if on.bit_width != on.original_bit_width: # pragma: no cover + highlights[on.origin].append( # type: ignore + "(" + f"note that it's assigned {on.bit_width}-bits " + f"during compilation because of its relation with other operations" + ")" + ) + self.error(highlights) + + dialect = fhe if on.is_scalar else fhelinalg + operation = dialect.ApplyLookupTableEintOp + + return self.operation(operation, resulting_type, on.result, table.result) + def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion: assert self.is_bit_width_compatible(resulting_type, x) assert resulting_type.is_encrypted and x.is_clear diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index dfa4d58cef..b4e7ab6455 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -323,6 +323,10 @@ def dot(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 2 return ctx.dot(ctx.typeof(node), preds[0], preds[1]) + def dynamic_tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + assert len(preds) == 2 + return ctx.dynamic_tlu(ctx.typeof(node), preds[0], preds[1]) + def equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 2 diff --git a/frontends/concrete-python/concrete/fhe/representation/node.py b/frontends/concrete-python/concrete/fhe/representation/node.py index 3987690dbd..d77f95ebc9 100644 --- a/frontends/concrete-python/concrete/fhe/representation/node.py +++ b/frontends/concrete-python/concrete/fhe/representation/node.py @@ -386,6 +386,7 @@ def converted_to_table_lookup(self) -> bool: "conv2d", "conv3d", "dot", + "dynamic_tlu", "expand_dims", "index_static", "matmul", diff --git a/frontends/concrete-python/concrete/fhe/representation/utils.py b/frontends/concrete-python/concrete/fhe/representation/utils.py index 9c819a46e5..66a0690a8e 100644 --- a/frontends/concrete-python/concrete/fhe/representation/utils.py +++ b/frontends/concrete-python/concrete/fhe/representation/utils.py @@ -81,7 +81,7 @@ def format_constant(constant: Any, maximum_length: int = 45, keep_newlines: bool return result -def format_indexing_element(indexing_element: Union[int, np.integer, slice]): +def format_indexing_element(indexing_element: Union[int, np.integer, slice, Any]): """ Format an indexing element. diff --git a/frontends/concrete-python/concrete/fhe/tracing/tracer.py b/frontends/concrete-python/concrete/fhe/tracing/tracer.py index 8dd7574af7..366e6f3b08 100644 --- a/frontends/concrete-python/concrete/fhe/tracing/tracer.py +++ b/frontends/concrete-python/concrete/fhe/tracing/tracer.py @@ -738,8 +738,19 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer": def __getitem__( self, - index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]], + index: Union[ + int, np.integer, slice, "Tracer", Tuple[Union[int, np.integer, slice, "Tracer"], ...] + ], ) -> "Tracer": + if isinstance(index, Tracer) and index.output.is_encrypted and self.output.is_clear: + computation = Node.generic( + "dynamic_tlu", + [deepcopy(index.output), deepcopy(self.output)], + deepcopy(index.output), + lambda on, table: table[on], + ) + return Tracer(computation, [index, self]) + if not isinstance(index, tuple): index = (index,) @@ -770,7 +781,7 @@ def __getitem__( raise ValueError(message) output_value = deepcopy(self.output) - output_value.shape = np.zeros(output_value.shape)[index].shape + output_value.shape = np.zeros(output_value.shape)[index].shape # type: ignore computation = Node.generic( "index_static", diff --git a/frontends/concrete-python/tests/execution/test_dynamic_tlu.py b/frontends/concrete-python/tests/execution/test_dynamic_tlu.py new file mode 100644 index 0000000000..ec3f79a2c7 --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_dynamic_tlu.py @@ -0,0 +1,94 @@ +""" +Tests of execution of dynamic tlu operation. +""" + +import random + +import numpy as np +import pytest + +from concrete import fhe +from concrete.fhe.dtypes import Integer + +cases = [] +for input_bit_width in range(1, 3): + for input_is_signed in [False, True]: + for output_bit_width in range(1, 3): + for output_is_signed in [False, True]: + input_shape = random.choice([(), (2,), (3, 2)]) + cases.append( + pytest.param( + input_bit_width, + input_is_signed, + input_shape, + output_bit_width, + output_is_signed, + id=( + f"{'' if input_is_signed else 'u'}int{input_bit_width}" + f" -> " + f"{'' if output_is_signed else 'u'}int{output_bit_width}" + f" {{ input_shape={input_shape} }}" + ), + ) + ) + +# pylint: disable=redefined-outer-name + + +@pytest.mark.parametrize( + "input_bit_width,input_is_signed,input_shape,output_bit_width,output_is_signed", + cases, +) +def test_dynamic_tlu( + input_bit_width, + input_is_signed, + input_shape, + output_bit_width, + output_is_signed, + helpers, +): + """ + Test dynamic tlu. + """ + + input_dtype = Integer(is_signed=input_is_signed, bit_width=input_bit_width) + output_dtype = Integer(is_signed=output_is_signed, bit_width=output_bit_width) + + def function(x, y): + return y[x] + + compiler = fhe.Compiler(function, {"x": "encrypted", "y": "clear"}) + inputset = [ + ( + np.random.randint( + input_dtype.min(), + input_dtype.max() + 1, + size=input_shape, + ), + np.random.randint( + output_dtype.min(), + output_dtype.max() + 1, + size=(2**input_bit_width,), + ), + ) + for _ in range(100) + ] + circuit = compiler.compile(inputset, helpers.configuration()) + + samples = [ + [ + np.random.randint( + input_dtype.min(), + input_dtype.max() + 1, + size=input_shape, + ), + np.random.randint( + output_dtype.min(), + output_dtype.max() + 1, + size=(2**input_bit_width,), + ), + ] + for _ in range(5) + ] + for sample in samples: + helpers.check_execution(circuit, function, sample, retries=3) diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index 02398a9d61..dfe21aef38 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -811,6 +811,31 @@ def assign(x, y): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported return %2 + """, # noqa: E501 + ), + pytest.param( + lambda x, y: y[x], + {"x": "encrypted", "y": "clear"}, + [ + ( + 1, + [1, 2, 3, 4], + ) + ], + RuntimeError, + """ + +Function you are trying to compile cannot be compiled + +%0 = x # EncryptedScalar ∈ [1, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookup input is 1-bits +%1 = y # ClearTensor ∈ [1, 4] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table has the shape (4,) +%2 = dynamic_tlu(%0, %1) # EncryptedScalar ∈ [2, 2] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ so table cannot be looked up with this input + table shape should have been (2,) +return %2 + """, # noqa: E501 ), ],