From fcd8925bde041227666453e78cc98f14cdb58e84 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 27 May 2024 13:11:51 +0300 Subject: [PATCH] fix(frontend-python): change input bit-width only when tlu is optimized based on original bit-width during table generation --- frontends/concrete-python/concrete/fhe/mlir/utils.py | 12 +++++++++--- .../concrete-python/tests/mlir/test_converter.py | 6 +++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/utils.py b/frontends/concrete-python/concrete/fhe/mlir/utils.py index 5e6d0397ec..ddc7ec5371 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/utils.py +++ b/frontends/concrete-python/concrete/fhe/mlir/utils.py @@ -170,8 +170,8 @@ def construct_table(node: Node, preds: List[Node], configuration: Configuration) assert_that(variable_input_index != -1) variable_input = preds[variable_input_index] - variable_input_dtype = node.inputs[variable_input_index].dtype - variable_input_shape = node.inputs[variable_input_index].shape + variable_input_dtype = variable_input.output.dtype + variable_input_shape = variable_input.output.shape assert_that(isinstance(variable_input_dtype, Integer)) variable_input_dtype = deepcopy(cast(Integer, variable_input_dtype)) @@ -209,7 +209,13 @@ def construct_table(node: Node, preds: List[Node], configuration: Configuration) else: original_bit_width = variable_input.properties["original_bit_width"] - variable_input_dtype.bit_width = original_bit_width + optimize = ( + configuration.optimize_tlu_based_on_original_bit_width + if isinstance(configuration.optimize_tlu_based_on_original_bit_width, bool) + else original_bit_width <= configuration.optimize_tlu_based_on_original_bit_width + ) + if optimize: + variable_input_dtype.bit_width = original_bit_width if configuration.optimize_tlu_based_on_measured_bounds: bounds = variable_input.bounds diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index 2d663db6f9..82b9d34f99 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -1573,7 +1573,7 @@ def test_converter_bad_convert( module { func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { %c3_i3 = arith.constant 3 : i3 - %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64> + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64> %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6> %1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> return %1 : !FHE.eint<6> @@ -1645,7 +1645,7 @@ def test_converter_bad_convert( module { func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { %c3_i3 = arith.constant 3 : i3 - %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64> + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64> %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6> %1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> return %1 : !FHE.eint<6> @@ -1667,7 +1667,7 @@ def test_converter_bad_convert( module { func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { %c3_i3 = arith.constant 3 : i3 - %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64> + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64> %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6> %1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> return %1 : !FHE.eint<6>