Skip to content

Commit

Permalink
docs(frontend): add FHE workarounds
Browse files Browse the repository at this point in the history
  • Loading branch information
aquint-zama committed Sep 29, 2023
1 parent 101a6da commit 25478ec
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
* [Simulation](tutorial/simulation.md)
* [Direct Circuits](tutorial/direct\_circuits.md)
* [Statistics](tutorial/statistics.md)
* [Common Workarounds](tutorial/workarounds.md)

## Application Tutorials

Expand Down
142 changes: 142 additions & 0 deletions docs/tutorial/workarounds.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Common tricks used in (T)FHE

As explained in the [Basics of FHE](../getting-started/fhe_basics.md), the challenge for developers
is to adapt their code to fit FHE constraints. In this document we have collected some common examples
to illustrate the kind of optimization one can do to get better performance.

{% hint style="info" %}
All code snippets provided here are for educational purpose. In future version of Concrete, some
functions described here could be directly available in a more generic and efficient form.
These code snippets are coming from support answers in our [community forum](https://community.zama.ai)
{% endhint %}


## Minimum for Two values

In this first example, we compute a minimum by creating a difference between the two numbers $$y$$ and $$x$$
and conditionally remove this diff from $$y$$ to either get $$x$$ if $$y>x$$ or $$y$$ if $$x>y$$:

```python
def min_two(x, y):
diff = y - x
min_x_y = y - np.maximum(y - x, 0)
return min_x_y
```

## Maximum for Two values

The companion example of above with the maximum value of two integers instead of the minimum:

```python
def max_two(x, y):
diff = y - x
max_x_y = y - np.minimum(y - x, 0)
return max_x_y
```

## Minimum for several values

And an extension for more than two values:

```python
def fhe_min(*args):
remaining = list(args)
while len(remaining) > 1:
a = remaining.pop()
b = remaining.pop()
min_a_b = b - np.maximum(b - a, 0)
remaining.insert(0, min_a_b)
return remaining[0]
```

## Retrieving a value with an encrypted index

This example show how to deal with an array and an encrypted index. It will create a "selection" array filled with `0` except for the requested index that will be `1`, and sum the products of all array values by this selection array:

```python

def indexed_value(array, index):
all_indices = np.arange(array.size)
index_selection = index == all_indices
selection_and_zeros = array * index_selection
selection = np.sum(selection_and_zeros)
return selection
```

## Filter an array with comparison (>)

This example will filter an encrypted array with an encrypted condition, here a `greater than` an encrypted value.
It will pack all values with a selection bit, resulting from the comparison and that will allow the unpacking of only the filtered values:

```python

def filtering(numbers, threshold):
is_greater = numbers > threshold

shifted_numbers = numbers * 2 # open space for a single bit at the end
combined_numbers_and_is_greater = shifted_numbers + is_greater # put is_greater to that bit

def extract(combination):
is_greater = (combination % 2) == 1 # extract is_greater back from packing
if_true = combination // 2 # if is greater is true, we unpack the number and use it
if_false = 0 # otherwise we set the element to zero
return np.where(is_greater, if_true, if_false) # and apply the operation

return fhe.univariate(extract)(combined_numbers_and_is_greater)

```

## Matrix Row/Col means

In this example of Matrix operation, we are introducing a key concept when using `Concrete`:
trying to maximize the parallelization. Here instead of sequentially sum all values to create a
mean value, we will split the values in sub-groups, and do the mean of the sub-groups means:

```python
def smallest_prime_divisor(n):
if n % 2 == 0:
return 2

for i in range(3, int(np.sqrt(n)) + 1):
if n % i == 0:
return i

return n

def mean_of_vector(x):
assert x.size != 0
if x.size == 1:
return x[0]

group_size = smallest_prime_divisor(x.size)
if x.size == group_size:
return np.round(np.sum(x) / x.size).astype(np.int64)

groups = []
for i in range(x.size // group_size):
start = i * group_size
end = start + group_size
groups.append(x[start:end])

mean_of_groups = []
for group in groups:
mean_of_groups.append(np.round(np.sum(group) / group_size).astype(np.int64))

return mean_of_vector(cnp.array(mean_of_groups))

def mean_of_matrix(x):
return mean_of_vector(x.flatten())

def mean_of_rows_of_matrix(x):
means = []
for i in range(x.shape[0]):
means.append(mean_of_vector(x[i]))
return cnp.array(means)

def mean_of_columns_of_matrix(x):
means = []
for i in range(x.shape[1]):
means.append(mean_of_vector(x[:, i]))
return cnp.array(means)

```

0 comments on commit 25478ec

Please sign in to comment.