Skip to content

Commit

Permalink
feat(frontend-python): support dynamic table lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Sep 18, 2023
1 parent 5d3e4bb commit 6b743c1
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 3 deletions.
31 changes: 31 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,37 @@ def dot(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> C

return self.operation(operation, resulting_type, x.result, y.result)

def dynamic_tlu(
self,
resulting_type: ConversionType,
on: Conversion,
table: Conversion,
) -> Conversion:
assert table.is_clear and on.is_encrypted

if table.shape != (2**on.bit_width,):
highlights: Dict[Node, Union[str, List[str]]] = {
table.origin: f"table has the shape {table.shape}",
on.origin: f"table lookup input is {on.bit_width}-bits",
self.converting: [
"so table cannot be looked up with this input",
f"table shape should have been {(2**on.bit_width,)}",
],
}
if on.bit_width != on.original_bit_width: # pragma: no cover
highlights[on.origin].append( # type: ignore
"("
f"note that it's assigned {on.bit_width}-bits "
f"during compilation because of its relation with other operations"
")"
)
self.error(highlights)

dialect = fhe if on.is_scalar else fhelinalg
operation = dialect.ApplyLookupTableEintOp

return self.operation(operation, resulting_type, on.result, table.result)

def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
assert self.is_bit_width_compatible(resulting_type, x)
assert resulting_type.is_encrypted and x.is_clear
Expand Down
4 changes: 4 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def dot(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.dot(ctx.typeof(node), preds[0], preds[1])

def dynamic_tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.dynamic_tlu(ctx.typeof(node), preds[0], preds[1])

def equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def converted_to_table_lookup(self) -> bool:
"conv2d",
"conv3d",
"dot",
"dynamic_tlu",
"expand_dims",
"index_static",
"matmul",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def format_constant(constant: Any, maximum_length: int = 45, keep_newlines: bool
return result


def format_indexing_element(indexing_element: Union[int, np.integer, slice]):
def format_indexing_element(indexing_element: Union[int, np.integer, slice, Any]):
"""
Format an indexing element.
Expand Down
15 changes: 13 additions & 2 deletions frontends/concrete-python/concrete/fhe/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,19 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":

def __getitem__(
self,
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
index: Union[
int, np.integer, slice, "Tracer", Tuple[Union[int, np.integer, slice, "Tracer"], ...]
],
) -> "Tracer":
if isinstance(index, Tracer) and index.output.is_encrypted and self.output.is_clear:
computation = Node.generic(
"dynamic_tlu",
[deepcopy(index.output), deepcopy(self.output)],
deepcopy(index.output),
lambda on, table: table[on],
)
return Tracer(computation, [index, self])

if not isinstance(index, tuple):
index = (index,)

Expand Down Expand Up @@ -770,7 +781,7 @@ def __getitem__(
raise ValueError(message)

output_value = deepcopy(self.output)
output_value.shape = np.zeros(output_value.shape)[index].shape
output_value.shape = np.zeros(output_value.shape)[index].shape # type: ignore

computation = Node.generic(
"index_static",
Expand Down
94 changes: 94 additions & 0 deletions frontends/concrete-python/tests/execution/test_dynamic_tlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Tests of execution of dynamic tlu operation.
"""

import random

import numpy as np
import pytest

from concrete import fhe
from concrete.fhe.dtypes import Integer

cases = []
for input_bit_width in range(1, 3):
for input_is_signed in [False, True]:
for output_bit_width in range(1, 3):
for output_is_signed in [False, True]:
input_shape = random.choice([(), (2,), (3, 2)])
cases.append(
pytest.param(
input_bit_width,
input_is_signed,
input_shape,
output_bit_width,
output_is_signed,
id=(
f"{'' if input_is_signed else 'u'}int{input_bit_width}"
f" -> "
f"{'' if output_is_signed else 'u'}int{output_bit_width}"
f" {{ input_shape={input_shape} }}"
),
)
)

# pylint: disable=redefined-outer-name


@pytest.mark.parametrize(
"input_bit_width,input_is_signed,input_shape,output_bit_width,output_is_signed",
cases,
)
def test_dynamic_tlu(
input_bit_width,
input_is_signed,
input_shape,
output_bit_width,
output_is_signed,
helpers,
):
"""
Test dynamic tlu.
"""

input_dtype = Integer(is_signed=input_is_signed, bit_width=input_bit_width)
output_dtype = Integer(is_signed=output_is_signed, bit_width=output_bit_width)

def function(x, y):
return y[x]

compiler = fhe.Compiler(function, {"x": "encrypted", "y": "clear"})
inputset = [
(
np.random.randint(
input_dtype.min(),
input_dtype.max() + 1,
size=input_shape,
),
np.random.randint(
output_dtype.min(),
output_dtype.max() + 1,
size=(2**input_bit_width,),
),
)
for _ in range(100)
]
circuit = compiler.compile(inputset, helpers.configuration())

samples = [
[
np.random.randint(
input_dtype.min(),
input_dtype.max() + 1,
size=input_shape,
),
np.random.randint(
output_dtype.min(),
output_dtype.max() + 1,
size=(2**input_bit_width,),
),
]
for _ in range(5)
]
for sample in samples:
helpers.check_execution(circuit, function, sample, retries=3)
25 changes: 25 additions & 0 deletions frontends/concrete-python/tests/mlir/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,31 @@ def assign(x, y):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit encrypted matrix multiplications are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: y[x],
{"x": "encrypted", "y": "clear"},
[
(
1,
[1, 2, 3, 4],
)
],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedScalar<uint1> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookup input is 1-bits
%1 = y # ClearTensor<uint3, shape=(4,)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table has the shape (4,)
%2 = dynamic_tlu(%0, %1) # EncryptedScalar<uint2> ∈ [2, 2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ so table cannot be looked up with this input
table shape should have been (2,)
return %2
""", # noqa: E501
),
],
Expand Down

0 comments on commit 6b743c1

Please sign in to comment.