diff --git a/docs/_static/rounding/approximate-off-by-one-error-approx-clipping.png b/docs/_static/rounding/approximate-off-by-one-error-approx-clipping.png new file mode 100644 index 0000000000..9364d8f4af Binary files /dev/null and b/docs/_static/rounding/approximate-off-by-one-error-approx-clipping.png differ diff --git a/docs/_static/rounding/approximate-off-by-one-error-logical-clipping.png b/docs/_static/rounding/approximate-off-by-one-error-logical-clipping.png new file mode 100644 index 0000000000..ca70ea3aa0 Binary files /dev/null and b/docs/_static/rounding/approximate-off-by-one-error-logical-clipping.png differ diff --git a/docs/_static/rounding/approximate-off-by-one-error.png b/docs/_static/rounding/approximate-off-by-one-error.png new file mode 100644 index 0000000000..3983d3409c Binary files /dev/null and b/docs/_static/rounding/approximate-off-by-one-error.png differ diff --git a/docs/_static/rounding/approximate-off-centering-distribution.png b/docs/_static/rounding/approximate-off-centering-distribution.png new file mode 100644 index 0000000000..447fba6f5c Binary files /dev/null and b/docs/_static/rounding/approximate-off-centering-distribution.png differ diff --git a/docs/_static/rounding/approximate-speedup.png b/docs/_static/rounding/approximate-speedup.png new file mode 100644 index 0000000000..164ba7e058 Binary files /dev/null and b/docs/_static/rounding/approximate-speedup.png differ diff --git a/docs/howto/configure.md b/docs/howto/configure.md index 9d2fc74896..616ef274b6 100644 --- a/docs/howto/configure.md +++ b/docs/howto/configure.md @@ -118,3 +118,12 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * Chunk size of the ReLU extension when [fhe.bits](../tutorial/bit_extraction.md) implementation is used. * **if_then_else_chunk_size**: int = 3 * Chunk size to use when converting `fhe.if_then_else` extension. +* **rounding_exactness** : Exactness = `fhe.Exactness.EXACT` + * Set default exactness mode for the rounding operation: + * `EXACT`: threshold for rounding up or down is exactly centered between upper and lower value, + * `APPROXIMATE`: faster but threshold for rounding up or down is approximately centered with pseudo-random shift. + * Precise and more complete behavior is described in [fhe.rounding_bit_pattern](../tutorial/rounding.md). +* **approximate_rounding_config** : ApproximateRoundingConfig = `fhe.ApproximateRoundingConfig()`: + * Provide more fine control on [approximate rounding](../tutorial/rounding.md#approximate-rounding-features): + * to enable exact cliping, + * or/and approximate clipping which make overflow protection faster. diff --git a/docs/tutorial/rounding.md b/docs/tutorial/rounding.md index 74967852b1..948b9b6669 100644 --- a/docs/tutorial/rounding.md +++ b/docs/tutorial/rounding.md @@ -97,7 +97,7 @@ prints: and displays: -![](../\_static/rounding/identity.png) +![](../_static/rounding/identity.png) {% hint style="info" %} If the rounded number is one of the last `2**(lsbs_to_remove - 1)` numbers in the input range `[0, 2**original_bit_width)`, an overflow **will** happen. @@ -194,7 +194,7 @@ The reason why the speed-up is not increasing with `lsbs_to_remove` is because t and displays: -![](../\_static/rounding/lsbs_to_remove.png) +![](../_static/rounding/lsbs_to_remove.png) {% hint style="info" %} Feel free to disable overflow protection and see what happens. @@ -289,8 +289,73 @@ target_msbs=1 => 2.34x speedup and displays: -![](../\_static/rounding/msbs_to_keep.png) +![](../_static/rounding/msbs_to_keep.png) {% hint style="warning" %} `AutoRounder`s should be defined outside the function that is being compiled. They are used to store the result of the adjustment process, so they shouldn't be created each time the function is called. Furthermore, each `AutoRounder` should be used with exactly one `round_bit_pattern` call. {% endhint %} + + +## Exactness + +One use of rounding is doing faster computation by ignoring the lower significant bits. +For this usage, you can even get faster results if you accept the rounding it-self to be slighlty inexact. +The speedup is usually around 2x-3x but can be higher for big precision reduction. +This also enable higher precisions values that are not possible otherwise. + +| ![approximate-speedup.png](../_static/rounding/approximate-speedup.png) | +|:--:| +| *Using the default configuration in approximate mode. For 3, 4, 5 and 6 reduced precision bits and accumulator precision up to 32bits | + + + +You can turn on this mode either globally on the configuration: +```python +configuration = fhe.Configuration( + ... + rounding_exactness=fhe.Exactness.APPROXIMATE +) +``` +or on/off locally: +```python +v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.APPROXIMATE) +v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.EXACT) +``` + +In approximate mode the rounding threshold up or down is not perfectly centered: +The off-centering is: +* is bounded, i.e. at worst an off-by-one on the reduced precision value compared to the exact result, +* is pseudo-random, i.e. it will be different on each call, +* almost symetrically distributed, +* depends on cryptographic properties like the encryption mask, the encryption noise and the crypto-parameters. + +| ![approximate-off-by-one-error.png](../_static/rounding/approximate-off-by-one-error.png) | +|:--:| +| *In blue the exact value, the red dots are approximate values due to off-centered transition in approximate mode.* | + +| ![approximate-off-centering-distribution.png](../_static/rounding/approximate-off-centering-distribution.png) | +|:--:| +| *Histogram of transitions off-centering delta. Each count correspond to a specific random mask and a specific encryption noise.* | + +## Approximate rounding features + +With approximate rounding, you can enable an approximate clipping to get further improve performance in the case of overflow handling. Approximate clipping enable to discard the extra bit of overflow protection bit in the successor TLU. For consistency a logical clipping is available when this optimization is not suitable. + +### Logical clipping + +When fast approximate clipping is not suitable (i.e. slower), it's better to apply logical clipping for consistency and better resilience to code change. +It has no extra cost since it's fuzed with the successor TLU. + +| ![logical-clipping.png](../_static/rounding/approximate-off-by-one-error-logical-clipping.png) | +|:--:| +| *Only the last step is clipped.* | + + +### Approximate clipping + +This set the first precision where approximate clipping is enabled, starting from this precision, an extra small precision TLU is introduced to safely remove the extra precision bit used to contain overflow. This way the successor TLU is faster. +E.g. for a rounding to 7bits, that finishes to a TLU of 8bits due to overflow, forcing to use a TLU of 7bits is 3x faster. + +| ![approximate-clipping.png](../_static/rounding/approximate-off-by-one-error-approx-clipping.png) | +|:--:| +| *The last steps are decreased.* | diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index b381037422..6e091114a5 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -9,6 +9,7 @@ from .compilation import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, + ApproximateRoundingConfig, BitwiseStrategy, Circuit, Client, @@ -18,6 +19,7 @@ Configuration, DebugArtifacts, EncryptionStatus, + Exactness, Keys, MinMaxStrategy, MultiParameterStrategy, diff --git a/frontends/concrete-python/concrete/fhe/compilation/__init__.py b/frontends/concrete-python/concrete/fhe/compilation/__init__.py index 2043e751ef..5f134c2a99 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/compilation/__init__.py @@ -9,9 +9,11 @@ from .configuration import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, + ApproximateRoundingConfig, BitwiseStrategy, ComparisonStrategy, Configuration, + Exactness, MinMaxStrategy, MultiParameterStrategy, MultivariateStrategy, diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 6065b58f09..ada658bd9f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -3,7 +3,8 @@ """ import platform -from enum import Enum +from dataclasses import dataclass +from enum import Enum, IntEnum from pathlib import Path from typing import List, Optional, Tuple, Union, get_type_hints @@ -72,6 +73,54 @@ def parse(cls, string: str) -> "MultiParameterStrategy": raise ValueError(message) +class Exactness(IntEnum): + """ + Exactness, to specify for specific operator the implementation preference (default and local). + """ + + EXACT = 0 + APPROXIMATE = 1 + + +@dataclass +class ApproximateRoundingConfig: + """ + Controls the behavior of approximate rounding. + + In the following `k` is the ideal rounding output precision. + Often the precision used after rounding is `k`+1 to avoid overflow. + `logical_clipping`, `approximate_clipping_start_precision` can be used to stay at precision `k`, + either logically or physically at the successor TLU. + See examples in https://github.com/zama-ai/concrete/blob/main/docs/tutorial/rounding.md. + """ + + logical_clipping: bool = True + """ + Enable logical clipping to simulate a precision `k` in the successor TLU of precision `k`+1. + """ + + approximate_clipping_start_precision: int = 5 + """Actively avoid the overflow using a `k`-1 precision TLU. + This is similar to logical clipping but less accurate and faster. + Effect on: + * accuracy: the upper values of the rounding range are sligtly decreased, + * cost: adds an extra `k`-1 bits TLU to guarantee that the precision after rounding is `k`. + This is usually a win when `k` >= 5 . + This is enabled by default for `k` >= 5. + Due to the extra inaccuracy and cost, it is possible to disable it completely using False.""" + + reduce_precision_after_approximate_clipping: bool = True + """Enable the reduction to `k` bits in the TLU. + Can be disabled for debugging/testing purposes. + When disabled along with logical_clipping, the result of approximate clipping is accessible. + """ + + symetrize_deltas: bool = True + """Enable asymetry of correction of deltas w.r.t. the exact rounding computation. + Can be disabled for debugging/testing purposes. + """ + + class ComparisonStrategy(str, Enum): """ ComparisonStrategy, to specify implementation preference for comparisons. @@ -931,6 +980,8 @@ class Configuration: relu_on_bits_threshold: int relu_on_bits_chunk_size: int if_then_else_chunk_size: int + rounding_exactness: Exactness + approximate_rounding_config: ApproximateRoundingConfig def __init__( self, @@ -987,6 +1038,8 @@ def __init__( relu_on_bits_threshold: int = 7, relu_on_bits_chunk_size: int = 3, if_then_else_chunk_size: int = 3, + rounding_exactness: Exactness = Exactness.EXACT, + approximate_rounding_config: Optional[ApproximateRoundingConfig] = None, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1069,6 +1122,10 @@ def __init__( self.relu_on_bits_threshold = relu_on_bits_threshold self.relu_on_bits_chunk_size = relu_on_bits_chunk_size self.if_then_else_chunk_size = if_then_else_chunk_size + self.rounding_exactness = rounding_exactness + self.approximate_rounding_config = ( + approximate_rounding_config or ApproximateRoundingConfig() + ) self._validate() @@ -1129,6 +1186,8 @@ def fork( relu_on_bits_threshold: Union[Keep, int] = KEEP, relu_on_bits_chunk_size: Union[Keep, int] = KEEP, if_then_else_chunk_size: Union[Keep, int] = KEEP, + rounding_exactness: Union[Keep, Exactness] = KEEP, + approximate_rounding_config: Union[Keep, Optional[ApproximateRoundingConfig]] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 34755781ec..bc47a41b04 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -12,6 +12,7 @@ # mypy: disable-error-code=attr-defined import concrete.compiler +import jsonpickle from concrete.compiler import ( CompilationContext, CompilationFeedback, @@ -253,7 +254,7 @@ def save(self, path: Union[str, Path], via_mlir: bool = False): f.write("1" if self.is_simulated else "0") with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f: - f.write(json.dumps(self._configuration.__dict__)) + f.write(jsonpickle.dumps(self._configuration.__dict__)) shutil.make_archive(path, "zip", tmp) @@ -300,7 +301,7 @@ def load(path: Union[str, Path]) -> "Server": mlir = f.read() with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f: - configuration = Configuration().fork(**json.load(f)) + configuration = Configuration().fork(**jsonpickle.loads(f.read())) return Server.create(mlir, configuration, is_simulated) diff --git a/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py b/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py index 48a3d9be27..c67380831d 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py +++ b/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py @@ -4,10 +4,11 @@ import threading from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np +from ..compilation.configuration import Exactness from ..dtypes import Integer from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH from ..representation import Node @@ -158,6 +159,7 @@ def round_bit_pattern( x: Union[int, np.integer, List, np.ndarray, Tracer], lsbs_to_remove: Union[int, AutoRounder], overflow_protection: bool = True, + exactness: Optional[Exactness] = None, ) -> Union[int, np.integer, List, np.ndarray, Tracer]: """ Round the bit pattern of an integer. @@ -212,6 +214,11 @@ def round_bit_pattern( overflow_protection (bool, default = True) whether to adjust bit widths and lsbs to remove to avoid overflows + exactness (Optional[Exactness], default = None) + select the exactness of the operation, None means use the global exactness. + The global exactnessdefault is EXACT. + It can be changed on the Configuration object. + Returns: Union[int, np.integer, np.ndarray, Tracer]: Tracer that respresents the operation during tracing @@ -240,6 +247,8 @@ def round_bit_pattern( def evaluator( x: Union[int, np.integer, np.ndarray], lsbs_to_remove: int, + overflow_protection: bool, # pylint: disable=unused-argument + exactness: Optional[Exactness], # pylint: disable=unused-argument ) -> Union[int, np.integer, np.ndarray]: if lsbs_to_remove == 0: return x @@ -255,8 +264,11 @@ def evaluator( [deepcopy(x.output)], deepcopy(x.output), evaluator, - kwargs={"lsbs_to_remove": lsbs_to_remove}, - attributes={"overflow_protection": overflow_protection}, + kwargs={ + "lsbs_to_remove": lsbs_to_remove, + "overflow_protection": overflow_protection, + "exactness": exactness, + }, ) return Tracer(computation, [x]) @@ -276,6 +288,6 @@ def evaluator( message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}" raise TypeError(message) - return evaluator(x, lsbs_to_remove) + return evaluator(x, lsbs_to_remove, overflow_protection, exactness) # pylint: enable=protected-access,too-many-branches diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 2d8f339381..0903e2d0ad 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -28,6 +28,7 @@ BitwiseStrategy, ComparisonStrategy, Configuration, + Exactness, MinMaxStrategy, ) from ..dtypes import Integer @@ -3144,6 +3145,8 @@ def round_bit_pattern( resulting_type: ConversionType, x: Conversion, lsbs_to_remove: int, + exactness: Exactness, + overflow_detected: bool, ) -> Conversion: if x.is_clear: highlights = { @@ -3154,19 +3157,121 @@ def round_bit_pattern( assert x.bit_width > lsbs_to_remove + if exactness is None: + exactness = self.configuration.rounding_exactness + + intermediate_bit_width = x.bit_width - lsbs_to_remove intermediate_type = self.typeof( ValueDescription( - dtype=Integer(is_signed=x.is_signed, bit_width=(x.bit_width - lsbs_to_remove)), + dtype=Integer(is_signed=x.is_signed, bit_width=intermediate_bit_width), shape=x.shape, is_encrypted=x.is_encrypted, ) ) + if exactness is Exactness.APPROXIMATE: + approx_conf = self.configuration.approximate_rounding_config + # 1. Unskew TLU's futur error distribution on approximated value + # this balances agains all leading zeros in the noise (ignoring symetric noise) + unskewed = x + if approx_conf.symetrize_deltas: + highest_supported_precision = 62 + delta_precision = highest_supported_precision - x.type.bit_width + full_precision = x.type.bit_width + delta_precision + half_in_extra_precision = ( + 1 << (delta_precision - 1) + ) - 1 # slightly smaller then half + half_in_extra_precision = self.constant( + self.i(full_precision + 1), half_in_extra_precision + ) + x_high_precision = self.reinterpret(x, bit_width=full_precision) + unskewed = self.add( + x_high_precision.type, x_high_precision, half_in_extra_precision + ) - rounded = self.operation( - fhe.RoundEintOp if x.is_scalar else fhelinalg.RoundOp, - intermediate_type, - x.result, - ) + # 2. Cancel overflow to have a TLU at exactly target_precision + # starting from 5 bits, the extra overflow bit in the TLU is too costly + # a smaller precision TLU can detect approximately the overflow to cancel it + # this is only possible because of the extra-bit from overflow protection + target_precision = x.bit_width - lsbs_to_remove - overflow_detected + if ( + overflow_detected + and target_precision >= approx_conf.approximate_clipping_start_precision + and approx_conf.approximate_clipping_start_precision is not False + ): + unskew_pre_overflow = self.reinterpret(unskewed, bit_width=x.type.bit_width) + overflow_precision = max(2, target_precision - 1) + # The last half-cell values in overflow_precision will naturally overflow. + # But there can also be an off by minus 1 to the previous cell in the worst case + # and an overflow in the successor TLU. + # We sliglty decrease the value of the rounding output on theses cells. + # `realign_cell_by` defines where the decrease starts to apply. + step_high = 1 << (x.type.bit_width - intermediate_bit_width) + step_wide = step_high + full_decrease_by = step_high + realign_cell_by = step_wide // 2 + realign_cell_by = self.constant(self.i(x.type.bit_width + 1), realign_cell_by) + overflow_candidate = self.sub( + unskew_pre_overflow.type, unskew_pre_overflow, realign_cell_by + ) + overflow_candidate = self.reinterpret( + overflow_candidate, bit_width=overflow_precision + ) + half_tlu_size = 2 ** (overflow_precision - 1) + if x.is_signed: + negative_size = half_tlu_size + positive_size = negative_size + used_positive_size = half_tlu_size // 2 + # this is oriented for precision higher than 3 + # it will work with smaller precision but with more invasive effects + prevent_overflow_positive = ( + # pre-overflow + [0] * used_positive_size + # overflow part + + [3 * full_decrease_by // 4, full_decrease_by] + # unused + + [0] * (positive_size - used_positive_size - 2) + )[:half_tlu_size] + prevent_overflow = prevent_overflow_positive + [0] * negative_size + else: + prevent_overflow = ( + # pre-overflow + [0] * half_tlu_size + # overflow part + + [3 * full_decrease_by // 4, full_decrease_by] + # unused + + [0] * (half_tlu_size - 2) + )[: 2 * half_tlu_size] + signed_type = self.to_signed(x).type + overflow_cancel = self.reinterpret( + self.tlu( + signed_type, + overflow_candidate, + table=prevent_overflow, + ), + bit_width=x.type.bit_width, + signed=x.is_signed, + ) + unskewed = self.sub(unskew_pre_overflow.type, unskew_pre_overflow, overflow_cancel) + if approx_conf.reduce_precision_after_approximate_clipping: + # a minimum bitwith 3 is required to multiply by 2 in signed case + if unskewed.bit_width < 3: + # pragma: no-cover + self.reinterpret(unskewed, bit_width=3) + unskewed = self.mul( + unskewed.type, unskewed, self.constant(self.i(unskewed.bit_width + 1), 2) + ) + rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width - 1) + # The TLU after may be adjusted to the right precision (see `Converter.tlu`) + else: + rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width) + else: + rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width) + else: + rounded = self.operation( + fhe.RoundEintOp if x.is_scalar else fhelinalg.RoundOp, + intermediate_type, + x.result, + ) return self.to_signedness(rounded, of=resulting_type) @@ -3594,13 +3699,16 @@ def truncate_bit_pattern(self, x: Conversion, lsbs_to_remove: int) -> Conversion return x - def reinterpret(self, x: Conversion, *, bit_width: int) -> Conversion: + def reinterpret( + self, x: Conversion, *, bit_width: int, signed: Optional[bool] = None + ) -> Conversion: assert x.is_encrypted if x.bit_width == bit_width: return x - resulting_element_type = (self.eint if x.is_unsigned else self.esint)(bit_width) + result_signed = x.is_unsigned if signed is None else signed + resulting_element_type = (self.eint if result_signed else self.esint)(bit_width) resulting_type = self.tensor(resulting_element_type, shape=x.shape) operation = ( diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 66f9f59e7e..abee8adc39 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -4,6 +4,7 @@ # pylint: disable=import-error,no-name-in-module +import math import sys from typing import Dict, List, Tuple, Union @@ -17,8 +18,7 @@ from mlir.ir import Location as MlirLocation from mlir.ir import Module as MlirModule -from concrete.fhe.compilation.configuration import Configuration - +from ..compilation.configuration import Configuration, Exactness from ..representation import Graph, Node, Operation from .context import Context from .conversion import Conversion @@ -195,7 +195,9 @@ def process(self, graph: Graph, configuration: Configuration): multivariate_strategy_preference=configuration.multivariate_strategy_preference, min_max_strategy_preference=configuration.min_max_strategy_preference, ), - ProcessRounding(), + ProcessRounding( + rounding_exactness=configuration.rounding_exactness, + ), ] for processor in pipeline: @@ -485,9 +487,9 @@ def round_bit_pattern(self, ctx: Context, node: Node, preds: List[Conversion]) - assert len(preds) == 1 pred = preds[0] + overflow_detected = node.properties["overflow_detected"] if pred.is_encrypted and pred.bit_width != pred.original_bit_width: overflow_protection = node.properties["overflow_protection"] - overflow_detected = node.properties["overflow_detected"] shifter = 2 ** (pred.bit_width - pred.original_bit_width) if overflow_protection and overflow_detected: @@ -500,6 +502,8 @@ def round_bit_pattern(self, ctx: Context, node: Node, preds: List[Conversion]) - ctx.typeof(node), pred, node.properties["final_lsbs_to_remove"], + node.properties["exactness"], + overflow_detected, ) def subtract(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: @@ -531,6 +535,61 @@ def squeeze(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversi # otherwise, a simple reshape would work as we already have the correct shape return ctx.reshape(preds[0], shape=node.output.shape) + @classmethod + def tlu_adjust(cls, table, variable_input, target_bit_width, clipping, reduce_precision): + target_bit_width = min( + variable_input.bit_width, target_bit_width + ) # inconsistency due to more precise bound vs precision + table_bit_width = math.log2(len(table)) + assert table_bit_width.is_integer() + table_bit_width = int(table_bit_width) + table_has_right_size = variable_input.bit_width == table_bit_width + if table_has_right_size and not clipping: + return table + half_rounded_bit_width = target_bit_width - 1 + if variable_input.is_signed: + # upper = positive part, lower = negative part + upper_clipping_index = 2**half_rounded_bit_width - 1 + lower_clipping_index = 2**table_bit_width - 2**half_rounded_bit_width + positive_clipped_card = 2 ** (table_bit_width - 1) - upper_clipping_index - 1 + negative_clipped_card = 2 ** (table_bit_width - 1) - 2**half_rounded_bit_width + else: + upper_clipping_index = 2**target_bit_width - 1 + lower_clipping_index = 0 + positive_clipped_card = 2**table_bit_width - upper_clipping_index - 1 + lower_clipping = table[lower_clipping_index] + upper_clipping = table[upper_clipping_index] + if table_has_right_size: + # value clipping + assert clipping + if variable_input.is_signed: + table = ( + list(table[: upper_clipping_index + 1]) + + [upper_clipping] * positive_clipped_card + + [lower_clipping] * negative_clipped_card + + list(table[lower_clipping_index:]) + ) + else: + table = ( + list(table[lower_clipping_index : upper_clipping_index + 1]) + + [upper_clipping] * positive_clipped_card + ) + assert len(table) == 2**table_bit_width, ( + len(table), + 2**table_bit_width, + table, + upper_clipping, + lower_clipping, + ) + return np.array(table, dtype=np.uint64) # negative value are in unsigned representation + + # adjust tlu size + assert reduce_precision + if variable_input.is_signed: + return np.concatenate((table[: upper_clipping_index + 1], table[lower_clipping_index:])) + + return table[lower_clipping_index : upper_clipping_index + 1] + def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert node.converted_to_table_lookup @@ -654,6 +713,33 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: variable_input = ctx.mul(variable_input.type, variable_input, shifter) variable_input = ctx.reinterpret(variable_input, bit_width=truncated_bit_width) + elif variable_input.origin.properties.get("name") == "round_bit_pattern": + exactness = ( + variable_input.origin.properties["exactness"] + or ctx.configuration.rounding_exactness + ) + if exactness == Exactness.APPROXIMATE: + # we clip values to enforce input precision exactly as queried + original_bit_width = variable_input.origin.properties["original_bit_width"] + lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"] + overflow = variable_input.origin.properties["overflow_detected"] + rounded_bit_width = original_bit_width - lsbs_to_remove - overflow + approx_config = ctx.configuration.approximate_rounding_config + clipping = approx_config.logical_clipping + reduce_precision = approx_config.reduce_precision_after_approximate_clipping + if len(tables) == 1: + lut_values = self.tlu_adjust( + lut_values, variable_input, rounded_bit_width, clipping, reduce_precision + ) + else: + for sub_i, sub_lut_values in enumerate(lut_values): + lut_values[sub_i] = self.tlu_adjust( + sub_lut_values, + variable_input, + rounded_bit_width, + clipping, + reduce_precision, + ) if len(tables) == 1: return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist()) diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index 3a997f184a..ee927e2c0c 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -231,7 +231,7 @@ def some_inputs_are_clear(self, node: Node, preds: List[Node]) -> bool: return any(pred.output.is_clear for pred in preds) def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool: - return node.properties["attributes"]["overflow_protection"] is True + return node.properties["kwargs"]["overflow_protection"] is True # =========== # Constraints diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py index 6aace8bdab..c115075697 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py @@ -8,6 +8,7 @@ import numpy as np +from ...compilation.configuration import Exactness from ...dtypes import Integer from ...extensions.table import LookupTable from ...representation import Graph, Node @@ -19,6 +20,14 @@ class ProcessRounding(GraphProcessor): ProcessRounding graph processor, to analyze rounding and support regular operations on it. """ + rounding_exactness: Exactness + + def __init__( + self, + rounding_exactness: Exactness, + ): + self.rounding_exactness = rounding_exactness + def apply(self, graph: Graph): rounding_nodes = graph.query_nodes(operation_filter="round_bit_pattern") for node in rounding_nodes: @@ -27,8 +36,13 @@ def apply(self, graph: Graph): original_lsbs_to_remove = node.properties["kwargs"]["lsbs_to_remove"] final_lsbs_to_remove = node.properties["final_lsbs_to_remove"] + exactness = node.properties["exactness"] + if exactness is None: + exactness = self.rounding_exactness + if original_lsbs_to_remove != 0 and final_lsbs_to_remove == 0: - self.replace_with_tlu(graph, node) + if exactness != Exactness.APPROXIMATE: + self.replace_with_tlu(graph, node) continue self.process_successors(graph, node) @@ -44,12 +58,14 @@ def process_predecessors(self, graph: Graph, node: Node): pred = preds[0] assert isinstance(pred.output.dtype, Integer) - overflow_protection = node.properties["attributes"]["overflow_protection"] + exactness = node.properties["kwargs"]["exactness"] + overflow_protection = node.properties["kwargs"]["overflow_protection"] overflow_detected = ( overflow_protection and pred.properties["original_bit_width"] != node.properties["original_bit_width"] ) + node.properties["exactness"] = exactness node.properties["overflow_protection"] = overflow_protection node.properties["overflow_detected"] = overflow_detected diff --git a/frontends/concrete-python/mypy.ini b/frontends/concrete-python/mypy.ini index 35159d2f17..62bf5c4508 100644 --- a/frontends/concrete-python/mypy.ini +++ b/frontends/concrete-python/mypy.ini @@ -1,3 +1,4 @@ [mypy] plugins = numpy.typing.mypy_plugin disable_error_code = annotation-unchecked +allow_redefinition = True diff --git a/frontends/concrete-python/requirements.txt b/frontends/concrete-python/requirements.txt index 364f7d1927..1d612353fd 100644 --- a/frontends/concrete-python/requirements.txt +++ b/frontends/concrete-python/requirements.txt @@ -1,4 +1,5 @@ importlib-resources>=6.1 +jsonpickle>=3.0.3 networkx>=2.6 numpy>=1.23 scipy>=1.10 diff --git a/frontends/concrete-python/tests/execution/test_round_bit_pattern.py b/frontends/concrete-python/tests/execution/test_round_bit_pattern.py index 964472aa06..045b3327a0 100644 --- a/frontends/concrete-python/tests/execution/test_round_bit_pattern.py +++ b/frontends/concrete-python/tests/execution/test_round_bit_pattern.py @@ -6,6 +6,7 @@ import pytest from concrete import fhe +from concrete.fhe.compilation.configuration import Exactness from concrete.fhe.representation.utils import format_constant @@ -432,14 +433,14 @@ def function3(x): helpers.check_str( f""" -%0 = x # EncryptedScalar -%1 = round_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar -%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar -%3 = round_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar -%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar +%0 = x # EncryptedScalar +%1 = round_bit_pattern(%0, lsbs_to_remove=3, overflow_protection=True, exactness=None) # EncryptedScalar +%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar +%3 = round_bit_pattern(%2, lsbs_to_remove=2, overflow_protection=True, exactness=None) # EncryptedScalar +%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar return %4 - """, + """, # noqa: E501 str(circuit3.graph.format(show_bounds=False)), ) @@ -611,3 +612,249 @@ def function(x): for x in inputset: helpers.check_execution(circuit, function, x, retries=3) + + +def test_round_bit_pattern_approximate_enabling(helpers): + """ + Test round bit pattern various activation paths. + """ + + @fhe.compiler({"x": "encrypted"}) + def function_default(x): + return fhe.round_bit_pattern(x, lsbs_to_remove=8) + + @fhe.compiler({"x": "encrypted"}) + def function_exact(x): + return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.EXACT) + + @fhe.compiler({"x": "encrypted"}) + def function_approx(x): + return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.APPROXIMATE) + + inputset = [-(2**10), 2**10 - 1] + configuration = helpers.configuration() + + circuit_default_default = function_default.compile(inputset, configuration) + circuit_default_exact = function_default.compile( + inputset, configuration.fork(rounding_exactness=Exactness.EXACT) + ) + circuit_default_approx = function_default.compile( + inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE) + ) + circuit_exact = function_exact.compile( + inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE) + ) + circuit_approx = function_approx.compile( + inputset, configuration.fork(rounding_exactness=Exactness.EXACT) + ) + + assert circuit_approx.complexity < circuit_exact.complexity + assert circuit_exact.complexity == circuit_default_default.complexity + assert circuit_exact.complexity == circuit_default_exact.complexity + assert circuit_approx.complexity == circuit_default_approx.complexity + + +@pytest.mark.parametrize( + "accumulator_precision,reduced_precision,signed,conf", + [ + (8, 4, True, fhe.ApproximateRoundingConfig(False, 4)), + (7, 4, False, fhe.ApproximateRoundingConfig(False, 4)), + (9, 3, True, fhe.ApproximateRoundingConfig(True, False)), + (8, 3, False, fhe.ApproximateRoundingConfig(True, False)), + (7, 3, False, fhe.ApproximateRoundingConfig(True, 3)), + (7, 2, True, fhe.ApproximateRoundingConfig(False, 2)), + (7, 2, False, fhe.ApproximateRoundingConfig(False, False, False, False)), + (8, 1, True, fhe.ApproximateRoundingConfig(False, 1)), + (8, 1, False, fhe.ApproximateRoundingConfig(True, False)), + (6, 5, False, fhe.ApproximateRoundingConfig(True, 6)), + (6, 5, False, fhe.ApproximateRoundingConfig(True, 5)), + ], +) +def test_round_bit_pattern_approximate_off_by_one_errors( + accumulator_precision, reduced_precision, signed, conf, helpers +): + """ + Test round bit pattern off by 1 errors. + """ + lsbs_to_remove = accumulator_precision - reduced_precision + + @fhe.compiler({"x": "encrypted"}) + def function(x): + x = fhe.univariate(lambda x: x)(x) + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove) + x = x // 2**lsbs_to_remove + return x + + if signed: + inputset = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1] + else: + inputset = [0, 2**accumulator_precision - 1] + + configuration = helpers.configuration() + circuit_exact = function.compile(inputset, configuration) + circuit_approx = function.compile( + inputset, + configuration.fork( + approximate_rounding_config=conf, rounding_exactness=Exactness.APPROXIMATE + ), + ) + # check it's better even with bad conf + assert circuit_approx.complexity < circuit_exact.complexity + + testset = range(*inputset) + + nb_error = 0 + for x in testset: + approx = circuit_approx.encrypt_run_decrypt(x) + approx_simu = circuit_approx.simulate(x) + exact = circuit_exact.simulate(x) + assert abs(approx_simu - exact) <= 1 + assert abs(approx_simu - approx) <= 1 + delta = abs(approx - approx_simu) + assert delta <= 1 + nb_error += delta > 0 + + nb_transitions = 2 ** (accumulator_precision - reduced_precision) + assert nb_error <= 3 * nb_transitions # of the same order as transitions but small sample size + + +@pytest.mark.parametrize( + "signed,physical", + [(signed, physical) for signed in (True, False) for physical in (True, False)], +) +def test_round_bit_pattern_approximate_clippping(signed, physical, helpers): + """ + Test round bit pattern clipping. + """ + accumulator_precision = 6 + reduced_precision = 3 + lsbs_to_remove = accumulator_precision - reduced_precision + + @fhe.compiler({"x": "encrypted"}) + def function(x): + x = fhe.univariate(lambda x: x)(x) + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove) + x = x // 2**lsbs_to_remove + return x + + if signed: + input_domain = range(-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1)) + else: + input_domain = range(0, 2 ** (accumulator_precision)) + + configuration = helpers.configuration() + approx_conf = fhe.ApproximateRoundingConfig( + logical_clipping=not physical, + approximate_clipping_start_precision=physical and reduced_precision, + reduce_precision_after_approximate_clipping=False, + ) + no_clipping_conf = fhe.ApproximateRoundingConfig( + logical_clipping=False, approximate_clipping_start_precision=False + ) + assert approx_conf.logical_clipping or approx_conf.approximate_clipping_start_precision + circuit_clipping = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf, rounding_exactness=Exactness.APPROXIMATE + ), + ) + circuit_no_clipping = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=no_clipping_conf, rounding_exactness=Exactness.APPROXIMATE + ), + ) + + if signed: + clipped_output_domain = range(-(2 ** (reduced_precision - 1)), 2 ** (reduced_precision - 1)) + else: + clipped_output_domain = range(0, 2**reduced_precision) + + # With clipping + for x in input_domain: + assert ( + circuit_clipping.encrypt_run_decrypt(x) in clipped_output_domain + ), circuit_clipping.mlir # no overflow + assert circuit_clipping.simulate(x) in clipped_output_domain + + # Without clipping + # overflow + assert circuit_no_clipping.simulate(input_domain[-1]) not in clipped_output_domain + + +@pytest.mark.parametrize( + "signed,accumulator_precision", + [ + (signed, accumulator_precision) + for signed in (True, False) + for accumulator_precision in (13, 24) + ], +) +def test_round_bit_pattern_approximate_acc_to_6_costs(signed, accumulator_precision, helpers): + """ + Test round bit pattern speedup when approximatipn is activated. + """ + reduced_precision = 6 + lsbs_to_remove = accumulator_precision - reduced_precision + + @fhe.compiler({"x": "encrypted"}) + def function(x): + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=True) + x = x // 2**lsbs_to_remove + return x + + # with overflow + if signed: + input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1] + else: + input_domain = [0, 2 ** (accumulator_precision) - 1] + + configuration = helpers.configuration().fork(composable=True) + circuit_exact = function.compile(input_domain, configuration) + approx_conf_fastest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=6) + approx_conf_safest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=100) + circuit_approx_fastest = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_fastest, + rounding_exactness=Exactness.APPROXIMATE, + ), + ) + circuit_approx_safest = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE + ), + ) + assert circuit_approx_safest.complexity < circuit_exact.complexity + assert circuit_approx_fastest.complexity < circuit_approx_safest.complexity + + @fhe.compiler({"x": "encrypted"}) + def function(x): # pylint: disable=function-redefined + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=False) + x = x // 2**lsbs_to_remove + return x + + # without overflow + if signed: + input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 2) - 2] + else: + input_domain = [0, 2 ** (accumulator_precision - 1) - 2] + + circuit_exact_no_ovf = function.compile(input_domain, configuration) + circuit_approx_fastest_no_ovf = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_fastest, + rounding_exactness=Exactness.APPROXIMATE, + ), + ) + circuit_approx_safest_no_ovf = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE + ), + ) + assert circuit_approx_fastest_no_ovf.complexity == circuit_approx_safest_no_ovf.complexity + assert circuit_approx_safest_no_ovf.complexity < circuit_exact_no_ovf.complexity + assert circuit_exact_no_ovf.complexity < circuit_exact.complexity diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index 99369f710e..7426934130 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -512,11 +512,11 @@ def assign(x, y): Function you are trying to compile cannot be compiled -%0 = x # ClearScalar ∈ [10, 30] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear -%1 = round_bit_pattern(%0, lsbs_to_remove=2) # ClearScalar ∈ [12, 32] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported -%2 = reinterpret(%1) # ClearScalar +%0 = x # ClearScalar ∈ [10, 30] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear +%1 = round_bit_pattern(%0, lsbs_to_remove=2, overflow_protection=True, exactness=None) # ClearScalar ∈ [12, 32] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported +%2 = reinterpret(%1) # ClearScalar return %2 """, # noqa: E501