Skip to content

Commit

Permalink
feat(frontend-python): dynamic assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Jun 5, 2024
1 parent 226ee27 commit 1700cc6
Show file tree
Hide file tree
Showing 11 changed files with 925 additions and 184 deletions.
1 change: 1 addition & 0 deletions .github/workflows/concrete_python_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ jobs:
- uses: actions/checkout@v3
- name: Pre-Commit Checks
run: |
sudo apt install -y graphviz libgraphviz-dev
cd frontends/concrete-python
make venv
source .venv/bin/activate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,7 @@ class Configuration:
optimize_tlu_based_on_original_bit_width: Union[bool, int]
detect_overflow_in_simulation: bool
dynamic_indexing_check_out_of_bounds: bool
dynamic_assignment_check_out_of_bounds: bool

def __init__(
self,
Expand Down Expand Up @@ -1059,6 +1060,7 @@ def __init__(
optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8,
detect_overflow_in_simulation: bool = False,
dynamic_indexing_check_out_of_bounds: bool = True,
dynamic_assignment_check_out_of_bounds: bool = True,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1162,6 +1164,7 @@ def __init__(
self.detect_overflow_in_simulation = detect_overflow_in_simulation

self.dynamic_indexing_check_out_of_bounds = dynamic_indexing_check_out_of_bounds
self.dynamic_assignment_check_out_of_bounds = dynamic_assignment_check_out_of_bounds

self._validate()

Expand Down Expand Up @@ -1236,6 +1239,7 @@ def fork(
optimize_tlu_based_on_original_bit_width: Union[Keep, bool, int] = KEEP,
detect_overflow_in_simulation: Union[Keep, bool] = KEEP,
dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP,
dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
149 changes: 8 additions & 141 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from mlir.ir import BoolAttr as MlirBoolAttr
from mlir.ir import Context as MlirContext
from mlir.ir import DenseElementsAttr as MlirDenseElementsAttr
from mlir.ir import DenseI64ArrayAttr as MlirDenseI64ArrayAttr
from mlir.ir import IndexType
from mlir.ir import InsertionPoint as MlirInsertionPoint
from mlir.ir import IntegerAttr as MlirIntegerAttr
Expand Down Expand Up @@ -1913,154 +1912,21 @@ def array(self, resulting_type: ConversionType, elements: List[Conversion]) -> C
original_bit_width=original_bit_width,
)

def assign_static(
def assign(
self,
resulting_type: ConversionType,
x: Conversion,
y: Conversion,
index: Sequence[Union[int, np.integer, slice]],
index: Sequence[Union[int, np.integer, slice, np.ndarray, list, Conversion]],
):
if x.is_clear and y.is_encrypted:
highlights = {
x.origin: "tensor is clear",
y.origin: "assigned value is encrypted",
self.converting: "but encrypted values cannot be assigned to clear tensors",
}
self.error(highlights)

assert self.is_bit_width_compatible(resulting_type, x, y)

if any(isinstance(indexing_element, (list, np.ndarray)) for indexing_element in index):
return self.assign_static_fancy(resulting_type, x, y, index)

index = list(index)
while len(index) < len(x.shape):
index.append(slice(None, None, None))

offsets = []
sizes = []
strides = []

for indexing_element, dimension_size in zip(index, x.shape):
if isinstance(indexing_element, slice):
size = int(np.zeros(dimension_size)[indexing_element].shape[0])
stride = int(indexing_element.step if indexing_element.step is not None else 1)
offset = int(
(
indexing_element.start
if indexing_element.start >= 0
else indexing_element.start + dimension_size
)
if indexing_element.start is not None
else (0 if stride > 0 else dimension_size - 1)
)

else:
size = 1
stride = 1
offset = int(
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
)

offsets.append(offset)
sizes.append(size)
strides.append(stride)

if x.is_encrypted and y.is_clear:
encrypted_type = self.typeof(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=y.shape,
is_encrypted=True,
)
)
y = self.encrypt(encrypted_type, y)

required_y_shape_list = []
for i, indexing_element in enumerate(index):
if isinstance(indexing_element, slice):
n = len(np.zeros(x.shape[i])[indexing_element])
required_y_shape_list.append(n)
else:
required_y_shape_list.append(1)

required_y_shape = tuple(required_y_shape_list)
try:
np.reshape(np.zeros(y.shape), required_y_shape)
y = self.reshape(y, required_y_shape)
except Exception: # pylint: disable=broad-except
np.broadcast_to(np.zeros(y.shape), required_y_shape)
y = self.broadcast_to(y, required_y_shape)

x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)

return self.operation(
tensor.InsertSliceOp,
resulting_type,
y.result,
x.result,
(),
(),
(),
MlirDenseI64ArrayAttr.get(offsets),
MlirDenseI64ArrayAttr.get(sizes),
MlirDenseI64ArrayAttr.get(strides),
original_bit_width=x.original_bit_width,
)
# This import needs to happen here to avoid circular imports.
# pylint: disable=import-outside-toplevel

def assign_static_fancy(
self,
resulting_type: ConversionType,
x: Conversion,
y: Conversion,
index: Sequence[Union[int, np.integer, slice, np.ndarray, list]],
) -> Conversion:
resulting_element_type = (self.eint if resulting_type.is_unsigned else self.esint)(
resulting_type.bit_width
)
from .operations.assignment import assignment

indices = []
indices_shape = ()
for indexing_element in index:
if isinstance(indexing_element, (int, np.integer, list, np.ndarray)):
indexing_element_array = np.array(indexing_element)
indices_shape = np.broadcast_shapes(
indices_shape,
indexing_element_array.shape,
) # type:ignore
indices.append(indexing_element_array)

else: # pragma: no cover
message = f"invalid indexing element of type {type(indexing_element)}"
raise AssertionError(message)
values_shape = indices_shape

indices = [np.broadcast_to(index, shape=indices_shape) for index in indices]

concrete_indices = []
for i in np.ndindex(*indices_shape):
concrete_index = [index[i] for index in indices]
concrete_indices.append(concrete_index)

if len(x.shape) > 1:
indices_shape = (*indices_shape, len(x.shape)) # type: ignore

if x.is_clear and y.is_encrypted:
raise NotImplementedError
if x.is_encrypted and y.is_clear:
y = self.encrypt(self.tensor(resulting_element_type, shape=y.shape), y)
# pylint: enable=import-outside-toplevel

return self.operation(
fhelinalg.FancyAssignOp,
resulting_type,
x.result,
self.constant(
self.tensor(self.index_type(), indices_shape),
np.array(concrete_indices).reshape(indices_shape),
).result,
self.broadcast_to(y, shape=values_shape).result,
)
return assignment(self, resulting_type, x, y, index)

def bitwise(
self,
Expand Down Expand Up @@ -2713,6 +2579,7 @@ def index(
) -> Conversion:
# This import needs to happen here to avoid circular imports.
# pylint: disable=import-outside-toplevel

from .operations.indexing import indexing

# pylint: enable=import-outside-toplevel
Expand Down
25 changes: 23 additions & 2 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,34 @@ def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion
assert len(preds) > 0
return ctx.array(ctx.typeof(node), elements=preds)

def assign_dynamic(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) >= 3

x = preds[0]
y = preds[-1]

dynamic_indices = preds[1:-1]
static_indices = node.properties["kwargs"]["static_indices"]

indices = []

cursor = 0
for index in static_indices:
if index is None:
indices.append(dynamic_indices[cursor])
cursor += 1
else:
indices.append(index)

return ctx.assign(ctx.typeof(node), x, y, indices)

def assign_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.assign_static(
return ctx.assign(
ctx.typeof(node),
preds[0],
preds[1],
index=node.properties["kwargs"]["index"],
node.properties["kwargs"]["index"],
)

def bitwise_and(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
Expand Down
Loading

0 comments on commit 1700cc6

Please sign in to comment.