Skip to content

Commit

Permalink
refactor(frontend-python): reduce memory usage for table construction…
Browse files Browse the repository at this point in the history
… of non-multi table lookups
  • Loading branch information
umut-sahin committed Sep 27, 2023
1 parent f988ecc commit 2b45fce
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
3 changes: 3 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:

if len(tables) == 1:
table = tables[0][0]
assert tables[0][1] is None

# The reduction on 63b is to avoid problems like doing a TLU of
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
Expand All @@ -563,6 +564,8 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
for i, (table, indices) in enumerate(tables):
assert len(table) == individual_table_size
lut_values[i, :] = table

assert indices is not None
for index in indices:
map_values[index] = i

Expand Down
25 changes: 21 additions & 4 deletions frontends/concrete-python/concrete/fhe/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,14 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
np.seterr(divide="ignore")

inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = []
table: List[Optional[Union[int, np.bool_, np.integer, np.floating, np.ndarray]]] = []
for value in values:
try:
inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value
table.append(node(*inputs))
evaluation = node(*inputs)
table.append(
evaluation if evaluation.min() != evaluation.max() else int(evaluation.min())
)
except Exception: # pylint: disable=broad-except
# here we try our best to fill the table
# if it fails, we append None and let flooding algoritm replace None values below
Expand All @@ -152,7 +155,7 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
def construct_deduplicated_tables(
node: Node,
preds: List[Node],
) -> Tuple[Tuple[np.ndarray, List[Tuple[int, ...]]], ...]:
) -> Tuple[Tuple[np.ndarray, Optional[List[Tuple[int, ...]]]], ...]:
"""
Construct lookup tables for each cell of the input for an Operation.Generic node.
Expand Down Expand Up @@ -187,8 +190,22 @@ def construct_deduplicated_tables(
[ [5, 8, 6, 7][input[2, 0]] , [3, 1, 2, 4][input[2, 1]] ]
"""

raw_table = construct_table(node, preds)
if all(isinstance(value, int) for value in raw_table):
return ((np.array(raw_table), None),)

node_complete_table = np.concatenate(
tuple(np.expand_dims(array, -1) for array in construct_table(node, preds)),
tuple(
np.expand_dims(
(
array
if isinstance(array, np.ndarray)
else np.broadcast_to(array, node.output.shape)
),
-1,
)
for array in raw_table
),
axis=-1,
)

Expand Down

0 comments on commit 2b45fce

Please sign in to comment.