Skip to content

Commit

Permalink
docs(frontend): fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
bcm-at-zama committed Sep 11, 2024
1 parent 26136e6 commit 33fee51
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
22 changes: 7 additions & 15 deletions frontends/concrete-python/examples/pir/pir_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
import numpy as np

from concrete import fhe

import numpy as np

def make_one_hot_vector(index: int, size: int) -> np.array:

answer = np.zeros(shape=(size,), dtype=np.int8)
answer[index] = 1
return answer


@fhe.compiler({"one_hot_vector": "encrypted", "database": "clear"})
def get_ith_element_of_database(one_hot_vector: np.array, database: np.array) -> int:
return np.dot(one_hot_vector, database)


def compile_function(database, show_mlir=False, show_graph=False):
database_length = database.shape[0]
inputset_length = 100
inputset = [
(make_one_hot_vector(np.random.randint(database_length), database_length), database)
for _ in range(inputset_length)
]
circuit = get_ith_element_of_database.compile(inputset, show_mlir=show_mlir, show_graph=show_graph)
circuit = get_ith_element_of_database.compile(
inputset, show_mlir=show_mlir, show_graph=show_graph
)
return circuit

def test_encrypted_queries(database, circuit, how_many_tests=1, verbose=True):

times = []
def test_encrypted_queries(database, circuit, how_many_tests=1, verbose=True):

for _ in range(how_many_tests):
database_length = database.shape[0]
Expand All @@ -41,21 +44,10 @@ def test_encrypted_queries(database, circuit, how_many_tests=1, verbose=True):
encrypted_x, _ = circuit.encrypt(x, None)

# Run the FHE computation on the server side
time_begin = time.time()
encrypted_y = circuit.run(encrypted_x, database)
time_end = time.time()
times.append(time_end - time_begin)

if verbose:
print(
f"FHE computation done in {(time_end - time_begin) * 1000:.1f} milliseconds -- database is {database_length} (2**{log_database_length}) elements of {database_output_bits} bits"
)

# Decrypt the result on the client side
y = circuit.decrypt(encrypted_y)

# And check the computations worked fine
assert y == database[random_index]

return times

25 changes: 21 additions & 4 deletions frontends/concrete-python/tests/execution/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from examples.pir import pir_utils


def test_static_kvdb(helpers):
"""
Test static key-value database example.
Expand Down Expand Up @@ -272,6 +273,7 @@ def test_levenshtein_distance_randomly(alphabet_name, max_length, helpers):
actual_distance = levenshtein_distance.calculate(str1, str2, "simulate", show_distance=True)
assert actual_distance == expected_distance


def test_pir_basics():

x = pir_utils.make_one_hot_vector(0, size=5)
Expand All @@ -286,9 +288,25 @@ def test_pir_basics():

assert database_output_bits == 5

assert pir_utils.get_ith_element_of_database(pir_utils.make_one_hot_vector(0, size=database_length), database) == database[0]
assert pir_utils.get_ith_element_of_database(pir_utils.make_one_hot_vector(3, size=database_length), database) == database[3]
assert pir_utils.get_ith_element_of_database(pir_utils.make_one_hot_vector(4, size=database_length), database) == database[4]
assert (
pir_utils.get_ith_element_of_database(
pir_utils.make_one_hot_vector(0, size=database_length), database
)
== database[0]
)
assert (
pir_utils.get_ith_element_of_database(
pir_utils.make_one_hot_vector(3, size=database_length), database
)
== database[3]
)
assert (
pir_utils.get_ith_element_of_database(
pir_utils.make_one_hot_vector(4, size=database_length), database
)
== database[4]
)


@pytest.mark.parametrize("database_input_bits, database_output_bits", [(4, 8), (6, 6), (9, 16)])
def test_pir_full(database_input_bits, database_output_bits):
Expand All @@ -304,4 +322,3 @@ def test_pir_full(database_input_bits, database_output_bits):

# And check it works as expected
pir_utils.test_encrypted_queries(database, circuit)

0 comments on commit 33fee51

Please sign in to comment.