Skip to content

Commit

Permalink
fix(frontend-python): change input bit-width only when tlu is optimiz…
Browse files Browse the repository at this point in the history
…ed based on original bit-width during table generation
  • Loading branch information
umut-sahin authored and BourgerieQuentin committed May 27, 2024
1 parent aa6e5d9 commit 3f1dc33
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 9 additions & 3 deletions frontends/concrete-python/concrete/fhe/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def construct_table(node: Node, preds: List[Node], configuration: Configuration)
assert_that(variable_input_index != -1)
variable_input = preds[variable_input_index]

variable_input_dtype = node.inputs[variable_input_index].dtype
variable_input_shape = node.inputs[variable_input_index].shape
variable_input_dtype = variable_input.output.dtype
variable_input_shape = variable_input.output.shape

assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = deepcopy(cast(Integer, variable_input_dtype))
Expand Down Expand Up @@ -209,7 +209,13 @@ def construct_table(node: Node, preds: List[Node], configuration: Configuration)

else:
original_bit_width = variable_input.properties["original_bit_width"]
variable_input_dtype.bit_width = original_bit_width
optimize = (
configuration.optimize_tlu_based_on_original_bit_width
if isinstance(configuration.optimize_tlu_based_on_original_bit_width, bool)
else original_bit_width <= configuration.optimize_tlu_based_on_original_bit_width
)
if optimize:
variable_input_dtype.bit_width = original_bit_width

if configuration.optimize_tlu_based_on_measured_bounds:
bounds = variable_input.bounds
Expand Down
6 changes: 3 additions & 3 deletions frontends/concrete-python/tests/mlir/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,7 @@ def test_converter_bad_convert(
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64>
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
%1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %1 : !FHE.eint<6>
Expand Down Expand Up @@ -1645,7 +1645,7 @@ def test_converter_bad_convert(
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64>
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
%1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %1 : !FHE.eint<6>
Expand All @@ -1667,7 +1667,7 @@ def test_converter_bad_convert(
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64>
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
%1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %1 : !FHE.eint<6>
Expand Down

0 comments on commit 3f1dc33

Please sign in to comment.