Skip to content

Commit

Permalink
feat(frontend-python): approximate mode for round_bit_pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Feb 26, 2024
1 parent bda568a commit 3b50c70
Show file tree
Hide file tree
Showing 20 changed files with 645 additions and 36 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/rounding/approximate-speedup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions docs/howto/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,12 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* Chunk size of the ReLU extension when [fhe.bits](../tutorial/bit_extraction.md) implementation is used.
* **if_then_else_chunk_size**: int = 3
* Chunk size to use when converting `fhe.if_then_else` extension.
* **rounding_exactness** : Exactness = `fhe.Exactness.EXACT`
* Set default exactness mode for the rounding operation:
* `EXACT`: threshold for rounding up or down is exactly centered between upper and lower value,
* `APPROXIMATE`: faster but threshold for rounding up or down is approximately centered with pseudo-random shift.
* Precise and more complete behavior is described in [fhe.rounding_bit_pattern](../tutorial/rounding.md).
* **approximate_rounding_config** : ApproximateRoundingConfig = `fhe.ApproximateRoundingConfig()`:
* Provide more fine control on [approximate rounding](../tutorial/rounding.md#approximate-rounding-features):
* to enable exact cliping,
* or/and approximate clipping which make overflow protection faster.
71 changes: 68 additions & 3 deletions docs/tutorial/rounding.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ prints:

and displays:

![](../\_static/rounding/identity.png)
![](../_static/rounding/identity.png)

{% hint style="info" %}
If the rounded number is one of the last `2**(lsbs_to_remove - 1)` numbers in the input range `[0, 2**original_bit_width)`, an overflow **will** happen.
Expand Down Expand Up @@ -194,7 +194,7 @@ The reason why the speed-up is not increasing with `lsbs_to_remove` is because t

and displays:

![](../\_static/rounding/lsbs_to_remove.png)
![](../_static/rounding/lsbs_to_remove.png)

{% hint style="info" %}
Feel free to disable overflow protection and see what happens.
Expand Down Expand Up @@ -289,8 +289,73 @@ target_msbs=1 => 2.34x speedup

and displays:

![](../\_static/rounding/msbs_to_keep.png)
![](../_static/rounding/msbs_to_keep.png)

{% hint style="warning" %}
`AutoRounder`s should be defined outside the function that is being compiled. They are used to store the result of the adjustment process, so they shouldn't be created each time the function is called. Furthermore, each `AutoRounder` should be used with exactly one `round_bit_pattern` call.
{% endhint %}


## Exactness

One use of rounding is doing faster computation by ignoring the lower significant bits.
For this usage, you can even get faster results if you accept the rounding it-self to be slighlty inexact.
The speedup is usually around 2x-3x but can be higher for big precision reduction.
This also enable higher precisions values that are not possible otherwise.

| ![approximate-speedup.png](../_static/rounding/approximate-speedup.png) |
|:--:|
| *Using the default configuration in approximate mode. For 3, 4, 5 and 6 reduced precision bits and accumulator precision up to 32bits |



You can turn on this mode either globally on the configuration:
```python
configuration = fhe.Configuration(
...
rounding_exactness=fhe.Exactness.APPROXIMATE
)
```
or on/off locally:
```python
v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.APPROXIMATE)
v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.EXACT)
```

In approximate mode the rounding threshold up or down is not perfectly centered:
The off-centering is:
* is bounded, i.e. at worst an off-by-one on the reduced precision value compared to the exact result,
* is pseudo-random, i.e. it will be different on each call,
* almost symetrically distributed,
* depends on cryptographic properties like the encryption mask, the encryption noise and the crypto-parameters.

| ![approximate-off-by-one-error.png](../_static/rounding/approximate-off-by-one-error.png) |
|:--:|
| *In blue the exact value, the red dots are approximate values due to off-centered transition in approximate mode.* |

| ![approximate-off-centering-distribution.png](../_static/rounding/approximate-off-centering-distribution.png) |
|:--:|
| *Histogram of transitions off-centering delta. Each count correspond to a specific random mask and a specific encryption noise.* |

## Approximate rounding features

With approximate rounding, you can enable an approximate clipping to get further improve performance in the case of overflow handling. Approximate clipping enable to discard the extra bit of overflow protection bit in the successor TLU. For consistency a logical clipping is available when this optimization is not suitable.

### Logical clipping

When fast approximate clipping is not suitable (i.e. slower), it's better to apply logical clipping for consistency and better resilience to code change.
It has no extra cost since it's fuzed with the successor TLU.

| ![logical-clipping.png](../_static/rounding/approximate-off-by-one-error-logical-clipping.png) |
|:--:|
| *Only the last step is clipped.* |


### Approximate clipping

This set the first precision where approximate clipping is enabled, starting from this precision, an extra small precision TLU is introduced to safely remove the extra precision bit used to contain overflow. This way the successor TLU is faster.
E.g. for a rounding to 7bits, that finishes to a TLU of 8bits due to overflow, forcing to use a TLU of 7bits is 3x faster.

| ![approximate-clipping.png](../_static/rounding/approximate-off-by-one-error-approx-clipping.png) |
|:--:|
| *The last steps are decreased.* |
2 changes: 2 additions & 0 deletions frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .compilation import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
ApproximateRoundingConfig,
BitwiseStrategy,
Circuit,
Client,
Expand All @@ -18,6 +19,7 @@
Configuration,
DebugArtifacts,
EncryptionStatus,
Exactness,
Keys,
MinMaxStrategy,
MultiParameterStrategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
ApproximateRoundingConfig,
BitwiseStrategy,
ComparisonStrategy,
Configuration,
Exactness,
MinMaxStrategy,
MultiParameterStrategy,
MultivariateStrategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

import platform
from enum import Enum
from dataclasses import dataclass
from enum import Enum, IntEnum
from pathlib import Path
from typing import List, Optional, Tuple, Union, get_type_hints

Expand Down Expand Up @@ -72,6 +73,54 @@ def parse(cls, string: str) -> "MultiParameterStrategy":
raise ValueError(message)


class Exactness(IntEnum):
"""
Exactness, to specify for specific operator the implementation preference (default and local).
"""

EXACT = 0
APPROXIMATE = 1


@dataclass
class ApproximateRoundingConfig:
"""
Controls the behavior of approximate rounding.
In the following `k` is the ideal rounding output precision.
Often the precision used after rounding is `k`+1 to avoid overflow.
`logical_clipping`, `approximate_clipping_start_precision` can be used to stay at precision `k`,
either logically or physically at the successor TLU.
See examples in https://github.com/zama-ai/concrete/blob/main/docs/tutorial/rounding.md.
"""

logical_clipping: bool = True
"""
Enable logical clipping to simulate a precision `k` in the successor TLU of precision `k`+1.
"""

approximate_clipping_start_precision: int = 5
"""Actively avoid the overflow using a `k`-1 precision TLU.
This is similar to logical clipping but less accurate and faster.
Effect on:
* accuracy: the upper values of the rounding range are sligtly decreased,
* cost: adds an extra `k`-1 bits TLU to guarantee that the precision after rounding is `k`.
This is usually a win when `k` >= 5 .
This is enabled by default for `k` >= 5.
Due to the extra inaccuracy and cost, it is possible to disable it completely using False."""

reduce_precision_after_approximate_clipping: bool = True
"""Enable the reduction to `k` bits in the TLU.
Can be disabled for debugging/testing purposes.
When disabled along with logical_clipping, the result of approximate clipping is accessible.
"""

symetrize_deltas: bool = True
"""Enable asymetry of correction of deltas w.r.t. the exact rounding computation.
Can be disabled for debugging/testing purposes.
"""


class ComparisonStrategy(str, Enum):
"""
ComparisonStrategy, to specify implementation preference for comparisons.
Expand Down Expand Up @@ -931,6 +980,8 @@ class Configuration:
relu_on_bits_threshold: int
relu_on_bits_chunk_size: int
if_then_else_chunk_size: int
rounding_exactness: Exactness
approximate_rounding_config: ApproximateRoundingConfig

def __init__(
self,
Expand Down Expand Up @@ -987,6 +1038,8 @@ def __init__(
relu_on_bits_threshold: int = 7,
relu_on_bits_chunk_size: int = 3,
if_then_else_chunk_size: int = 3,
rounding_exactness: Exactness = Exactness.EXACT,
approximate_rounding_config: Optional[ApproximateRoundingConfig] = None,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1069,6 +1122,10 @@ def __init__(
self.relu_on_bits_threshold = relu_on_bits_threshold
self.relu_on_bits_chunk_size = relu_on_bits_chunk_size
self.if_then_else_chunk_size = if_then_else_chunk_size
self.rounding_exactness = rounding_exactness
self.approximate_rounding_config = (
approximate_rounding_config or ApproximateRoundingConfig()
)

self._validate()

Expand Down Expand Up @@ -1129,6 +1186,8 @@ def fork(
relu_on_bits_threshold: Union[Keep, int] = KEEP,
relu_on_bits_chunk_size: Union[Keep, int] = KEEP,
if_then_else_chunk_size: Union[Keep, int] = KEEP,
rounding_exactness: Union[Keep, Exactness] = KEEP,
approximate_rounding_config: Union[Keep, Optional[ApproximateRoundingConfig]] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
5 changes: 3 additions & 2 deletions frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# mypy: disable-error-code=attr-defined
import concrete.compiler
import jsonpickle
from concrete.compiler import (
CompilationContext,
CompilationFeedback,
Expand Down Expand Up @@ -253,7 +254,7 @@ def save(self, path: Union[str, Path], via_mlir: bool = False):
f.write("1" if self.is_simulated else "0")

with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._configuration.__dict__))
f.write(jsonpickle.dumps(self._configuration.__dict__))

shutil.make_archive(path, "zip", tmp)

Expand Down Expand Up @@ -300,7 +301,7 @@ def load(path: Union[str, Path]) -> "Server":
mlir = f.read()

with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**json.load(f))
configuration = Configuration().fork(**jsonpickle.loads(f.read()))

return Server.create(mlir, configuration, is_simulated)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import threading
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np

from ..compilation.configuration import Exactness
from ..dtypes import Integer
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from ..representation import Node
Expand Down Expand Up @@ -158,6 +159,7 @@ def round_bit_pattern(
x: Union[int, np.integer, List, np.ndarray, Tracer],
lsbs_to_remove: Union[int, AutoRounder],
overflow_protection: bool = True,
exactness: Optional[Exactness] = None,
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
"""
Round the bit pattern of an integer.
Expand Down Expand Up @@ -212,6 +214,11 @@ def round_bit_pattern(
overflow_protection (bool, default = True)
whether to adjust bit widths and lsbs to remove to avoid overflows
exactness (Optional[Exactness], default = None)
select the exactness of the operation, None means use the global exactness.
The global exactnessdefault is EXACT.
It can be changed on the Configuration object.
Returns:
Union[int, np.integer, np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
Expand Down Expand Up @@ -240,6 +247,8 @@ def round_bit_pattern(
def evaluator(
x: Union[int, np.integer, np.ndarray],
lsbs_to_remove: int,
overflow_protection: bool, # pylint: disable=unused-argument
exactness: Optional[Exactness], # pylint: disable=unused-argument
) -> Union[int, np.integer, np.ndarray]:
if lsbs_to_remove == 0:
return x
Expand All @@ -255,8 +264,11 @@ def evaluator(
[deepcopy(x.output)],
deepcopy(x.output),
evaluator,
kwargs={"lsbs_to_remove": lsbs_to_remove},
attributes={"overflow_protection": overflow_protection},
kwargs={
"lsbs_to_remove": lsbs_to_remove,
"overflow_protection": overflow_protection,
"exactness": exactness,
},
)
return Tracer(computation, [x])

Expand All @@ -276,6 +288,6 @@ def evaluator(
message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}"
raise TypeError(message)

return evaluator(x, lsbs_to_remove)
return evaluator(x, lsbs_to_remove, overflow_protection, exactness)

# pylint: enable=protected-access,too-many-branches
Loading

0 comments on commit 3b50c70

Please sign in to comment.