Skip to content

Commit

Permalink
fix(frontend-python): bad signed input tlu padding with extra bitwidth
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Jun 21, 2024
1 parent f93611e commit 0715df2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
7 changes: 6 additions & 1 deletion frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3776,7 +3776,12 @@ def tlu(self, resulting_type: ConversionType, on: Conversion, table: Sequence[in
result = self.add(resulting_type, result, constant)
return result

table += [0] * ((2**on.bit_width) - len(table))
padding = [0] * ((2**on.bit_width) - len(table))
if padding:
if on.is_unsigned:
table += padding
else:
table = table[:len(table)//2] + padding + table[-len(table)//2:]

dialect = fhe if on.is_scalar else fhelinalg
operation = dialect.ApplyLookupTableEintOp
Expand Down
22 changes: 22 additions & 0 deletions frontends/concrete-python/tests/execution/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,25 @@ def test_minimum_maximum(
]
for sample in samples:
helpers.check_execution(circuit, function, sample, retries=5)


def test_internal_signed_tlu_padding(helpers):
"""Test that the signed input LUT is correctly padded in the case of substraction trick."""

inputset = [(i, j) for i in [0, 1] for j in [0, 1]]

@fhe.compiler({"a": "encrypted", "b": "encrypted"})
def min2(a, b):
min_12 = np.minimum(a, b)
return (min_12, a + 3, b + 3)

c = min2.compile(inputset, helpers.configuration())
min_0_1, _, _ = c.encrypt_run_decrypt(0, 1)

assert min_0_1 == 0

# Some extra checks to verify that the test is relevant (substraction trick).
assert c.mlir.count("to_signed") == 2 # check substraction trick is used
assert c.mlir.count("sub_eint") == 1 # check substraction trick is used
assert c.mlir.count("<[0, 0, -2, -1, 0, 0, 0, 0]>") == 0 # lut wrongly padded at the end
assert c.mlir.count("<[0, 0, 0, 0, 0, 0, -2, -1]>") == 1 # lut correctly padded in the middle

0 comments on commit 0715df2

Please sign in to comment.