diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 22d4daf925..144f7368c4 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -33,6 +33,7 @@ * [Key Value Database](application-tutorial/key\_value\_database.md) * [SHA-256](application-tutorial/sha256.md) +* [Game of Life](application-tutorial/game_of_life.md) ## How To diff --git a/docs/_static/game-of-life.gif b/docs/_static/game-of-life.gif new file mode 100644 index 0000000000..6b3f6d420d Binary files /dev/null and b/docs/_static/game-of-life.gif differ diff --git a/docs/application-tutorial/game_of_life.md b/docs/application-tutorial/game_of_life.md new file mode 100644 index 0000000000..980952354a --- /dev/null +++ b/docs/application-tutorial/game_of_life.md @@ -0,0 +1,54 @@ +# Game of Life + +In the associated [Python file](https://github.com/zama-ai/concrete/blob/main/frontends/concrete-python/examples/game_of_life/game_of_life.py), you can run the Game of Life, written in Concrete Python. + +![ Game of Life](../_static/game_of_life.gif) + +### Installation + +In addition to Concrete, you must install `pygame` in your virtual environment: + +- `pip3 install pygame` + +Once done, if you go to `frontends/concrete-python/examples/game_of_life`, `python game_of_life.py --help` should give you the manpage: + +``` +Game of Life in Concrete Python. + +options: + -h, --help show this help message and exit + --dimension DIMENSION + Dimension of the grid + --refresh_every REFRESH_EVERY + Refresh the grid every X steps + --method {method_3b,method_4b,method_5b,method_basic} + Method for refreshing the grid + --log2_global_p_error LOG2_GLOBAL_P_ERROR + Probability of correctness issue (full circuit) + --log2_p_error LOG2_P_ERROR + Probability of correctness issue (individual TLU) + --simulate Simulate instead of running computations in FHE + --show_mlir Show the MLIR + --stop_after_compilation + Stop after compilation + --text_output Print a text output of the grid +``` + +### Running + +Then, you can play with the different options, and in particular: + +- `dimension`, to chose the size of the grid; the larger, the slower +- `method`, to chose which implementation is used for the grid update +- `log2_global_p_error` and `log2_p_error`, to chose the probability of correctness issue (see the Concrete documentation for more information) +- `simulate`, to do computations only in simulation, i.e., not in FHE + +### Typical Executions + +In simulation: `python3 game_of_life.py --dimension 100 --refresh_every 50 --simulate` + +In FHE: `python3 game_of_life.py --dimension 6 --refresh_every 8 --log2_p_error -40 --method method_4b` + +### Technical Explanations + +A blog is currently being under writening, and a link will be added it here when it's available. In the meantime, some explanations are given in the code. diff --git a/frontends/concrete-python/examples/game_of_life/game_of_life.py b/frontends/concrete-python/examples/game_of_life/game_of_life.py new file mode 100644 index 0000000000..d150723805 --- /dev/null +++ b/frontends/concrete-python/examples/game_of_life/game_of_life.py @@ -0,0 +1,500 @@ +import sys +import time + +# Hide pygame prompt +from os import environ + +import numpy as np + +environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" + +# ruff: noqa:E402 +import argparse + +# ruff: noqa:E402 +import pygame + +# ruff: noqa:E402 +from concrete import fhe + + +# Function to workaround the miss of padding in CP +def by_hand_padding(original_grid, res): + padded_res = fhe.zeros(original_grid.shape) + + original_grid_shape = original_grid.shape + assert padded_res.shape[0:2] == (1, 1) + padded_res[0, 0, 1 : original_grid_shape[2] - 1, 1 : original_grid_shape[3] - 1] = res + + assert original_grid.shape == padded_res.shape + + return padded_res + + +# Function to workaround the miss of padding in CP +def conv_with_hand_padding(grid, weight, do_padded_fix): + convoluted_grid = fhe.conv( + grid, + weight.reshape(1, 1, *weight.shape), + strides=(1, 1), + dilations=(1, 1), + group=1, + pads=(1, 1, 1, 1) if not do_padded_fix else (0, 0, 0, 0), + ) + + if do_padded_fix: + convoluted_grid = by_hand_padding(grid, convoluted_grid) + + return convoluted_grid + + +# Function for Game of Life +@fhe.compiler({"grid": "encrypted"}) +def update_grid_method_3b(grid): + # Method which uses two first TLU of 3 bits and a third TLU of 2 bits + + weights_method_3b_a = np.array( + [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 0], + ] + ) + weights_method_3b_b = np.array( + [ + [0, 0, 0], + [0, 0, 0], + [0, 0, 1], + ] + ) + table_next_cell_3b_a = [i if i <= 3 else 4 for i in range(8)] + table_next_cell_3b_b = [i - 1 if i in [2, 3] else 0 for i in range(6)] + table_next_cell_3b_c = [int(i in [2, 3]) for i in range(4)] + + table_cp_next_cell_3b_a = fhe.LookupTable(table_next_cell_3b_a) + table_cp_next_cell_3b_b = fhe.LookupTable(table_next_cell_3b_b) + table_cp_next_cell_3b_c = fhe.LookupTable(table_next_cell_3b_c) + + # This is to workaround the fact that we have no pad option in fhe.conv + do_padded_fix = True + + # Compute the sum of 7 elements + convoluted_grid = conv_with_hand_padding(grid, weights_method_3b_a, do_padded_fix) + + # Apply a TLU: input in [0, 7], output in [0, 4] + grid_a = table_cp_next_cell_3b_a[convoluted_grid] + + # Add the 8th one: output is in [0, 5] + convoluted_grid = conv_with_hand_padding(grid, weights_method_3b_b, do_padded_fix) + + grid_b = grid_a + convoluted_grid + + # Apply a TLU: input in [0, 5], output in [0, 4] + grid_c = table_cp_next_cell_3b_b[grid_b] + + # Add center + grid = grid_c + grid + + # And a last TLU: input in [0, 5] and output in [0, 1] + grid = table_cp_next_cell_3b_c[grid] + + return grid + + +@fhe.compiler({"grid": "encrypted"}) +def update_grid_method_4b(grid): + # Method which uses a first TLU of 4 bits and a second TLU of 2 bits + + weights_method_4b = np.array( + [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1], + ] + ) + table_next_cell_4b_a = [i - 1 if i in [2, 3] else 0 for i in range(9)] + table_next_cell_4b_b = [int(i in [2, 3]) for i in range(4)] + + table_cp_next_cell_4b_a = fhe.LookupTable(table_next_cell_4b_a) + table_cp_next_cell_4b_b = fhe.LookupTable(table_next_cell_4b_b) + + # This is to workaround the fact that we have no pad option in fhe.conv + do_padded_fix = True + + convoluted_grid = conv_with_hand_padding(grid, weights_method_4b, do_padded_fix) + + grid_a = table_cp_next_cell_4b_a[convoluted_grid] + grid = grid_a + grid + grid = table_cp_next_cell_4b_b[grid] + + return grid + + +@fhe.compiler({"grid": "encrypted"}) +def update_grid_method_5b(grid): + # Method which uses a single TLU of 5 bits + + weights_method_5b = np.array( + [ + [1, 1, 1], + [1, 9, 1], + [1, 1, 1], + ] + ) + table_next_cell_5b = [int(i in [3, 9 + 2, 9 + 3]) for i in range(18)] + + table_cp_next_cell_5b = fhe.LookupTable(table_next_cell_5b) + + # This is to workaround the fact that we have no pad option in fhe.conv + do_padded_fix = True + + convoluted_grid = conv_with_hand_padding(grid, weights_method_5b, do_padded_fix) + + grid = table_cp_next_cell_5b[convoluted_grid] + + return grid + + +@fhe.compiler({"grid": "encrypted"}) +def update_grid_basic(grid): + # Method which follows the naive approach + + weights_method_basic = np.array( + [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1], + ] + ) + table_next_cell_basic_a = [int(i in [3]) for i in range(9)] + table_next_cell_basic_b = [int(i in [2, 3]) for i in range(9)] + + table_cp_next_cell_basic_a = fhe.LookupTable(table_next_cell_basic_a) + table_cp_next_cell_basic_b = fhe.LookupTable(table_next_cell_basic_b) + + # This is to workaround the fact that we have no pad option in fhe.conv + do_padded_fix = True + + convoluted_grid = conv_with_hand_padding(grid, weights_method_basic, do_padded_fix) + + grid = table_cp_next_cell_basic_a[convoluted_grid] | ( + table_cp_next_cell_basic_b[convoluted_grid] & (grid == 1) + ) + + return grid + + +# Function for Game of Life +def update_grid(grid, method="method_3b"): + assert grid.ndim == 4 + + if method == "method_basic": + return update_grid_basic(grid) + + if method == "method_3b": + return update_grid_method_3b(grid) + + if method == "method_4b": + return update_grid_method_4b(grid) + + if method == "method_5b": + return update_grid_method_5b(grid) + + msg = "Bad method" + raise ValueError(msg) + + +# Graphic functions +# The graphical functions of this code were inspired by those of +# https://github.com/matheusgomes28/pygame-life/blob/main/pygame_life.py +def manage_graphics_and_refresh( + grid, + count, + dimension, + nb_initial_points, + border_size, + screen, + background_refresh_color, + background_color, + life_color, + time_new_grid_sleep, + time_sleep, + refresh_every, + do_text_output, +): + make_new_grid = count == 0 or (refresh_every > 0 and (count % refresh_every) == 0) + + count += 1 + + # Refresh the grid from time to time + if make_new_grid: + grid = np.random.randint(2, size=(1, 1, dimension, dimension), dtype=np.int8) + screen.fill(background_refresh_color) + pygame.display.flip() + time.sleep(time_new_grid_sleep) + + screen.fill(background_color) + + # Draw the grid + width = grid.shape[2 + 0] + height = grid.shape[2 + 1] + cell_width = screen.get_width() / width + cell_height = screen.get_height() / height + + for x in range(width): + for y in range(height): + if grid[0, 0, x, y]: + pygame.draw.rect( + screen, + life_color, + ( + x * cell_width + border_size, + y * cell_height + border_size, + cell_width - border_size, + cell_height - border_size, + ), + ) + + if do_text_output: + np.set_printoptions(threshold=sys.maxsize, linewidth=np.nan) + print( + str(grid[0, 0, :, :]) + .replace("[", " ") + .replace("]", " ") + .replace("0", ".") + .replace("1", "*") + .replace(" ", "") + ) + + pygame.display.flip() + + # Make a pause for controlled speed + time.sleep(time_sleep) + + return grid, count + + +def autotest(dimension): + # Check all our methods return the same result + + for _ in range(100): + # Take a random grid + grid = np.random.randint(2, size=(1, 1, dimension, dimension), dtype=np.int8) + + # Check the results are the same + results = {} + + for method in ["method_3b", "method_4b", "method_5b", "method_basic"]: + results[method] = update_grid(grid, method=method) + + keys = list(results.keys()) + + for k in keys[1:]: + assert np.array_equal( + results[keys[0]], results[k] + ), f"{results[keys[0]]} {results[k]} are different" + + print("Tests of methods looks ok") + + +def manage_args(): + parser = argparse.ArgumentParser(description="Game of Life in Concrete Python.") + parser.add_argument( + "--dimension", + dest="dimension", + action="store", + type=int, + default=100, + help="Dimension of the grid", + ) + parser.add_argument( + "--refresh_every", + dest="refresh_every", + action="store", + type=int, + default=None, + help="Refresh the grid every X steps", + ) + parser.add_argument( + "--method", + dest="method", + action="store", + choices=["method_3b", "method_4b", "method_5b", "method_basic"], + default="method_5b", + help="Method for refreshing the grid", + ) + parser.add_argument( + "--log2_global_p_error", + dest="log2_global_p_error", + action="store", + type=float, + default=None, + help="Probability of correctness issue (full circuit)", + ) + parser.add_argument( + "--log2_p_error", + dest="log2_p_error", + action="store", + type=float, + default=-16, + help="Probability of correctness issue (individual TLU)", + ) + parser.add_argument( + "--simulate", + action="store_true", + dest="fhe_simulation", + help="Simulate instead of running computations in FHE", + ) + parser.add_argument( + "--show_mlir", + action="store_true", + dest="show_mlir", + help="Show the MLIR", + ) + parser.add_argument( + "--stop_after_compilation", + action="store_true", + dest="stop_after_compilation", + help="Stop after compilation", + ) + parser.add_argument( + "--text_output", + action="store_true", + dest="text_output", + help="Print a text output of the grid", + ) + + args = parser.parse_args() + return args + + +def main(): + # Options by the user + args = manage_args() + + # Dimension of the grid. The larger, the slower, in FHE + dimension = args.dimension + + # Which method + which_method = args.method + + # Switch this off to not compile in FHE + do_compile = True + + # Activate to simulate + fhe_simulation = args.fhe_simulation + + # Refresh with a random grid every X steps + refresh_every = min(100, dimension) if args.refresh_every is None else args.refresh_every + + # To see the execution time + do_print_time = True + + # If there is no X server + do_text_output = args.text_output + + # Probability of failure + log2_global_p_error = args.log2_global_p_error + log2_p_error = args.log2_p_error + + # Options for graphics + nb_initial_points = dimension**2 + size = (1000, 700) + background_color = (20, 20, 20) + background_refresh_color = (150, 20, 20) + life_color = (55, 200, 200) + border_size = 1 + + time_sleep = 0.1 if not do_compile or fhe_simulation else 0 + + time_new_grid_sleep = 0.4 + + # Autotest + autotest(dimension=dimension) + + print(f"Using method {which_method}") + print(f"Using a grid {dimension} * {dimension}") + print(f"Refreshing every {refresh_every} steps") + print(f"Using 2**{log2_global_p_error} for global_p_error") + print(f"Using 2**{log2_p_error} for p_error") + + # Compile + if do_compile: + inputset = [ + np.random.randint(2, size=(1, 1, dimension, dimension), dtype=np.int8) + for _ in range(1000) + ] + + if which_method == "method_3b": + function = update_grid_method_3b + elif which_method == "method_4b": + function = update_grid_method_4b + elif which_method == "method_5b": + function = update_grid_method_5b + else: + assert which_method == "method_basic" + function = update_grid_basic + + circuit = function.compile( + inputset, + show_mlir=args.show_mlir, + fhe_simulation=fhe_simulation, + global_p_error=None, # 2**log2_global_p_error, + p_error=2**log2_p_error, + ) + + if args.stop_after_compilation: + sys.exit(0) + + # Set plot + pygame.init() + screen = pygame.display.set_mode(size) + pygame.display.set_caption("Game of Life in Concrete Python") + count = 0 + grid = None + + # Run the key generation, to avoid to have a first execution time which is slower + if do_compile and not fhe_simulation: + time_start = time.time() + grid = circuit.keygen() + time_end = time.time() + + if do_print_time: + print(f"Generating key in {time_end - time_start:.2f} seconds") + + while True: + if pygame.QUIT in [e.type for e in pygame.event.get()]: + sys.exit(0) + + grid, count = manage_graphics_and_refresh( + grid, + count, + dimension, + nb_initial_points, + border_size, + screen, + background_refresh_color, + background_color, + life_color, + time_new_grid_sleep, + time_sleep, + refresh_every, + do_text_output, + ) + + # Update the grid + time_start = time.time() + + if do_compile: + grid = circuit.simulate(grid) if fhe_simulation else circuit.encrypt_run_decrypt(grid) + else: + grid = update_grid(grid, method=which_method) + + time_end = time.time() + + if do_print_time: + print(f"Updating grid in {time_end - time_start:.2f} seconds") + + +if __name__ == "__main__": + main()