diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 57871125e7..0c7b1c5fd6 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -992,6 +992,7 @@ class Configuration: optimize_tlu_based_on_original_bit_width: Union[bool, int] detect_overflow_in_simulation: bool dynamic_indexing_check_out_of_bounds: bool + dynamic_assignment_check_out_of_bounds: bool def __init__( self, @@ -1059,6 +1060,7 @@ def __init__( optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8, detect_overflow_in_simulation: bool = False, dynamic_indexing_check_out_of_bounds: bool = True, + dynamic_assignment_check_out_of_bounds: bool = True, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1162,6 +1164,7 @@ def __init__( self.detect_overflow_in_simulation = detect_overflow_in_simulation self.dynamic_indexing_check_out_of_bounds = dynamic_indexing_check_out_of_bounds + self.dynamic_assignment_check_out_of_bounds = dynamic_assignment_check_out_of_bounds self._validate() @@ -1236,6 +1239,7 @@ def fork( optimize_tlu_based_on_original_bit_width: Union[Keep, bool, int] = KEEP, detect_overflow_in_simulation: Union[Keep, bool] = KEEP, dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP, + dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 0b028fd1d1..20296f8dfc 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -15,7 +15,6 @@ from mlir.ir import BoolAttr as MlirBoolAttr from mlir.ir import Context as MlirContext from mlir.ir import DenseElementsAttr as MlirDenseElementsAttr -from mlir.ir import DenseI64ArrayAttr as MlirDenseI64ArrayAttr from mlir.ir import IndexType from mlir.ir import InsertionPoint as MlirInsertionPoint from mlir.ir import IntegerAttr as MlirIntegerAttr @@ -1913,154 +1912,21 @@ def array(self, resulting_type: ConversionType, elements: List[Conversion]) -> C original_bit_width=original_bit_width, ) - def assign_static( + def assign( self, resulting_type: ConversionType, x: Conversion, y: Conversion, - index: Sequence[Union[int, np.integer, slice]], + index: Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]], ): - if x.is_clear and y.is_encrypted: - highlights = { - x.origin: "tensor is clear", - y.origin: "assigned value is encrypted", - self.converting: "but encrypted values cannot be assigned to clear tensors", - } - self.error(highlights) - - assert self.is_bit_width_compatible(resulting_type, x, y) - - if any(isinstance(indexing_element, (list, np.ndarray)) for indexing_element in index): - return self.assign_static_fancy(resulting_type, x, y, index) - - index = list(index) - while len(index) < len(x.shape): - index.append(slice(None, None, None)) - - offsets = [] - sizes = [] - strides = [] - - for indexing_element, dimension_size in zip(index, x.shape): - if isinstance(indexing_element, slice): - size = int(np.zeros(dimension_size)[indexing_element].shape[0]) - stride = int(indexing_element.step if indexing_element.step is not None else 1) - offset = int( - ( - indexing_element.start - if indexing_element.start >= 0 - else indexing_element.start + dimension_size - ) - if indexing_element.start is not None - else (0 if stride > 0 else dimension_size - 1) - ) - - else: - size = 1 - stride = 1 - offset = int( - indexing_element if indexing_element >= 0 else indexing_element + dimension_size - ) - - offsets.append(offset) - sizes.append(size) - strides.append(stride) - - if x.is_encrypted and y.is_clear: - encrypted_type = self.typeof( - ValueDescription( - dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width), - shape=y.shape, - is_encrypted=True, - ) - ) - y = self.encrypt(encrypted_type, y) - - required_y_shape_list = [] - for i, indexing_element in enumerate(index): - if isinstance(indexing_element, slice): - n = len(np.zeros(x.shape[i])[indexing_element]) - required_y_shape_list.append(n) - else: - required_y_shape_list.append(1) - - required_y_shape = tuple(required_y_shape_list) - try: - np.reshape(np.zeros(y.shape), required_y_shape) - y = self.reshape(y, required_y_shape) - except Exception: # pylint: disable=broad-except - np.broadcast_to(np.zeros(y.shape), required_y_shape) - y = self.broadcast_to(y, required_y_shape) - - x = self.to_signedness(x, of=resulting_type) - y = self.to_signedness(y, of=resulting_type) - - return self.operation( - tensor.InsertSliceOp, - resulting_type, - y.result, - x.result, - (), - (), - (), - MlirDenseI64ArrayAttr.get(offsets), - MlirDenseI64ArrayAttr.get(sizes), - MlirDenseI64ArrayAttr.get(strides), - original_bit_width=x.original_bit_width, - ) + # This import needs to happen here to avoid circular imports. + # pylint: disable=import-outside-toplevel - def assign_static_fancy( - self, - resulting_type: ConversionType, - x: Conversion, - y: Conversion, - index: Sequence[Union[int, np.integer, slice, np.ndarray, list]], - ) -> Conversion: - resulting_element_type = (self.eint if resulting_type.is_unsigned else self.esint)( - resulting_type.bit_width - ) + from .operations.assignment import assignment - indices = [] - indices_shape = () - for indexing_element in index: - if isinstance(indexing_element, (int, np.integer, list, np.ndarray)): - indexing_element_array = np.array(indexing_element) - indices_shape = np.broadcast_shapes( - indices_shape, - indexing_element_array.shape, - ) # type:ignore - indices.append(indexing_element_array) - - else: # pragma: no cover - message = f"invalid indexing element of type {type(indexing_element)}" - raise AssertionError(message) - values_shape = indices_shape - - indices = [np.broadcast_to(index, shape=indices_shape) for index in indices] - - concrete_indices = [] - for i in np.ndindex(*indices_shape): - concrete_index = [index[i] for index in indices] - concrete_indices.append(concrete_index) - - if len(x.shape) > 1: - indices_shape = (*indices_shape, len(x.shape)) # type: ignore - - if x.is_clear and y.is_encrypted: - raise NotImplementedError - if x.is_encrypted and y.is_clear: - y = self.encrypt(self.tensor(resulting_element_type, shape=y.shape), y) + # pylint: enable=import-outside-toplevel - return self.operation( - fhelinalg.FancyAssignOp, - resulting_type, - x.result, - self.constant( - self.tensor(self.index_type(), indices_shape), - np.array(concrete_indices).reshape(indices_shape), - ).result, - self.broadcast_to(y, shape=values_shape).result, - ) + return assignment(self, resulting_type, x, y, index) def bitwise( self, @@ -2713,6 +2579,7 @@ def index( ) -> Conversion: # This import needs to happen here to avoid circular imports. # pylint: disable=import-outside-toplevel + from .operations.indexing import indexing # pylint: enable=import-outside-toplevel diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 9a5e948b47..6245c2fa18 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -290,13 +290,34 @@ def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion assert len(preds) > 0 return ctx.array(ctx.typeof(node), elements=preds) + def assign_dynamic(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + assert len(preds) >= 3 + + x = preds[0] + y = preds[-1] + + dynamic_indices = preds[1:-1] + static_indices = node.properties["kwargs"]["static_indices"] + + indices = [] + + cursor = 0 + for index in static_indices: + if index is None: + indices.append(dynamic_indices[cursor]) + cursor += 1 + else: + indices.append(index) + + return ctx.assign(ctx.typeof(node), x, y, indices) + def assign_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 2 - return ctx.assign_static( + return ctx.assign( ctx.typeof(node), preds[0], preds[1], - index=node.properties["kwargs"]["index"], + node.properties["kwargs"]["index"], ) def bitwise_and(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: diff --git a/frontends/concrete-python/concrete/fhe/mlir/operations/assignment.py b/frontends/concrete-python/concrete/fhe/mlir/operations/assignment.py new file mode 100644 index 0000000000..bfc68b6315 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/mlir/operations/assignment.py @@ -0,0 +1,221 @@ +""" +Conversion of assignment operation. +""" + +# pylint: disable=import-error,no-name-in-module + +from typing import Any, List, Sequence, Union + +import numpy as np +from concrete.lang.dialects import fhelinalg +from mlir.dialects import tensor +from mlir.ir import DenseI64ArrayAttr as MlirDenseI64ArrayAttr +from mlir.ir import ShapedType as MlirShapedType + +from ..context import Context +from ..conversion import Conversion, ConversionType +from .indexing import generate_fancy_indices, process_indexing_element + +# pylint: enable=import-error,no-name-in-module + + +def fancy_assignment( + ctx: Context, + resulting_type: ConversionType, + x: Conversion, + y: Conversion, + index: Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]], +) -> Conversion: + """ + Convert fancy assignment operation. + + Args: + ctx (Context): + conversion context + + resulting_type (ConversionType): + resulting type of the operation + + x (Conversion): + tensor to assign to + + y (Conversion): + tensor to assign + + index (Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]]): + fancy index to use + + Returns: + Conversion: + x after fancy assignment + """ + + sample_index = [] + for indexing_element in index: + sample_index.append( + np.zeros(indexing_element.shape, dtype=np.int64) + if isinstance(indexing_element, Conversion) + else indexing_element + ) + + indexing_element_shape = np.zeros(resulting_type.shape, dtype=np.int8)[ + tuple(sample_index) + ].shape + + indices = generate_fancy_indices( + ctx, + indexing_element_shape, + x, + index, + check_out_of_bounds=ctx.configuration.dynamic_assignment_check_out_of_bounds, + ) + + if y.shape != indexing_element_shape: + y = ctx.broadcast_to(y, indexing_element_shape) + + return ctx.operation( + fhelinalg.FancyAssignOp, + resulting_type, + x.result, + indices.result, + y.result, + ) + + +def assignment( + ctx: Context, + resulting_type: ConversionType, + x: Conversion, + y: Conversion, + index: Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]], +) -> Conversion: + """ + Convert assignment operation. + + Args: + ctx (Context): + conversion context + + resulting_type (ConversionType): + resulting type of the operation + + x (Conversion): + tensor to assign to + + y (Conversion): + tensor to assign + + index (Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]]): + index to use + + Returns: + Conversion: + x after assignment + """ + + if x.is_clear and y.is_encrypted: + highlights = { + x.origin: "tensor is clear", + y.origin: "assigned value is encrypted", + ctx.converting: "but encrypted values cannot be assigned to clear tensors", + } + ctx.error(highlights) + + assert ctx.is_bit_width_compatible(resulting_type, x, y) + + index = list(index) + while len(index) < len(x.shape): + index.append(slice(None, None, None)) + + if x.is_encrypted and y.is_clear: + encrypted_type = ctx.tensor(ctx.element_typeof(x), y.shape) + y = ctx.encrypt(encrypted_type, y) + + is_fancy = any( + ( + isinstance(indexing_element, (list, np.ndarray)) + or (isinstance(indexing_element, Conversion) and indexing_element.is_tensor) + ) + for indexing_element in index + ) + if is_fancy: + return fancy_assignment(ctx, resulting_type, x, y, index) + + static_offsets: List[Any] = [] + static_sizes: List[Any] = [] + static_strides: List[Any] = [] + + dynamic_offsets: List[Any] = [] + + for indexing_element, dimension_size in zip(index, x.shape): + offset: Any + size: Any + stride: Any + + if isinstance(indexing_element, slice): + size = int(np.zeros(dimension_size)[indexing_element].shape[0]) + stride = int(indexing_element.step if indexing_element.step is not None else 1) + offset = int( + process_indexing_element( + ctx, + indexing_element.start, # type: ignore + dimension_size, + check_out_of_bounds=ctx.configuration.dynamic_assignment_check_out_of_bounds, + ) + if indexing_element.start is not None + else (0 if stride > 0 else dimension_size - 1) + ) + else: + assert isinstance(indexing_element, (int, np.integer)) or ( + isinstance(indexing_element, Conversion) and indexing_element.is_scalar + ) + + size = 1 + stride = 1 + offset = process_indexing_element( + ctx, + indexing_element, + dimension_size, + check_out_of_bounds=ctx.configuration.dynamic_assignment_check_out_of_bounds, + ) + + if isinstance(offset, Conversion): + dynamic_offsets.append(offset) + offset = MlirShapedType.get_dynamic_size() + + static_offsets.append(offset) + static_sizes.append(size) + static_strides.append(stride) + + required_y_shape_list = [] + for i, indexing_element in enumerate(index): + if isinstance(indexing_element, slice): + n = len(np.zeros(x.shape[i])[indexing_element]) + required_y_shape_list.append(n) + else: + required_y_shape_list.append(1) + + required_y_shape = tuple(required_y_shape_list) + try: + np.reshape(np.zeros(y.shape), required_y_shape) + y = ctx.reshape(y, required_y_shape) + except Exception: # pylint: disable=broad-except + np.broadcast_to(np.zeros(y.shape), required_y_shape) + y = ctx.broadcast_to(y, required_y_shape) + + x = ctx.to_signedness(x, of=resulting_type) + y = ctx.to_signedness(y, of=resulting_type) + + return ctx.operation( + tensor.InsertSliceOp, + resulting_type, + y.result, + x.result, + tuple(item.result for item in dynamic_offsets), + (), + (), + MlirDenseI64ArrayAttr.get(static_offsets), + MlirDenseI64ArrayAttr.get(static_sizes), + MlirDenseI64ArrayAttr.get(static_strides), + original_bit_width=x.original_bit_width, + ) diff --git a/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py b/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py index c589c600f6..b84a01d7bf 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py +++ b/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py @@ -130,6 +130,7 @@ def process_indexing_element( ctx: Context, indexing_element: Union[int, np.integer, slice, np.ndarray, list, Conversion], dimension_size: int, + check_out_of_bounds: bool, ) -> Union[int, np.integer, slice, np.ndarray, list, Conversion]: """ Process indexing element. @@ -148,6 +149,9 @@ def process_indexing_element( dimension_size (int): size of the indexed dimension + check_out_of_bounds (int): + whether to check for out of bounds access in runtime + Returns: Union[int, np.integer, slice, np.ndarray, list, Conversion]: processed indexing element @@ -220,7 +224,7 @@ def process_indexing_element( assert new_indexing_element is not None indexing_element = new_indexing_element - if ctx.configuration.dynamic_indexing_check_out_of_bounds: + if check_out_of_bounds: check_out_of_bounds_in_runtime(ctx, indexing_element, dimension_size) else: element_type = ctx.element_typeof(indexing_element) @@ -263,7 +267,7 @@ def process_indexing_element( ) assert sanitized_index is not None - if ctx.configuration.dynamic_indexing_check_out_of_bounds: + if check_out_of_bounds: check_out_of_bounds_in_runtime(ctx, sanitized_index, dimension_size) assert sanitized_index is not None @@ -279,7 +283,7 @@ def process_indexing_element( sanitized_indexing_element.result, ) - elif ctx.configuration.dynamic_indexing_check_out_of_bounds: + elif check_out_of_bounds: check_out_of_bounds_in_runtime(ctx, indexing_element, dimension_size) return ctx.operation( @@ -292,21 +296,22 @@ def process_indexing_element( return 0 # pragma: no cover -def fancy_indexing( +def generate_fancy_indices( ctx: Context, - resulting_type: ConversionType, + indexing_element_shape: Tuple[int, ...], x: Conversion, index: Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]], + check_out_of_bounds: bool, ) -> Conversion: """ - Convert fancy indexing operation. + Generate indices to use for fancy indexing. Args: ctx (Context): conversion context - resulting_type (ConversionType): - resulting type of the operation + indexing_element_shape (Tuple[int, ...]): + individual shape of indexing elements x (Conversion): tensor to fancy index @@ -314,9 +319,12 @@ def fancy_indexing( index (Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]]): fancy index to use + check_out_of_bounds (int): + whether to check for out of bounds access in runtime + Returns: Conversion: - result of fancy indexing operation + indices to use for fancy indexing operation """ # refer to @@ -343,6 +351,7 @@ def fancy_indexing( ctx, indexing_element, dimension_size, + check_out_of_bounds, ) if isinstance(indexing_element, Conversion): @@ -369,17 +378,17 @@ def fancy_indexing( assert len(x.shape) == 1 indices = processed_index[0] else: - expected_indexing_element_shape = resulting_type.shape + (1,) + expanded_indexing_element_shape = indexing_element_shape + (1,) to_concat = [] for dimension, indexing_element in enumerate(processed_index): if indexing_element.is_scalar: to_concat.append( - ctx.broadcast_to(indexing_element, expected_indexing_element_shape) + ctx.broadcast_to(indexing_element, expanded_indexing_element_shape) ) - elif indexing_element.shape == expected_indexing_element_shape[:-1]: - to_concat.append(ctx.reshape(indexing_element, expected_indexing_element_shape)) + elif indexing_element.shape == indexing_element_shape: + to_concat.append(ctx.reshape(indexing_element, expanded_indexing_element_shape)) else: if len(fancy_indices_start_positions) == 1: @@ -392,7 +401,8 @@ def fancy_indexing( if isinstance(original_indexing_element, list): original_indexing_element = np.array(original_indexing_element) broadcast_shape = np.broadcast_shapes( - broadcast_shape, original_indexing_element.shape + broadcast_shape, + original_indexing_element.shape, ) extra_dimensions = 1 @@ -428,15 +438,53 @@ def fancy_indexing( indexing_element.shape + (1,) * extra_dimensions, ) to_concat.append( - ctx.broadcast_to(indexing_element, expected_indexing_element_shape) + ctx.broadcast_to(indexing_element, expanded_indexing_element_shape) ) indices = ctx.concatenate( - ctx.tensor(ctx.index_type(), resulting_type.shape + (len(to_concat),)), + ctx.tensor(ctx.index_type(), indexing_element_shape + (len(to_concat),)), to_concat, axis=-1, ) + return indices + + +def fancy_indexing( + ctx: Context, + resulting_type: ConversionType, + x: Conversion, + index: Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]], +) -> Conversion: + """ + Convert fancy indexing operation. + + Args: + ctx (Context): + conversion context + + resulting_type (ConversionType): + resulting type of the operation + + x (Conversion): + tensor to fancy index + + index (Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]]): + fancy index to use + + Returns: + Conversion: + result of fancy indexing operation + """ + + indices = generate_fancy_indices( + ctx, + resulting_type.shape, + x, + index, + check_out_of_bounds=ctx.configuration.dynamic_indexing_check_out_of_bounds, + ) + return ctx.operation( fhelinalg.FancyIndexOp, resulting_type, @@ -498,6 +546,7 @@ def indexing( ctx, indexing_element, dimension_size, + check_out_of_bounds=ctx.configuration.dynamic_indexing_check_out_of_bounds, ) if not isinstance(indexing_element, Conversion): indexing_element = ctx.constant( @@ -535,6 +584,7 @@ def indexing( ctx, indexing_element.start, # type: ignore dimension_size, + check_out_of_bounds=ctx.configuration.dynamic_indexing_check_out_of_bounds, ) if indexing_element.start is not None else (0 if stride > 0 else dimension_size - 1) @@ -547,7 +597,12 @@ def indexing( size = 1 stride = 1 - offset = process_indexing_element(ctx, indexing_element, dimension_size) + offset = process_indexing_element( + ctx, + indexing_element, + dimension_size, + check_out_of_bounds=ctx.configuration.dynamic_indexing_check_out_of_bounds, + ) if isinstance(offset, Conversion): dynamic_offsets.append(offset) diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index e9a3cf7649..da5a96028a 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -404,6 +404,10 @@ def min_max(self, node: Node, preds: List[Node]): inputs_and_output_share_precision, } + assign_dynamic = { + inputs_and_output_share_precision, + } + assign_static = { inputs_and_output_share_precision, } diff --git a/frontends/concrete-python/concrete/fhe/representation/node.py b/frontends/concrete-python/concrete/fhe/representation/node.py index 125bfc4cfb..2c9b467353 100644 --- a/frontends/concrete-python/concrete/fhe/representation/node.py +++ b/frontends/concrete-python/concrete/fhe/representation/node.py @@ -325,6 +325,23 @@ def format(self, predecessors: List[str], maximum_constant_length: int = 45) -> elements = [format_indexing_element(element) for element in index] return f"{predecessors[0]}[{', '.join(elements)}]" + if name == "assign_dynamic": + dynamic_indices = predecessors[1:-1] + static_indices = self.properties["kwargs"]["static_indices"] + + indices = [] + + cursor = 0 + for index in static_indices: + if index is None: + indices.append(dynamic_indices[cursor]) + cursor += 1 + else: + indices.append(index) + + elements = [format_indexing_element(element) for element in indices] + return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[-1]})" + if name == "assign_static": index = self.properties["kwargs"]["index"] elements = [format_indexing_element(element) for element in index] @@ -387,6 +404,9 @@ def label(self) -> str: if name == "index_static": name = self.format(["□"]) + if name == "assign_dynamic": + name = self.format(["□"] * len(self.inputs))[1:-1] + if name == "assign_static": name = self.format(["□", "□"])[1:-1] @@ -427,6 +447,7 @@ def converted_to_table_lookup(self) -> bool: return self.operation == Operation.Generic and self.properties["name"] not in [ "add", "array", + "assign_dynamic", "assign_static", "broadcast_to", "concatenate", diff --git a/frontends/concrete-python/concrete/fhe/tracing/tracer.py b/frontends/concrete-python/concrete/fhe/tracing/tracer.py index 18a98d2630..0cd805304a 100644 --- a/frontends/concrete-python/concrete/fhe/tracing/tracer.py +++ b/frontends/concrete-python/concrete/fhe/tracing/tracer.py @@ -785,9 +785,7 @@ def __getitem__( reject = True break - if isinstance(indexing_element, np.ndarray) or ( - isinstance(indexing_element, Tracer) and not indexing_element.output.is_scalar - ): + if isinstance(indexing_element, np.ndarray): reject = not np.issubdtype(indexing_element.dtype, np.integer) continue @@ -896,27 +894,26 @@ def __setitem__( if not isinstance(index, tuple): index = (index,) - is_fancy = False - has_slices = False - reject = False for indexing_element in index: + if isinstance(indexing_element, Tracer): + reject = reject or indexing_element.output.is_encrypted + continue + if isinstance(indexing_element, list): try: indexing_element = np.array(indexing_element) - except Exception: # pragma: no cover # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except reject = True break if isinstance(indexing_element, np.ndarray): - is_fancy = True reject = not np.issubdtype(indexing_element.dtype, np.integer) continue valid = isinstance(indexing_element, (int, np.integer, slice)) if isinstance(indexing_element, slice): # noqa: SIM102 - has_slices = True if ( not ( indexing_element.start is None @@ -937,7 +934,7 @@ def __setitem__( reject = True break - if reject or (is_fancy and has_slices): + if reject: indexing_elements = [ format_indexing_element(indexing_element) for indexing_element in index ] @@ -949,21 +946,75 @@ def __setitem__( message = f"{self}[{formatted_index}] cannot be assigned {value}" raise ValueError(message) - np.zeros(self.output.shape)[index] = 1 + output_value = deepcopy(self.output) - def assign(x, value, index): - x[index] = value - return x + sample_index = [] + for indexing_element in index: + sample_index.append( + np.zeros(indexing_element.shape, dtype=np.int64) + if isinstance(indexing_element, Tracer) + else indexing_element + ) - sanitized_value = self.sanitize(value) - computation = Node.generic( - "assign_static", - [deepcopy(self.output), deepcopy(sanitized_value.output)], - deepcopy(self.output), - assign, - kwargs={"index": index}, - ) - new_version = Tracer(computation, [self, sanitized_value]) + np.zeros(self.output.shape)[tuple(sample_index)] = 1 + + if any(isinstance(indexing_element, Tracer) for indexing_element in index): + dynamic_indices = [] + static_indices: List[Any] = [] + + for indexing_element in index: + if isinstance(indexing_element, Tracer): + static_indices.append(None) + dynamic_indices.append(indexing_element) + else: + static_indices.append(indexing_element) + + def assign_dynamic(tensor, *dynamic_indices_and_value, static_indices): + dynamic_indices = dynamic_indices_and_value[:-1] + value = dynamic_indices_and_value[-1] + + final_indices = [] + + cursor = 0 + for index in static_indices: + if index is None: + final_indices.append(dynamic_indices[cursor]) + cursor += 1 + else: + final_indices.append(index) + + tensor[tuple(final_indices)] = value + return tensor + + sanitized_value = self.sanitize(value) + computation = Node.generic( + "assign_dynamic", + [deepcopy(self.output)] + + [deepcopy(index.output) for index in dynamic_indices] + + [sanitized_value.output], + output_value, + assign_dynamic, + kwargs={"static_indices": static_indices}, + ) + new_version = Tracer( + computation, [self] + [index for index in dynamic_indices] + [sanitized_value] + ) + + else: + + def assign(x, value, index): + x[index] = value + return x + + sanitized_value = self.sanitize(value) + computation = Node.generic( + "assign_static", + [deepcopy(self.output), deepcopy(sanitized_value.output)], + deepcopy(self.output), + assign, + kwargs={"index": index}, + ) + new_version = Tracer(computation, [self, sanitized_value]) self.last_version = new_version diff --git a/frontends/concrete-python/mypy.ini b/frontends/concrete-python/mypy.ini index 1e8ef81b62..34ca2253fb 100644 --- a/frontends/concrete-python/mypy.ini +++ b/frontends/concrete-python/mypy.ini @@ -2,4 +2,4 @@ plugins = numpy.typing.mypy_plugin disable_error_code = annotation-unchecked allow_redefinition = True -exclude = test_dynamic_indexing\.py +exclude = test_dynamic_(indexing|assignment)\.py diff --git a/frontends/concrete-python/tests/compilation/test_circuit.py b/frontends/concrete-python/tests/compilation/test_circuit.py index 2ef2d3d9d3..ec05693680 100644 --- a/frontends/concrete-python/tests/compilation/test_circuit.py +++ b/frontends/concrete-python/tests/compilation/test_circuit.py @@ -408,7 +408,9 @@ def test_circuit_run_with_unused_arg(helpers): def f(x, y): # pylint: disable=unused-argument return x + 10 - inputset = [(np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100)] + inputset = [ + (np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100) + ] circuit = f.compile(inputset, configuration) with pytest.raises(ValueError, match="Expected 2 inputs but got 1"): diff --git a/frontends/concrete-python/tests/execution/test_dynamic_assignment.py b/frontends/concrete-python/tests/execution/test_dynamic_assignment.py new file mode 100644 index 0000000000..298877be15 --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_dynamic_assignment.py @@ -0,0 +1,496 @@ +""" +Tests of execution of dynamic assignment operation. +""" + +import random + +import numpy as np +import pytest + +from concrete import fhe + + +@pytest.mark.parametrize( + "dtype,index,value_status,value", + [ + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(0, 5),), + "clear", + 42, + id="x[i] = 42 where x.shape = (5,) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(-5, 5),), + "clear", + 42, + id="x[i] = 42 where x.shape = (5,) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(0, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (5,) | 0 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(-5, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (5,) | -5 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 50], + (lambda _: np.random.randint(0, 5),), + "clear", + 42, + id="x[i] = 42 where x.shape = (50,) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 50], + (lambda _: np.random.randint(-5, 5),), + "clear", + 42, + id="x[i] = 42 where x.shape = (50,) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 50], + (lambda _: np.random.randint(0, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (50,) | 0 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 50], + (lambda _: np.random.randint(-5, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (50,) | -5 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5), 0), + "clear", + 42, + id="x[i, 0] = 42 where x.shape = (5, 3) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5), 0), + "clear", + 42, + id="x[i, 0] = 42 where x.shape = (5, 3) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5), 0), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i, 0] = y where x.shape = (5, 3) | 0 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5), 0), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i, 0] = y where x.shape = (5, 3) | -5 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (1, lambda _: np.random.randint(0, 5)), + "clear", + 42, + id="x[1, i] = 42 where x.shape = (5, 3) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (1, lambda _: np.random.randint(-5, 5)), + "clear", + 42, + id="x[1, i] = 42 where x.shape = (5, 3) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (1, lambda _: np.random.randint(0, 5)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[1, i] = y where x.shape = (5, 3) | 0 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (1, lambda _: np.random.randint(-5, 5)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[1, i] = y where x.shape = (5, 3) | -5 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5), lambda _: np.random.randint(0, 3)), + "clear", + 42, + id="x[i, j] = 42 where x.shape = (5, 3) | 0 < i < 5 | 0 < j < 3", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5), lambda _: np.random.randint(0, 3)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i, j] = y where x.shape = (5, 3) | 0 < i < 5 | 0 < j < 3 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5), lambda _: np.random.randint(-3, 3)), + "clear", + 42, + id="x[i, j] = 42 where x.shape = (5, 3) | 0 < i < 5 | -3 < j < 3", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5), lambda _: np.random.randint(-3, 3)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i, j] = y where x.shape = (5, 3) | 0 < i < 5 | -3 < j < 3 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5), lambda _: np.random.randint(0, 3)), + "clear", + 42, + id="x[i, j] = 42 where x.shape = (5, 3) | -5 < i < 5 | 0 < j < 3", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5), lambda _: np.random.randint(0, 3)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i, j] = y where x.shape = (5, 3) | -5 < i < 5 | 0 < j < 3 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5), lambda _: np.random.randint(-3, 3)), + "clear", + 42, + id="x[i, j] = 42 where x.shape = (5, 3) | -5 < i < 5 | -3 < j < 3", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5), lambda _: np.random.randint(-3, 3)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i, j] = y where x.shape = (5, 3) | -5 < i < 5 | -3 < j < 3 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5),), + "clear", + 42, + id="x[i] = 42 where x.shape = (5, 3) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (5, 3) | 0 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5),), + "clear", + 42, + id="x[i] = 42 where x.shape = (5, 3) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (5, 3) | -5 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5),), + "clear", + [10, 20, 30], + id="x[i] = [10, 20, 30] where x.shape = (5, 3) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(0, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10, size=(3,)), + id="x[i] = y where x.shape = (5, 3) | 0 < i < 5 | -10 < y < 10 | y.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5),), + "clear", + [10, 20, 30], + id="x[i] = [10, 20, 30] where x.shape = (5, 3) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 5, 3], + (lambda _: np.random.randint(-5, 5),), + "encrypted", + lambda _: np.random.randint(-10, 10, size=(3,)), + id="x[i] = y where x.shape = (5, 3) | -5 < i < 5 | -10 < y < 10 | y.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(0, 5)), + "clear", + 42, + id="x[:, i] = 42 where x.shape = (3, 5) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(0, 5)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[:, i] = y where x.shape = (3, 5) | 0 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(-5, 5)), + "clear", + 42, + id="x[:, i] = 42 where x.shape = (3, 5) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(-5, 5)), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[:, i] = y where x.shape = (3, 5) | -5 < i < 5 | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(0, 5)), + "clear", + [10, 20, 30], + id="x[:, i] = [10, 20, 30] where x.shape = (3, 5) | 0 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(0, 5)), + "encrypted", + lambda _: np.random.randint(-10, 10, size=(3,)), + id="x[:, i] = y where x.shape = (3, 5) | 0 < i < 5 | -10 < y < 10 | y.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(-5, 5)), + "clear", + [10, 20, 30], + id="x[:, i] = [10, 20, 30] where x.shape = (3, 5) | -5 < i < 5", + ), + pytest.param( + fhe.tensor[fhe.int6, 3, 5], + (slice(None, None, None), lambda _: np.random.randint(-5, 5)), + "encrypted", + lambda _: np.random.randint(-10, 10, size=(3,)), + id="x[:, i] = y where x.shape = (3, 5) | -5 < i < 5 | -10 < y < 10 | y.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 10, 9, 8], + (slice(1, 3, None), lambda _: np.random.randint(0, 9), slice(4, 6, None)), + "clear", + 42, + id="x[1:3, i, 4:6] = 42 where x.shape = (10, 9, 8) | 0 < i < 9", + ), + pytest.param( + fhe.tensor[fhe.int6, 10, 9, 8], + (slice(1, 3, None), lambda _: np.random.randint(-9, 9), slice(4, 6, None)), + "clear", + 42, + id="x[1:3, i, 4:6] = 42 where x.shape = (10, 9, 8) | -9 < i < 9", + ), + pytest.param( + fhe.tensor[fhe.int6, 10, 9, 8], + ( + lambda _: np.random.randint(0, 10), + slice(2, 5, None), + lambda _: np.random.randint(0, 8), + ), + "clear", + 42, + id="x[i, 2:5, j] = 42 where x.shape = (10, 9, 8) | 0 < i < 10 | 0 < j < 8", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(0, 5, size=(3,)),), + "clear", + 42, + id="x[i] = 42 where x.shape = (5,) | 0 < i < 5 | i.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(-5, 5, size=(3,)),), + "clear", + 42, + id="x[i] = 42 where x.shape = (5,) | -5 < i < 5 | i.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(0, 5, size=(3,)),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (5,) | 0 < i < 5 | i.shape = (3,) | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(-5, 5, size=(3,)),), + "encrypted", + lambda _: np.random.randint(-10, 10), + id="x[i] = y where x.shape = (5,) | -5 < i < 5 | i.shape = (3,) | -10 < y < 10", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(0, 5, size=(3,)),), + "clear", + [10, 20, 30], + id="x[i] = [10, 20, 30] where x.shape = (5,) | 0 < i < 5 | i.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(-5, 5, size=(3,)),), + "clear", + [10, 20, 30], + id="x[i] = [10, 20, 30] where x.shape = (5,) | -5 < i < 5 | i.shape = (3,)", + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(0, 5, size=(3,)),), + "encrypted", + lambda _: np.random.randint(-10, 10, size=(3,)), + id=( + "x[i] = y where x.shape = (5,) " + "| 0 < i < 5 | i.shape = (3,) " + "| -10 < y < 10 | y.shape = (3,)" + ), + ), + pytest.param( + fhe.tensor[fhe.int6, 5], + (lambda _: np.random.randint(-5, 5, size=(3,)),), + "encrypted", + lambda _: np.random.randint(-10, 10, size=(3,)), + id=( + "x[i] = y where x.shape = (5,) " + "| -5 < i < 5 | i.shape = (3,) " + "| -10 < y < 10 | y.shape = (3,)" + ), + ), + pytest.param( + fhe.tensor[fhe.int6, 20], + (lambda _: np.random.randint(0, 20, size=(3, 2)),), + "clear", + [[10, 11], [20, 21], [30, 31]], + id="x[i] = [[10, 11], [20, 21], [30, 31]] where x.shape = (20,) | 0 < i < 20 | i.shape = (3, 2)", + ), + pytest.param( + fhe.tensor[fhe.int6, 20], + (lambda _: np.random.randint(0, 20, size=(3, 2)),), + "clear", + [42, 24], + id="x[i] = [42, 24] where x.shape = (20,) | 0 < i < 20 | i.shape = (3, 2)", + ), + ], +) +def test_dynamic_assignment(dtype, index, value_status, value, helpers): + """ + Test dynamic assignment. + """ + + dynamic_index_positions = [] + dynamic_indices = [] + + for position, indexing_element in enumerate(index): + if callable(indexing_element): + dynamic_indices.append(indexing_element) + dynamic_index_positions.append(position) + + processed_index = list(index) + + def f(tensor, *args, value): + cursor = 0 + for position in dynamic_index_positions: + processed_index[position] = args[cursor] + cursor += 1 + + tensor[tuple(processed_index)] = value + return tensor + + if len(dynamic_index_positions) == 1: + if callable(value): + + def function(tensor, i0, value=value): + return f(tensor, i0, value=value) + + else: + + def function(tensor, i0): + return f(tensor, i0, value=value) + + elif len(dynamic_index_positions) == 2: + if callable(value): + + def function(tensor, i0, i1, value=value): + return f(tensor, i0, i1, value=value) + + else: + + def function(tensor, i0, i1): + return f(tensor, i0, i1, value=value) + + elif len(dynamic_index_positions) == 3: + if callable(value): + + def function(tensor, i0, i1, i2, value=value): + return f(tensor, i0, i1, i2, value=value) + + else: + + def function(tensor, i0, i1, i2): + return f(tensor, i0, i1, i2, value=value) + + elif len(dynamic_index_positions) == 4: + if callable(value): + + def function(tensor, i0, i1, i2, i3, value=value): + return f(tensor, i0, i1, i2, i3, value=value) + + else: + + def function(tensor, i0, i1, i2, i3): + return f(tensor, i0, i1, i2, i3, value=value) + + else: + message = ( + f"expected at least 1 at most 4 dynamic indexing elements " + f"but got {len(dynamic_index_positions)}" + ) + raise RuntimeError(message) + + encryption_status = {"tensor": "encrypted"} + inputset_types = [dtype] + + cursor = 0 + for indexing_element in index: + if callable(indexing_element): + encryption_status[f"i{cursor}"] = "clear" + inputset_types.append(indexing_element) + cursor += 1 + if callable(value): + encryption_status[f"value"] = value_status + inputset_types.append(value) + + configuration = helpers.configuration() + compiler = fhe.Compiler(function, encryption_status) + + inputset = fhe.inputset(*inputset_types) + circuit = compiler.compile(inputset, configuration, show_mlir=True) + + for sample in random.sample(inputset, 8): + helpers.check_execution(circuit, function, list(sample))