Skip to content

Commit

Permalink
test(frontend): test tfhers integers
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed May 28, 2024
1 parent a3c8554 commit 6f35a8b
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 4 deletions.
5 changes: 5 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,11 @@ def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion])
# we reshape so that we can concatenate later over the last dim (ciphertext dim)
reshaped_native_int = ctx.reshape(native_int, native_int.shape + (1,))

# TODO: remove this when we want to optimize computation so that we don't compute
# on empty ciphertexts, based on the bit_width assignment. (e.g. if onlt two lsb
# ciphertexts are used, then we don't want to extract bits from the remaining ones)
reshaped_native_int.set_original_bit_width(input_bit_width)

# we want to extract `msg_width` bits at a time, and store them
# in a `msg_width + carry_width` bits eint
bits_shape = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape)
Expand Down
8 changes: 7 additions & 1 deletion frontends/concrete-python/concrete/fhe/tfhers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@
uint16_2_2,
)
from .tracing import from_native, to_native
from .values import int8_2_2_value, int16_2_2_value, uint8_2_2_value, uint16_2_2_value
from .values import (
TFHERSInteger,
int8_2_2_value,
int16_2_2_value,
uint8_2_2_value,
uint16_2_2_value,
)
58 changes: 57 additions & 1 deletion frontends/concrete-python/concrete/fhe/tfhers/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

from functools import partial
from typing import Any
from typing import Any, Union

import numpy as np

from ..dtypes import Integer

Expand Down Expand Up @@ -35,6 +37,60 @@ def __str__(self) -> str:
f"{self.bit_width}, {self.carry_width}, {self.msg_width}>"
)

def encode(self, value: Union[int, np.integer, np.ndarray]) -> np.ndarray:
"""Encode a scalar or tensor to tfhers integers.
Args:
value (Union[int, np.ndarray]): scalar or tensor of integer to encode
Raises:
TypeError: wrong value type
Returns:
np.ndarray: encoded scalar or tensor
"""
bit_width = self.bit_width
msg_width = self.msg_width
if isinstance(value, (int, np.integer)):
value_bin = bin(value)[2:].zfill(bit_width)
# msb first
return np.array(
[int(value_bin[i : i + msg_width], 2) for i in range(0, bit_width, msg_width)]
)
if isinstance(value, np.ndarray):
return np.array([self.encode(int(v)) for v in value.flatten()]).reshape(
value.shape + (bit_width // msg_width,)
)
msg = f"can only encode int or ndarray, but got {type(value)}"
raise TypeError(msg)

def decode(self, value: np.ndarray) -> Union[int, np.ndarray]:
"""Decode a tfhers-encoded integer (scalar or tensor).
Args:
value (np.ndarray): encoded value
Raises:
ValueError: bad encoding
Returns:
Union[int, np.ndarray]: decoded value
"""
bit_width = self.bit_width
msg_width = self.msg_width
expected_ct_shape = bit_width // msg_width
if value.shape[-1] != expected_ct_shape:
msg = (
f"bad encoding: expected value with last shape being {expected_ct_shape} "
f"but got {value.shape[-1]}"
)
raise ValueError(msg)
if len(value.shape) == 1:
# reversed because it's msb first and we are computing powers lsb first
return sum(v << i * msg_width for i, v in enumerate(reversed(value)))
cts = value.reshape((-1, expected_ct_shape))
return np.array([self.decode(ct) for ct in cts]).reshape(value.shape[:-1])


int8 = partial(TFHERSIntegerType, True, 8)
uint8 = partial(TFHERSIntegerType, False, 8)
Expand Down
4 changes: 2 additions & 2 deletions frontends/concrete-python/concrete/fhe/tfhers/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
msg = f"got error while trying to convert list value into a numpy array: {e}"
raise ValueError(msg) from e

if isinstance(value, int):
if isinstance(value, (int, np.integer)):
self._shape = ()
elif isinstance(value, np.ndarray):
if value.max() > dtype.max():
Expand All @@ -40,7 +40,7 @@ def __init__(
raise ValueError(msg)
self._shape = value.shape
else:
msg = "value can either be an int or ndarray"
msg = f"value can either be an int or ndarray, not a {type(value)}"
raise TypeError(msg)

self._value = value
Expand Down
195 changes: 195 additions & 0 deletions frontends/concrete-python/tests/execution/test_tfhers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""
Tests execution of tfhers conversion operations.
"""

import pytest

from concrete import fhe
from concrete.fhe import tfhers


def binary_tfhers(x, y, binary_op, tfhers_type):
"""wrap binary op in tfhers conversion (2 tfhers inputs)"""
x = tfhers.to_native(x)
y = tfhers.to_native(y)
return tfhers.from_native(binary_op(x, y), tfhers_type)


def one_tfhers_one_native(x, y, binary_op, tfhers_type):
"""wrap binary op in tfhers conversion (1 tfhers, 1 native input)"""
x = tfhers.to_native(x)
return tfhers.from_native(binary_op(x, y), tfhers_type)


@pytest.mark.parametrize(
"function, parameters, dtype",
[
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [0, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**14], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y big values",
),
pytest.param(
lambda x, y: x - y,
{
"x": {"range": [2**10, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**10], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x - y",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.uint8_2_2,
id="x * y",
),
],
)
def test_tfhers_conversion_binary_encrypted(
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
):
"""
Test different operations wrapped by tfhers conversion (2 tfhers inputs).
"""

parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()

compiler = fhe.Compiler(
lambda x, y: binary_tfhers(x, y, function, dtype),
parameter_encryption_statuses,
)

inputset = [
tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt)
for inpt in helpers.generate_inputset(parameters)
]
circuit = compiler.compile(inputset, configuration)

sample = helpers.generate_sample(parameters)
encoded_sample = (dtype.encode(v) for v in sample)
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)

assert (dtype.decode(encoded_result) == function(*sample)).all()


@pytest.mark.parametrize(
"function, parameters, dtype",
[
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [0, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**14], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [0, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**14], "status": "clear"},
},
tfhers.uint16_2_2,
id="x + clear(y)",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y big values",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
"y": {"range": [2**14, 2**15 - 1], "status": "clear"},
},
tfhers.uint16_2_2,
id="x + clear(y) big values",
),
pytest.param(
lambda x, y: x - y,
{
"x": {"range": [2**10, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**10], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x - y",
),
pytest.param(
lambda x, y: x - y,
{
"x": {"range": [2**10, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**10], "status": "clear"},
},
tfhers.uint16_2_2,
id="x - clear(y)",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.uint8_2_2,
id="x * y",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "clear"},
},
tfhers.uint8_2_2,
id="x * clear(y)",
),
],
)
def test_tfhers_conversion_one_encrypted_one_native(
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
):
"""
Test different operations wrapped by tfhers conversion (1 tfhers, 1 native input).
"""

parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()

compiler = fhe.Compiler(
lambda x, y: one_tfhers_one_native(x, y, function, dtype),
parameter_encryption_statuses,
)

inputset = [
(tfhers.TFHERSInteger(dtype, inpt[0]), inpt[1])
for inpt in helpers.generate_inputset(parameters)
]
circuit = compiler.compile(inputset, configuration)

sample = helpers.generate_sample(parameters)
encoded_sample = (dtype.encode(sample[0]), sample[1])
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)

assert (dtype.decode(encoded_result) == function(*sample)).all()

0 comments on commit 6f35a8b

Please sign in to comment.