Skip to content

Commit

Permalink
fix(frontend-python): change input bit-width only when tlu is optimiz…
Browse files Browse the repository at this point in the history
…ed based on original bit-width during table generation
  • Loading branch information
umut-sahin committed May 27, 2024
1 parent 08a1871 commit 5c704dc
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions frontends/concrete-python/concrete/fhe/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def construct_table(node: Node, preds: List[Node], configuration: Configuration)
assert_that(variable_input_index != -1)
variable_input = preds[variable_input_index]

variable_input_dtype = node.inputs[variable_input_index].dtype
variable_input_shape = node.inputs[variable_input_index].shape
variable_input_dtype = variable_input.output.dtype
variable_input_shape = variable_input.output.shape

assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = deepcopy(cast(Integer, variable_input_dtype))
Expand Down Expand Up @@ -209,7 +209,13 @@ def construct_table(node: Node, preds: List[Node], configuration: Configuration)

else:
original_bit_width = variable_input.properties["original_bit_width"]
variable_input_dtype.bit_width = original_bit_width
optimize = (
configuration.optimize_tlu_based_on_original_bit_width
if isinstance(configuration.optimize_tlu_based_on_original_bit_width, bool)
else original_bit_width <= configuration.optimize_tlu_based_on_original_bit_width
)
if optimize:
variable_input_dtype.bit_width = original_bit_width

if configuration.optimize_tlu_based_on_measured_bounds:
bounds = variable_input.bounds
Expand Down

0 comments on commit 5c704dc

Please sign in to comment.