From e5d073a0ebc46e4b28d5628e10592b392f8bff96 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 12 Nov 2024 14:00:23 +0100 Subject: [PATCH] fix(frontend): support higher bitwidth computation when using TFHE-rs --- .../concrete/fhe/mlir/converter.py | 17 ++++++++++++++--- .../tests/execution/test_tfhers.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index f5cf6211bb..6c244f670b 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -985,9 +985,20 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> ] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1)) padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table) # set padding bits (where necessary) in the final result - return ctx.add(result_type, sum_result, padding_bits_inc) - - return sum_result + result = ctx.add(result_type, sum_result, padding_bits_inc) + else: + result = sum_result + + # even if TFHE-rs value are using non-variable bit-width, we want the output + # to be pluggable into the rest of the computation. For example, two 8bits TFHE-rs integers + # could be used in a 9bits addition. If we don't cast, it won't pass the bitwidth + # compatibility check. + output_bit_width = ctx.typeof(node).bit_width + casted_result_type = ctx.tensor( + ctx.esint(output_bit_width) if dtype.is_signed else ctx.eint(output_bit_width), + result_shape, + ) + return ctx.cast(casted_result_type, result) def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1 diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index 1ac6487ffd..fddc7713c3 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -327,6 +327,16 @@ def lut_add_lut(x, y): TFHERS_UINT_8_3_2_4096, id="x + y", ), + # make sure Concrete ciphertexts can use more than 8 bits + pytest.param( + lambda x, y: (x + y) % 213, + { + "x": {"range": [128, 255], "status": "encrypted"}, + "y": {"range": [128, 255], "status": "encrypted"}, + }, + TFHERS_UINT_8_3_2_4096, + id="mod(x + y)", + ), pytest.param( lambda x, y: x + y, {