From a5fbcce9d7d2ff2549ebe4ad4faec23fa3d4873c Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 19 Nov 2020 00:53:19 -0800 Subject: [PATCH] Initial commit --- .clang-format | 3 + .gitignore | 19 + .gitmodules | 8 + CMakeLists.txt | 39 ++ CODE_OF_CONDUCT.md | 9 + CONTRIBUTING.md | 12 + LICENSE.txt | 21 ++ README.md | 177 +++++++++ SECURITY.md | 41 +++ SUPPORT.md | 15 + azure-pipelines.yml | 108 ++++++ eva/CMakeLists.txt | 27 ++ eva/ckks/CMakeLists.txt | 6 + eva/ckks/always_rescaler.h | 65 ++++ eva/ckks/ckks_compiler.h | 309 ++++++++++++++++ eva/ckks/ckks_config.cpp | 109 ++++++ eva/ckks/ckks_config.h | 41 +++ eva/ckks/ckks_parameters.h | 23 ++ eva/ckks/ckks_signature.h | 37 ++ eva/ckks/eager_relinearizer.h | 56 +++ eva/ckks/eager_waterline_rescaler.h | 95 +++++ eva/ckks/encode_inserter.h | 62 ++++ eva/ckks/encryption_parameter_selector.h | 210 +++++++++++ eva/ckks/lazy_relinearizer.h | 98 +++++ eva/ckks/lazy_waterline_rescaler.h | 433 ++++++++++++++++++++++ eva/ckks/levels_checker.h | 71 ++++ eva/ckks/minimum_rescaler.h | 124 +++++++ eva/ckks/mod_switcher.h | 94 +++++ eva/ckks/parameter_checker.h | 106 ++++++ eva/ckks/rescaler.h | 60 ++++ eva/ckks/scales_checker.h | 96 +++++ eva/ckks/seal_lowering.h | 32 ++ eva/common/CMakeLists.txt | 6 + eva/common/constant_folder.h | 192 ++++++++++ eva/common/multicore_program_traversal.h | 151 ++++++++ eva/common/program_traversal.h | 102 ++++++ eva/common/reduction_balancer.h | 148 ++++++++ eva/common/reference_executor.cpp | 117 ++++++ eva/common/reference_executor.h | 80 +++++ eva/common/rotation_keys_selector.h | 57 +++ eva/common/type_deducer.h | 40 +++ eva/common/valuation.h | 14 + eva/eva.cpp | 23 ++ eva/eva.h | 16 + eva/ir/CMakeLists.txt | 9 + eva/ir/attribute_list.cpp | 106 ++++++ eva/ir/attribute_list.h | 79 ++++ eva/ir/attributes.cpp | 29 ++ eva/ir/attributes.h | 38 ++ eva/ir/constant_value.h | 146 ++++++++ eva/ir/ops.h | 56 +++ eva/ir/program.cpp | 221 ++++++++++++ eva/ir/program.h | 156 ++++++++ eva/ir/term.cpp | 153 ++++++++ eva/ir/term.h | 62 ++++ eva/ir/term_map.h | 125 +++++++ eva/ir/types.h | 35 ++ eva/seal/CMakeLists.txt | 6 + eva/seal/seal.cpp | 205 +++++++++++ eva/seal/seal.h | 99 +++++ eva/seal/seal_executor.h | 438 +++++++++++++++++++++++ eva/serialization/CMakeLists.txt | 16 + eva/serialization/ckks.proto | 23 ++ eva/serialization/ckks_serialization.cpp | 82 +++++ eva/serialization/eva.proto | 46 +++ eva/serialization/eva_format_version.h | 13 + eva/serialization/eva_serialization.cpp | 291 +++++++++++++++ eva/serialization/known_type.cpp | 35 ++ eva/serialization/known_type.h | 37 ++ eva/serialization/known_type.proto | 13 + eva/serialization/save_load.cpp | 36 ++ eva/serialization/save_load.h | 64 ++++ eva/serialization/seal.proto | 41 +++ eva/serialization/seal_serialization.cpp | 232 ++++++++++++ eva/util/CMakeLists.txt | 12 + eva/util/galois.cpp | 14 + eva/util/galois.h | 15 + eva/util/logging.cpp | 68 ++++ eva/util/logging.h | 22 ++ eva/util/overloaded.h | 20 ++ eva/version.cpp | 12 + eva/version.h | 12 + examples/.gitignore | 8 + examples/baboon.png | Bin 0 -> 11917 bytes examples/image_processing.py | 131 +++++++ examples/requirements.txt | 2 + examples/serialization.py | 73 ++++ python/.gitignore | 3 + python/CMakeLists.txt | 15 + python/eva/CMakeLists.txt | 13 + python/eva/__init__.py | 162 +++++++++ python/eva/ckks/__init__.py | 4 + python/eva/metric.py | 19 + python/eva/seal/__init__.py | 4 + python/eva/std/numeric.py | 21 ++ python/eva/wrapper.cpp | 225 ++++++++++++ python/setup.py.in | 23 ++ scripts/clang-format-all.sh | 10 + tests/all.py | 10 + tests/bug_fixes.py | 71 ++++ tests/common.py | 36 ++ tests/features.py | 220 ++++++++++++ tests/large_programs.py | 149 ++++++++ tests/std.py | 38 ++ third_party/Galois | 1 + third_party/pybind11 | 1 + 106 files changed, 7858 insertions(+) create mode 100644 .clang-format create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 CMakeLists.txt create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE.txt create mode 100644 README.md create mode 100644 SECURITY.md create mode 100644 SUPPORT.md create mode 100644 azure-pipelines.yml create mode 100644 eva/CMakeLists.txt create mode 100644 eva/ckks/CMakeLists.txt create mode 100644 eva/ckks/always_rescaler.h create mode 100644 eva/ckks/ckks_compiler.h create mode 100644 eva/ckks/ckks_config.cpp create mode 100644 eva/ckks/ckks_config.h create mode 100644 eva/ckks/ckks_parameters.h create mode 100644 eva/ckks/ckks_signature.h create mode 100644 eva/ckks/eager_relinearizer.h create mode 100644 eva/ckks/eager_waterline_rescaler.h create mode 100644 eva/ckks/encode_inserter.h create mode 100644 eva/ckks/encryption_parameter_selector.h create mode 100644 eva/ckks/lazy_relinearizer.h create mode 100644 eva/ckks/lazy_waterline_rescaler.h create mode 100644 eva/ckks/levels_checker.h create mode 100644 eva/ckks/minimum_rescaler.h create mode 100644 eva/ckks/mod_switcher.h create mode 100644 eva/ckks/parameter_checker.h create mode 100644 eva/ckks/rescaler.h create mode 100644 eva/ckks/scales_checker.h create mode 100644 eva/ckks/seal_lowering.h create mode 100644 eva/common/CMakeLists.txt create mode 100644 eva/common/constant_folder.h create mode 100644 eva/common/multicore_program_traversal.h create mode 100644 eva/common/program_traversal.h create mode 100644 eva/common/reduction_balancer.h create mode 100644 eva/common/reference_executor.cpp create mode 100644 eva/common/reference_executor.h create mode 100644 eva/common/rotation_keys_selector.h create mode 100644 eva/common/type_deducer.h create mode 100644 eva/common/valuation.h create mode 100644 eva/eva.cpp create mode 100644 eva/eva.h create mode 100644 eva/ir/CMakeLists.txt create mode 100644 eva/ir/attribute_list.cpp create mode 100644 eva/ir/attribute_list.h create mode 100644 eva/ir/attributes.cpp create mode 100644 eva/ir/attributes.h create mode 100644 eva/ir/constant_value.h create mode 100644 eva/ir/ops.h create mode 100644 eva/ir/program.cpp create mode 100644 eva/ir/program.h create mode 100644 eva/ir/term.cpp create mode 100644 eva/ir/term.h create mode 100644 eva/ir/term_map.h create mode 100644 eva/ir/types.h create mode 100644 eva/seal/CMakeLists.txt create mode 100644 eva/seal/seal.cpp create mode 100644 eva/seal/seal.h create mode 100644 eva/seal/seal_executor.h create mode 100644 eva/serialization/CMakeLists.txt create mode 100644 eva/serialization/ckks.proto create mode 100644 eva/serialization/ckks_serialization.cpp create mode 100644 eva/serialization/eva.proto create mode 100644 eva/serialization/eva_format_version.h create mode 100644 eva/serialization/eva_serialization.cpp create mode 100644 eva/serialization/known_type.cpp create mode 100644 eva/serialization/known_type.h create mode 100644 eva/serialization/known_type.proto create mode 100644 eva/serialization/save_load.cpp create mode 100644 eva/serialization/save_load.h create mode 100644 eva/serialization/seal.proto create mode 100644 eva/serialization/seal_serialization.cpp create mode 100644 eva/util/CMakeLists.txt create mode 100644 eva/util/galois.cpp create mode 100644 eva/util/galois.h create mode 100644 eva/util/logging.cpp create mode 100644 eva/util/logging.h create mode 100644 eva/util/overloaded.h create mode 100644 eva/version.cpp create mode 100644 eva/version.h create mode 100644 examples/.gitignore create mode 100644 examples/baboon.png create mode 100644 examples/image_processing.py create mode 100644 examples/requirements.txt create mode 100644 examples/serialization.py create mode 100644 python/.gitignore create mode 100644 python/CMakeLists.txt create mode 100644 python/eva/CMakeLists.txt create mode 100644 python/eva/__init__.py create mode 100644 python/eva/ckks/__init__.py create mode 100644 python/eva/metric.py create mode 100644 python/eva/seal/__init__.py create mode 100644 python/eva/std/numeric.py create mode 100644 python/eva/wrapper.cpp create mode 100644 python/setup.py.in create mode 100755 scripts/clang-format-all.sh create mode 100644 tests/all.py create mode 100644 tests/bug_fixes.py create mode 100644 tests/common.py create mode 100644 tests/features.py create mode 100644 tests/large_programs.py create mode 100644 tests/std.py create mode 160000 third_party/Galois create mode 160000 third_party/pybind11 diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..4c18f65 --- /dev/null +++ b/.clang-format @@ -0,0 +1,3 @@ +AllowShortIfStatementsOnASingleLine: true +BasedOnStyle: llvm +FixNamespaceComments: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ecaea24 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +/.*/ +/build/ +/dist/ +eva.egg-info/ +__pycache__/ + +# In-source build files +Makefile +CMakeCache.txt +cmake_install.cmake +CPackConfig.cmake +CPackSourceConfig.cmake +compile_commands.json +CMakeFiles/ +*.pb.cc +*.pb.h +*.a +*.so +*.whl \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..526cb48 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,8 @@ +[submodule "third_party/pybind11"] + path = third_party/pybind11 + url = https://github.com/pybind/pybind11 + ignore = untracked +[submodule "third_party/Galois"] + path = third_party/Galois + url = https://github.com/IntelligentSoftwareSystems/Galois.git + ignore = untracked diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..5216ec1 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +cmake_minimum_required(VERSION 3.13) +cmake_policy(SET CMP0079 NEW) +cmake_policy(SET CMP0076 NEW) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +project(eva + VERSION 1.0.0 + LANGUAGES CXX +) + +option(USE_GALOIS "Use the Galois library for multicore homomorphic evaluation" OFF) +if(USE_GALOIS) + message("Galois based multicore support enabled") + add_definitions(-DEVA_USE_GALOIS) +endif() + +find_package(SEAL 3.6 REQUIRED) +find_package(Protobuf 3.6 REQUIRED) +find_package(Python COMPONENTS Interpreter Development) + +if(NOT Python_VERSION_MAJOR EQUAL 3) + message(FATAL_ERROR "EVA requires Python 3. Please ensure you have it + installed in a location searched by CMake.") +endif() + +add_subdirectory(third_party/pybind11) +if(USE_GALOIS) + add_subdirectory(third_party/Galois EXCLUDE_FROM_ALL) +endif() + +add_subdirectory(eva) +add_subdirectory(python) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..6257f2e --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..b140305 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,12 @@ +## Contributing + +The EVA project welcomes contributions and suggestions. Most contributions require you to +agree to a Contributor License Agreement (CLA) declaring that you have the right to, +and actually do, grant us the rights to use your contribution. For details, visit +https://cla.microsoft.com. + +Please submit all pull requests on the **contrib** branch. We will handle the final merge onto the main branch. + +When you submit a pull request, a CLA-bot will automatically determine whether you need +to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the +instructions provided by the bot. You will only need to do this once across all repositories using our CLA. \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..b2f52a2 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..ae72662 --- /dev/null +++ b/README.md @@ -0,0 +1,177 @@ +# EVA - Compiler for Microsoft SEAL + +EVA is a compiler for homomorphic encryption, that automates away the parts that require cryptographic expertise. +This gives you a simple way to write programs that operate on encrypted data without having access to the secret key. + +Think of EVA as the "C compiler" of the homomorphic world. Homomorphic computations written in EVA IR (Encrypted Vector Arithmetic Intermediate Representation) get compiled to the "assembly" of the homomorphic encryption library API. Just like C compilers free you from tricky tasks like register allocation, EVA frees you from *encryption parameter selection, rescaling insertion, relinearization*... + +EVA targets [Microsoft SEAL](https://github.com/microsoft/SEAL) — the industry leading library for fully-homomorphic encryption — and currently supports the CKKS scheme for deep computations on encrypted approximate fixed-point arithmetic. + +## Getting Started + +EVA is a native library written in C++17 with bindings for Python. Both Linux and Windows are supported. The instructions below show how to get started with EVA on Ubuntu. For building on Windows [EVA's Azure Pipelines script](azure-pipelines.yml) is a useful reference. + +### Installing Dependencies + +To install dependencies on Ubuntu 20.04: +``` +sudo apt install cmake libboost-all-dev libprotobuf-dev protobuf-compiler +``` + +Clang is recommended for compilation, as SEAL is faster when compiled with it. To install clang and set it as default: +``` +sudo apt install clang +sudo update-alternatives --install /usr/bin/cc cc /usr/bin/clang 100 +sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/clang++ 100 +``` + +Next install Microsoft SEAL version 3.6: +``` +git clone -b v3.6.0 https://github.com/microsoft/SEAL.git +cd SEAL +cmake -DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF . +make -j +sudo make install +``` +*Note that SEAL has to be installed with transparent ciphertext checking turned off, as it is not possible in general to statically ensure a program will not produce a transparent ciphertext. This does not affect the security of ciphertexts encrypted with SEAL.* + +### Building and Installing EVA + +#### Building EVA + +EVA builds with CMake version ≥ 3.13: +``` +git submodule update --init +cmake . +make -j +``` +The build process creates a `setup.py` file in `python/`. To install the package for development with PIP: +``` +python3 -m pip install -e python/ +``` +To create a Python Wheel package for distribution in `dist/`: +``` +python3 python/setup.py bdist_wheel --dist-dir='.' +``` + +To check that the installed Python package is working correctly, run all tests with: +``` +python3 tests/all.py +``` + +EVA does not yet support installing the native library for use in other CMake projects (contributions very welcome). + +#### Multicore Support + +EVA features highly scalable multicore support using the [Galois library](https://github.com/IntelligentSoftwareSystems/Galois). It is included as a submodule, but is turned off by default for faster builds and easier debugging. To build EVA with Galois configure with `USE_GALOIS=ON`: +``` +cmake -DUSE_GALOIS=ON . +``` + +### Running the Examples + +The examples use EVA's Python APIs. To install dependencies with PIP: +``` +python3 -m pip install -r examples/requirements.txt +``` + +To run for example the image processing example in EVA/examples: +``` +cd examples/ +python3 image_processing.py +``` +This will compile and run homomorphic evaluations of a Sobel edge detection filter and a Harris corner detection filter on `examples/baboon.png`, producing results of homomorphic evaluation in `*_encrypted.png` and reference results from normal execution in `*_reference.png`. +The script also reports the mean squared error between these for each filter. + +## Programming with PyEVA + +PyEVA is a thin Python-embedded DSL for producing EVA programs. +We will walk you through compiling a PyEVA program with EVA and running it on top of SEAL. + +### Writing and Compiling Programs + +A program to evaluate a fixed polynomial 3x2+5x-2 on 1024 encrypted values can be written: +``` +from eva import * +poly = EvaProgram('Polynomial', vec_size=1024) +with poly: + x = Input('x') + Output('y', 3*x**2 + 5*x - 2) +``` +Next we will compile this program for the [CKKS encryption scheme](https://eprint.iacr.org/2016/421.pdf). +Two additional pieces of information EVA currently requires to compile for CKKS are the *fixed-point scale for inputs* and the *maximum ranges of coefficients in outputs*, both represented in number of bits: +``` +poly.set_output_ranges(30) +poly.set_input_scales(30) +``` +Now the program can be compiled: +``` +from eva.ckks import * +compiler = CKKSCompiler() +compiled_poly, params, signature = compiler.compile(poly) +``` +The `compile` method transforms the program in-place and returns: + +1. the compiled program; +2. encryption parameters for Microsoft SEAL with which the program can be executed; +3. a signature object, that specifies how inputs and outputs need to be encoded and decoded. + +The compiled program can be inspected by printing it in the DOT format for the [Graphviz](https://graphviz.org/) visualization software: +``` +print(compiled_poly.to_DOT()) +``` +The output can be viewed as a graph in, for example, a number of Graphviz editors available online. + +### Generating Keys and Encrypting Inputs + +Encryption keys can now be generated using the encryption parameters: +``` +from eva.seal import * +public_ctx, secret_ctx = generate_keys(params) +``` +Next a dictionary of inputs is created and encrypted using the public context and the program signature: +``` +inputs = { 'x': [i for i in range(compiled_poly.vec_size)] } +encInputs = public_ctx.encrypt(inputs, signature) +``` + +### Homomorphic Execution + +Everything is now in place for executing the program with Microsoft SEAL: +``` +encOutputs = public_ctx.execute(compiled_poly, encInputs) +``` + +### Decrypting Results + +Finally, the outputs can be decrypted using the secret context: +``` +outputs = secret_ctx.decrypt(encOutputs, signature) +``` +For debugging it is often useful to compare homomorphic results to unencrypted computation. +The `evaluate` method can be used to execute an EVA program on unencrypted data. +The two sets of results can then be compared with for example Mean Squared Error: + +``` +from eva.metric import valuation_mse +reference = evaluate(compiled_poly, inputs) +print('MSE', valuation_mse(outputs, reference)) +``` + +## Contributing + +The EVA project welcomes contributions and suggestions. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details. + +## Credits + +This project is a collaboration between the Microsoft Research's Research in Software Engineering (RiSE) group and Cryptography and Privacy Research group. + +A huge credit goes to [Dr. Roshan Dathathri](https://roshandathathri.github.io/), who as an intern built the first version of EVA, along with all the transformations required for targeting the CKKS scheme efficiently and the parallelizing runtime required to make execution scale. + +Many thanks to [Sangeeta Chowdhary](https://www.ilab.cs.rutgers.edu/~sc1696/), who as an intern put a huge amount of work into making EVA ready for release. + +## Publications + +Roshan Dathathri, Blagovesta Kostova, Olli Saarikivi, Wei Dai, Kim Laine, Madanlal Musuvathi. *EVA: An Encrypted Vector Arithmetic Language and Compiler for Efficient Homomorphic Computation*. PLDI 2020. [arXiv](https://arxiv.org/abs/1912.11951) [DOI](https://doi.org/10.1145/3385412.3386023) + +Roshan Dathathri, Olli Saarikivi, Hao Chen, Kim Laine, Kristin Lauter, Saeed Maleki, Madanlal Musuvathi, Todd Mytkowicz. *CHET: An Optimizing Compiler for Fully-Homomorphic Neural-Network Inferencing*. PLDI 2019. [DOI](https://doi.org/10.1145/3314221.3314628) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..f7b8998 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + \ No newline at end of file diff --git a/SUPPORT.md b/SUPPORT.md new file mode 100644 index 0000000..61af943 --- /dev/null +++ b/SUPPORT.md @@ -0,0 +1,15 @@ +# Support + +## How to file issues and get help + +This project uses GitHub Issues to track bugs and feature requests. Please search the existing +issues before filing new issues to avoid duplicates. For new issues, file your bug or +feature request as a new Issue. + +For help and questions about using this project, you can contact the EVA compiler team at +[evacompiler@microsoft.com](mailto:evacompiler@microsoft.com). +We are very interested in helping early adopters start to use EVA. + +## Microsoft Support Policy + +Support for EVA is limited to the resources listed above. \ No newline at end of file diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 0000000..2f48817 --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,108 @@ +# EVA pipeline + +trigger: +- master + +pool: + vmImage: 'windows-latest' + +steps: +- task: UsePythonVersion@0 + displayName: 'Ensure Python 3.x' + inputs: + versionSpec: '3.x' + addToPath: true + architecture: 'x64' + +- task: securedevelopmentteam.vss-secure-development-tools.build-task-credscan.CredScan@2 + displayName: 'Run CredScan' + inputs: + toolMajorVersion: 'V2' + outputFormat: sarif + debugMode: false + +- task: CmdLine@2 + displayName: 'Get SEAL source code' + inputs: + script: | + rem Use github repo + git clone https://github.com/microsoft/SEAL.git + cd SEAL + rem Use 3.6.0 specifically + git checkout 3.6.0 + workingDirectory: '$(Build.SourcesDirectory)/third_party' + +- task: CMake@1 + displayName: 'Configure SEAL' + inputs: + cmakeArgs: '-DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF -DALLOW_COMMAND_LINE_BUILD=ON -DSEAL_USE_MSGSL=OFF -DSEAL_USE_ZLIB=OFF -DSEAL_USE_ZSTD=OFF .' + workingDirectory: $(Build.SourcesDirectory)/third_party/SEAL + +- task: MSBuild@1 + displayName: 'Build SEAL' + inputs: + solution: '$(Build.SourcesDirectory)/third_party/SEAL/SEAL.sln' + msbuildArchitecture: 'x64' + platform: 'x64' + configuration: 'Debug' + +- task: CmdLine@2 + displayName: 'Get vcpkg' + inputs: + script: 'git clone https://github.com/microsoft/vcpkg.git' + workingDirectory: '$(Build.SourcesDirectory)/third_party' + +- task: CmdLine@2 + displayName: 'Bootstrap vcpkg' + inputs: + script: '$(Build.SourcesDirectory)/third_party/vcpkg/bootstrap-vcpkg.bat' + workingDirectory: '$(Build.SourcesDirectory)/third_party/vcpkg' + +- task: PowerShell@2 + displayName: 'Get protobuf compiler' + inputs: + targetType: 'inline' + script: | + mkdir protobuf + cd protobuf + Invoke-WebRequest -Uri "https://github.com/protocolbuffers/protobuf/releases/download/v3.13.0/protoc-3.13.0-win64.zip" -OutFile protobufc.zip + Expand-Archive -LiteralPath protobufc.zip -DestinationPath protobufc + workingDirectory: '$(Build.SourcesDirectory)/third_party' + +- task: CmdLine@2 + displayName: 'Install protobuf library' + inputs: + script: '$(Build.SourcesDirectory)/third_party/vcpkg/vcpkg.exe install protobuf[zlib]:x64-windows' + workingDirectory: '$(Build.SourcesDirectory)/third_party/vcpkg' + +- task: CmdLine@2 + displayName: 'Create build directory' + inputs: + script: 'mkdir build' + workingDirectory: '$(Build.SourcesDirectory)' + +- task: CMake@1 + displayName: 'Configure EVA' + inputs: + cmakeArgs: .. -DSEAL_DIR=$(Build.SourcesDirectory)/third_party/SEAL/cmake -DProtobuf_INCLUDE_DIR=$(Build.SourcesDirectory)/third_party/vcpkg/packages/protobuf_x64-windows/include -DProtobuf_LIBRARY=$(Build.SourcesDirectory)/third_party/vcpkg/packages/protobuf_x64-windows/lib/libprotobuf.lib -DProtobuf_PROTOC_EXECUTABLE=$(Build.SourcesDirectory)/third_party/protobuf/protobufc/bin/protoc.exe + workingDirectory: '$(Build.SourcesDirectory)/build' + +- task: MSBuild@1 + displayName: 'Build EVA' + inputs: + solution: '$(Build.SourcesDirectory)/build/eva.sln' + msbuildArchitecture: 'x64' + platform: 'x64' + configuration: 'Debug' + +- task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: 'Component Detection' + +- task: securedevelopmentteam.vss-secure-development-tools.build-task-publishsecurityanalysislogs.PublishSecurityAnalysisLogs@2 + displayName: 'Publish Security Analysis Logs' + +- task: PublishBuildArtifacts@1 + displayName: 'Publish build artifacts' + inputs: + PathtoPublish: '$(Build.ArtifactStagingDirectory)' + artifactName: windows-drop diff --git a/eva/CMakeLists.txt b/eva/CMakeLists.txt new file mode 100644 index 0000000..bfee443 --- /dev/null +++ b/eva/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +add_library(eva STATIC + eva.cpp + version.cpp +) + +# TODO: everything except SEAL::seal should be make PRIVATE +target_link_libraries(eva PUBLIC SEAL::seal protobuf::libprotobuf) +if(USE_GALOIS) + target_link_libraries(eva PUBLIC Galois::shmem numa) +endif() +target_include_directories(eva + PUBLIC + $ + $ + $ +) +target_compile_definitions(eva PRIVATE EVA_VERSION_STR="${PROJECT_VERSION}") + +add_subdirectory(util) +add_subdirectory(serialization) +add_subdirectory(ir) +add_subdirectory(common) +add_subdirectory(ckks) +add_subdirectory(seal) diff --git a/eva/ckks/CMakeLists.txt b/eva/ckks/CMakeLists.txt new file mode 100644 index 0000000..3c325a4 --- /dev/null +++ b/eva/ckks/CMakeLists.txt @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(eva PRIVATE + ckks_config.cpp +) diff --git a/eva/ckks/always_rescaler.h b/eva/ckks/always_rescaler.h new file mode 100644 index 0000000..b1ffa98 --- /dev/null +++ b/eva/ckks/always_rescaler.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/rescaler.h" + +namespace eva { + +class AlwaysRescaler : public Rescaler { + std::uint32_t minScale; + +public: + AlwaysRescaler(Program &g, TermMap &type, + TermMapOptional &scale) + : Rescaler(g, type, scale) { + // ASSUME: minScale is max among all the cipher inputs' scale + minScale = 0; + for (auto &source : program.getSources()) { + if (scale[source] > minScale) minScale = scale[source]; + } + assert(minScale != 0); + } + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + if (term->numOperands() == 0) return; // inputs + if (type[term] == Type::Raw) { + handleRawScale(term); + return; + } + + if (isRescaleOp(term->op)) return; // already processed + + if (!isMultiplicationOp(term->op)) { + // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, + // Op::RotateRightConst copy scale of the first operand + scale[term] = scale[term->operandAt(0)]; + if (isAdditionOp(term->op)) { + // Op::Add, Op::Sub + // assert that all operands have the same scale + for (auto &operand : term->getOperands()) { + if (type[operand] != Type::Raw) { + assert(scale[term] == scale[operand] || type[operand] == Type::Raw); + } + } + } + return; + } + + // Op::Mul only + // ASSUME: only two operands + std::uint32_t multScale = 0; + for (auto &operand : term->getOperands()) { + multScale += scale[operand]; + } + assert(multScale != 0); + scale[term] = multScale; + + // always rescale + insertRescale(term, multScale - minScale); + } +}; + +} // namespace eva diff --git a/eva/ckks/ckks_compiler.h b/eva/ckks/ckks_compiler.h new file mode 100644 index 0000000..e52f3e5 --- /dev/null +++ b/eva/ckks/ckks_compiler.h @@ -0,0 +1,309 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/always_rescaler.h" +#include "eva/ckks/ckks_config.h" +#include "eva/ckks/ckks_parameters.h" +#include "eva/ckks/ckks_signature.h" +#include "eva/ckks/eager_relinearizer.h" +#include "eva/ckks/eager_waterline_rescaler.h" +#include "eva/ckks/encode_inserter.h" +#include "eva/ckks/encryption_parameter_selector.h" +#include "eva/ckks/lazy_relinearizer.h" +#include "eva/ckks/lazy_waterline_rescaler.h" +#include "eva/ckks/levels_checker.h" +#include "eva/ckks/minimum_rescaler.h" +#include "eva/ckks/mod_switcher.h" +#include "eva/ckks/parameter_checker.h" +#include "eva/ckks/scales_checker.h" +#include "eva/ckks/seal_lowering.h" +#include "eva/common/constant_folder.h" +#include "eva/common/program_traversal.h" +#include "eva/common/reduction_balancer.h" +#include "eva/common/rotation_keys_selector.h" +#include "eva/common/type_deducer.h" +#include "eva/util/logging.h" +#include +#include + +namespace eva { + +class CKKSCompiler { + CKKSConfig config; + + void transform(Program &program, TermMap &types, + TermMapOptional &scales) { + auto programRewrite = ProgramTraversal(program); + + log(Verbosity::Debug, "Running TypeDeducer pass"); + programRewrite.forwardPass(TypeDeducer(program, types)); + log(Verbosity::Debug, "Running ConstantFolder pass"); + programRewrite.forwardPass(ConstantFolder( + program, scales)); // currently required because executor/runtime + // does not handle this + if (config.balanceReductions) { + log(Verbosity::Debug, "Running ReductionCombiner pass"); + programRewrite.forwardPass(ReductionCombiner(program)); + log(Verbosity::Debug, "Running ReductionLogExpander pass"); + programRewrite.forwardPass(ReductionLogExpander(program, types)); + } + switch (config.rescaler) { + case CKKSRescaler::Minimum: + log(Verbosity::Debug, "Running MinimumRescaler pass"); + programRewrite.forwardPass(MinimumRescaler(program, types, scales)); + break; + case CKKSRescaler::Always: + log(Verbosity::Debug, "Running AlwaysRescaler pass"); + programRewrite.forwardPass(AlwaysRescaler(program, types, scales)); + break; + case CKKSRescaler::EagerWaterline: + log(Verbosity::Debug, "Running EagerWaterlineRescaler pass"); + programRewrite.forwardPass( + EagerWaterlineRescaler(program, types, scales)); + break; + case CKKSRescaler::LazyWaterline: + log(Verbosity::Debug, "Running LazyWaterlineRescaler pass"); + programRewrite.forwardPass(LazyWaterlineRescaler(program, types, scales)); + break; + default: + throw std::logic_error("Unhandled rescaler in CKKSCompiler."); + } + log(Verbosity::Debug, "Running TypeDeducer pass"); + programRewrite.forwardPass(TypeDeducer(program, types)); + + log(Verbosity::Debug, "Running EncodeInserter pass"); + programRewrite.forwardPass(EncodeInserter(program, types, scales)); + log(Verbosity::Debug, "Running TypeDeducer pass"); + programRewrite.forwardPass(TypeDeducer(program, types)); + // TODO: rerunning the type deducer at every step is wasteful, but also + // forcing other passes to always keep type information up to date isn't + // something they should need to do. Type deduction should be changed + // into a thing that is done as needed locally. + if (config.lazyRelinearize) { + log(Verbosity::Debug, "Running LazyRelinearizer pass"); + programRewrite.forwardPass(LazyRelinearizer(program, types, scales)); + } else { + log(Verbosity::Debug, "Running EagerRelinearizer pass"); + programRewrite.forwardPass(EagerRelinearizer(program, types, scales)); + } + log(Verbosity::Debug, "Running TypeDeducer pass"); + programRewrite.forwardPass(TypeDeducer(program, types)); + log(Verbosity::Debug, "Running ModSwitcher pass"); + programRewrite.backwardPass(ModSwitcher(program, types, scales)); + log(Verbosity::Debug, "Running TypeDeducer pass"); + programRewrite.forwardPass(TypeDeducer(program, types)); + log(Verbosity::Debug, "Running SEALLowering pass"); + programRewrite.forwardPass(SEALLowering(program, types)); + } + + void validate(Program &program, TermMap &types, + TermMapOptional &scales) { + auto programTraverse = ProgramTraversal(program); + log(Verbosity::Debug, "Running LevelsChecker pass"); + LevelsChecker lc(program, types); + programTraverse.forwardPass(lc); + try { + log(Verbosity::Debug, "Running ParameterChecker pass"); + ParameterChecker pc(program, types); + programTraverse.forwardPass(pc); + } catch (const InconsistentParameters &e) { + switch (config.rescaler) { + case CKKSRescaler::Minimum: + throw std::runtime_error( + "The 'minimum' rescaler produced inconsistent parameters. Note " + "that this rescaling policy is not general and thus will not work " + "for all programs. Please use a different rescaler for this " + "program."); + case CKKSRescaler::Always: + throw std::runtime_error( + "The 'always' rescaler produced inconsistent parameters. Note that " + "this rescaling policy is not general. It is only guaranteed to " + "work for programs that have equal scale for all inputs and " + "constants."); + default: + throw std::runtime_error( + "The current rescaler produced inconsistent parameters. This is a " + "bug, as this rescaler should be able to handle all programs."); + } + } + log(Verbosity::Debug, "Running ScalesChecker pass"); + ScalesChecker sc(program, scales, types); + programTraverse.forwardPass(sc); + } + + std::size_t getMinDegreeForBitCount(int (*MaxBitsFun)(std::size_t), + int bitCount) { + std::size_t degree = 1024; + int maxBitsSeen = 0; + while (true) { + auto maxBitsForDegree = MaxBitsFun(degree); + maxBitsSeen = std::max(maxBitsSeen, maxBitsForDegree); + if (maxBitsForDegree == 0) { + throw std::runtime_error( + "Program requires a " + std::to_string(bitCount) + + " bit modulus, but parameters are available for a maximum of " + + std::to_string(maxBitsSeen)); + } + if (maxBitsForDegree >= bitCount) { + return degree; + } + degree *= 2; + } + } + + void determineEncryptionParameters(Program &program, + CKKSParameters &encParams, + TermMapOptional &scales, + TermMap types) { + auto programTraverse = ProgramTraversal(program); + log(Verbosity::Debug, "Running EncryptionParametersSelector pass"); + EncryptionParametersSelector eps(program, scales, types); + programTraverse.forwardPass(eps); + log(Verbosity::Debug, "Running RotationKeysSelector pass"); + RotationKeysSelector rks(program, types); + programTraverse.forwardPass(rks); + encParams.primeBits = eps.getEncryptionParameters(); + encParams.rotations = rks.getRotationKeys(); + + int bitCount = 0; + for (auto &logQ : encParams.primeBits) + bitCount += logQ; + if (config.securityLevel <= 128) { + if (config.quantumSafe) + encParams.polyModulusDegree = getMinDegreeForBitCount( + &seal::util::seal_he_std_parms_128_tq, bitCount); + else + encParams.polyModulusDegree = getMinDegreeForBitCount( + &seal::util::seal_he_std_parms_128_tc, bitCount); + } else if (config.securityLevel <= 192) { + if (config.quantumSafe) + encParams.polyModulusDegree = getMinDegreeForBitCount( + &seal::util::seal_he_std_parms_192_tq, bitCount); + else + encParams.polyModulusDegree = getMinDegreeForBitCount( + &seal::util::seal_he_std_parms_192_tc, bitCount); + } else if (config.securityLevel <= 256) { + if (config.quantumSafe) + encParams.polyModulusDegree = getMinDegreeForBitCount( + &seal::util::seal_he_std_parms_256_tq, bitCount); + else + encParams.polyModulusDegree = getMinDegreeForBitCount( + &seal::util::seal_he_std_parms_256_tc, bitCount); + } else { + throw std::runtime_error( + "EVA has support for up to 256 bit security, but " + + std::to_string(config.securityLevel) + + " bit security was requested."); + } + + auto slots = encParams.polyModulusDegree / 2; + if (config.warnVecSize && slots > program.getVecSize()) { + warn("Program specifies vector size %i while at least %i slots are " + "required for security. " + "This does not affect correctness, as the smaller vector size will " + "be transparently emulated. " + "However, using a vector size up to %i would come at no additional " + "cost.", + program.getVecSize(), slots, slots); + } + if (slots < program.getVecSize()) { + if (config.warnVecSize) { + warn("Program uses vector size %i while only %i slots are required for " + "security. " + "This does not affect correctness, but higher performance may be " + "available " + "with a smaller vector size.", + program.getVecSize(), slots); + } + encParams.polyModulusDegree = 2 * program.getVecSize(); + } + + if (verbosityAtLeast(Verbosity::Info)) { + printf("EVA: Encryption parameters for %s are:\n Q = [", + program.getName().c_str()); + bool first = true; + for (auto &logQ : encParams.primeBits) { + if (first) { + first = false; + printf("%i", logQ); + } else { + printf(",%i", logQ); + } + } + int n = encParams.polyModulusDegree; + int nexp = 0; + while (n >>= 1) + ++nexp; + printf("] (total bits %i)\n N = 2^%i (available slots %i)\n Rotation " + "keys: ", + bitCount, nexp, encParams.polyModulusDegree / 2); + first = true; + for (auto &rotation : encParams.rotations) { + if (first) { + first = false; + printf("%i", rotation); + } else { + printf(", %i", rotation); + } + } + printf(" (count %lu)\n", encParams.rotations.size()); + } + } + + CKKSSignature extractSignature(const Program &program) { + std::unordered_map inputs; + for (auto &input : program.getInputs()) { + Type type = input.second->get(); + assert(type != Type::Undef); + + inputs.emplace( + input.first, + CKKSEncodingInfo(type, input.second->get(), + input.second->get())); + } + return CKKSSignature(program.getVecSize(), std::move(inputs)); + } + +public: + CKKSCompiler() {} + CKKSCompiler(CKKSConfig config) : config(config) {} + + std::tuple, CKKSParameters, CKKSSignature> + compile(Program &inputProgram) { + auto program = inputProgram.deepCopy(); + + log(Verbosity::Info, "Compiling %s for CKKS with:\n%s", + program->getName().c_str(), config.toString(2).c_str()); + + TermMap types(*program); + TermMapOptional scales(*program); + for (auto &source : program->getSources()) { + // Error out if the scale attribute doesn't exist + if (!source->has()) { + for (auto &entry : program->getInputs()) { + if (source == entry.second) { + throw std::runtime_error("The scale for input " + entry.first + + " was not set."); + } + } + throw std::runtime_error("The scale for a constant was not set."); + } + // Copy the scale from the attribute into the scales TermMap + scales[source] = source->get(); + } + + CKKSParameters encParams; + transform(*program, types, scales); + validate(*program, types, scales); + determineEncryptionParameters(*program, encParams, scales, types); + + auto signature = extractSignature(*program); + + return std::make_tuple(std::move(program), std::move(encParams), + std::move(signature)); + } +}; + +} // namespace eva diff --git a/eva/ckks/ckks_config.cpp b/eva/ckks/ckks_config.cpp new file mode 100644 index 0000000..7419e9a --- /dev/null +++ b/eva/ckks/ckks_config.cpp @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/ckks/ckks_config.h" +#include "eva/util/logging.h" +#include + +namespace eva { + +CKKSConfig::CKKSConfig( + const std::unordered_map &configMap) { + for (const auto &entry : configMap) { + const auto &option = entry.first; + const auto &valueStr = entry.second; + if (option == "balance_reductions") { + std::istringstream is(valueStr); + is >> std::boolalpha >> balanceReductions; + if (is.bad()) { + warn("Could not parse boolean in balance_reductions=%s. Falling back " + "to default.", + valueStr.c_str()); + } + } else if (option == "rescaler") { + if (valueStr == "lazy_waterline") { + rescaler = CKKSRescaler::LazyWaterline; + } else if (valueStr == "eager_waterline") { + rescaler = CKKSRescaler::EagerWaterline; + } else if (valueStr == "always") { + rescaler = CKKSRescaler::Always; + } else if (valueStr == "minimum") { + rescaler = CKKSRescaler::Minimum; + } else { + // Please update this warning message when adding new options to the + // cases above + warn("Unknown value rescaler=%s. Available rescalers are " + "lazy_waterline, eager_waterline, always, minimum. Falling back " + "to default.", + valueStr.c_str()); + } + } else if (option == "lazy_relinearize") { + std::istringstream is(valueStr); + is >> std::boolalpha >> lazyRelinearize; + if (is.bad()) { + warn("Could not parse boolean in lazy_relinearize=%s. Falling back to " + "default.", + valueStr.c_str()); + } + } else if (option == "security_level") { + std::istringstream is(valueStr); + is >> securityLevel; + if (is.bad()) { + throw std::runtime_error( + "Could not parse unsigned int in security_level=" + valueStr); + } + } else if (option == "quantum_safe") { + std::istringstream is(valueStr); + is >> std::boolalpha >> quantumSafe; + if (is.bad()) { + throw std::runtime_error("Could not parse boolean in quantum_safe=" + + valueStr); + } + } else if (option == "warn_vec_size") { + std::istringstream is(valueStr); + is >> std::boolalpha >> warnVecSize; + if (is.bad()) { + warn("Could not parse boolean in warn_vec_size=%s. Falling " + "back to default.", + valueStr.c_str()); + } + } else { + warn("Unknown option %s. Available options are:\n%s", option.c_str(), + OPTIONS_HELP_MESSAGE); + } + } +} + +std::string CKKSConfig::toString(int indent) const { + auto indentStr = std::string(indent, ' '); + std::stringstream s; + s << std::boolalpha; + s << indentStr << "balance_reductions = " << balanceReductions; + s << '\n'; + s << indentStr << "rescaler = "; + switch (rescaler) { + case CKKSRescaler::LazyWaterline: + s << "lazy_waterline"; + break; + case CKKSRescaler::EagerWaterline: + s << "eager_waterline"; + break; + case CKKSRescaler::Always: + s << "always"; + break; + case CKKSRescaler::Minimum: + s << "minimum"; + break; + } + s << '\n'; + s << indentStr << "lazy_relinearize = " << lazyRelinearize; + s << '\n'; + s << indentStr << "security_level = " << securityLevel; + s << '\n'; + s << indentStr << "quantum_safe = " << quantumSafe; + s << '\n'; + s << indentStr << "warn_vec_size = " << warnVecSize; + return s.str(); +} + +} // namespace eva diff --git a/eva/ckks/ckks_config.h b/eva/ckks/ckks_config.h new file mode 100644 index 0000000..de9dc3e --- /dev/null +++ b/eva/ckks/ckks_config.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +namespace eva { + +// clang-format off +const char *const OPTIONS_HELP_MESSAGE = + "balance_reductions - Balance trees of mul, add or sub operations. bool (default=true)\n" + "rescaler - Rescaling policy. One of: lazy_waterline (default), eager_waterline, always, minimum\n" + "lazy_relinearize - Relinearize as late as possible. bool (default=true)\n" + "security_level - How many bits of security parameters should be selected for. int (default=128)\n" + "quantum_safe - Select quantum safe parameters. bool (default=false)\n" + "warn_vec_size - Warn about possibly inefficient vector size selection. bool (default=true)"; +// clang-format on + +enum class CKKSRescaler { LazyWaterline, EagerWaterline, Always, Minimum }; + +// Controls the behavior of CKKSCompiler +class CKKSConfig { +public: + CKKSConfig() {} + CKKSConfig(const std::unordered_map &configMap); + + std::string toString(int indent = 0) const; + + bool balanceReductions = true; + CKKSRescaler rescaler = CKKSRescaler::LazyWaterline; + bool lazyRelinearize = true; + uint32_t securityLevel = 128; + bool quantumSafe = false; + + // Warnings + bool warnVecSize = true; +}; + +} // namespace eva diff --git a/eva/ckks/ckks_parameters.h b/eva/ckks/ckks_parameters.h new file mode 100644 index 0000000..82355d6 --- /dev/null +++ b/eva/ckks/ckks_parameters.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/serialization/ckks.pb.h" +#include +#include +#include +#include + +namespace eva { + +struct CKKSParameters { + std::vector primeBits; // in log-scale + std::set rotations; + std::uint32_t polyModulusDegree; +}; + +std::unique_ptr serialize(const CKKSParameters &); +std::unique_ptr deserialize(const msg::CKKSParameters &); + +} // namespace eva diff --git a/eva/ckks/ckks_signature.h b/eva/ckks/ckks_signature.h new file mode 100644 index 0000000..da66658 --- /dev/null +++ b/eva/ckks/ckks_signature.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/types.h" +#include "eva/serialization/ckks.pb.h" +#include +#include +#include + +namespace eva { + +// TODO: make these structs immutable + +struct CKKSEncodingInfo { + Type inputType; + int scale; + int level; + + CKKSEncodingInfo(Type inputType, int scale, int level) + : inputType(inputType), scale(scale), level(level) {} +}; + +struct CKKSSignature { + int vecSize; + std::unordered_map inputs; + + CKKSSignature(int vecSize, + std::unordered_map inputs) + : vecSize(vecSize), inputs(inputs) {} +}; + +std::unique_ptr serialize(const CKKSSignature &); +std::unique_ptr deserialize(const msg::CKKSSignature &); + +} // namespace eva diff --git a/eva/ckks/eager_relinearizer.h b/eva/ckks/eager_relinearizer.h new file mode 100644 index 0000000..64d5a99 --- /dev/null +++ b/eva/ckks/eager_relinearizer.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class EagerRelinearizer { + Program &program; + TermMap &type; + TermMapOptional &scale; + + bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } + + bool isUnencryptedType(const Type &type) { return type != Type::Cipher; } + + bool areAllOperandsEncrypted(Term::Ptr &term) { + for (auto &op : term->getOperands()) { + if (isUnencryptedType(type[op])) { + return false; + } + } + return true; + } + +public: + EagerRelinearizer(Program &g, TermMap &type, + TermMapOptional &scale) + : program(g), type(type), scale(scale) {} + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + auto &operands = term->getOperands(); + if (operands.size() == 0) return; // inputs + + auto op = term->op; + + if (!isMultiplicationOp(op)) return; + + // Op::Multiply only + // ASSUME: only two operands + bool encryptedOps = areAllOperandsEncrypted(term); + if (!encryptedOps) return; + + auto relinNode = program.makeTerm(Op::Relinearize, {term}); + type[relinNode] = type[term]; + scale[relinNode] = scale[term]; + + term->replaceOtherUsesWith(relinNode); + } +}; + +} // namespace eva diff --git a/eva/ckks/eager_waterline_rescaler.h b/eva/ckks/eager_waterline_rescaler.h new file mode 100644 index 0000000..1fee979 --- /dev/null +++ b/eva/ckks/eager_waterline_rescaler.h @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/rescaler.h" +#include "eva/util/logging.h" + +namespace eva { + +class EagerWaterlineRescaler : public Rescaler { + std::uint32_t minScale; + const std::uint32_t fixedRescale = 60; + +public: + EagerWaterlineRescaler(Program &g, TermMap &type, + TermMapOptional &scale) + : Rescaler(g, type, scale) { + // ASSUME: minScale is max among all the inputs' scale + minScale = 0; + for (auto &source : program.getSources()) { + if (scale[source] > minScale) minScale = scale[source]; + } + assert(minScale != 0); + } + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + if (term->numOperands() == 0) return; // inputs + if (type[term] == Type::Raw) { + handleRawScale(term); + return; + } + + if (isRescaleOp(term->op)) return; // already processed + + if (!isMultiplicationOp(term->op)) { + // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, + // Op::RotateRightConst copy scale of the first operand + scale[term] = scale[term->operandAt(0)]; + if (isAdditionOp(term->op)) { + // Op::Add, Op::Sub + auto maxScale = scale[term]; + for (auto &operand : term->getOperands()) { + // Here we allow raw operands to possibly raise the scale + if (scale[operand] > maxScale) maxScale = scale[operand]; + } + for (auto &operand : term->getOperands()) { + if (scale[operand] < maxScale && type[operand] != Type::Raw) { + log(Verbosity::Trace, + "Scaling up t%i from scale %i to match other addition operands " + "at scale %i", + operand->index, scale[operand], maxScale); + + auto scaleConstant = program.makeUniformConstant(1); + scale[scaleConstant] = maxScale - scale[operand]; + scaleConstant->set(scale[scaleConstant]); + + auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant}); + scale[mulNode] = maxScale; + + // TODO: Not obviously correct as it's modifying inside + // iteration. Refine API to make this less surprising. + term->replaceOperand(operand, mulNode); + } + } + // assert that all operands have the same scale + for (auto &operand : term->getOperands()) { + assert(maxScale == scale[operand] || type[operand] == Type::Raw); + } + scale[term] = maxScale; + } + return; + } + + // Op::Mul only + // ASSUME: only two operands + std::uint32_t multScale = 0; + for (auto &operand : term->getOperands()) { + multScale += scale[operand]; + } + assert(multScale != 0); + scale[term] = multScale; + + // rescale only if above the waterline + auto temp = term; + while (multScale >= (fixedRescale + minScale)) { + temp = insertRescale(temp, fixedRescale); + multScale -= fixedRescale; + assert(multScale == scale[temp]); + } + } +}; + +} // namespace eva diff --git a/eva/ckks/encode_inserter.h b/eva/ckks/encode_inserter.h new file mode 100644 index 0000000..f42fe90 --- /dev/null +++ b/eva/ckks/encode_inserter.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class EncodeInserter { + Program &program; + TermMap &type; + TermMapOptional &scale; + + bool isRawType(const Type &type) { return type == Type::Raw; } + bool isCipherType(const Type &type) { return type == Type::Cipher; } + bool isAdditionOp(const Op &op_code) { + return ((op_code == Op::Add) || (op_code == Op::Sub)); + } + + auto insertEncodeNode(Op op, const Term::Ptr &other, const Term::Ptr &term) { + auto newNode = program.makeTerm(Op::Encode, {term}); + type[newNode] = Type::Plain; + if (isAdditionOp(op)) { + scale[newNode] = scale[other]; + } else { + scale[newNode] = scale[term]; + } + newNode->set(scale[newNode]); + return newNode; + } + +public: + EncodeInserter(Program &g, TermMap &type, + TermMapOptional &scale) + : program(g), type(type), scale(scale) {} + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + auto &operands = term->getOperands(); + if (operands.size() == 0) return; // inputs + + assert(operands.size() <= 2); + if (operands.size() == 2) { + auto &leftOperand = operands[0]; + auto &rightOperand = operands[1]; + auto op1 = leftOperand->op; + if (isCipherType(type[leftOperand]) && isRawType(type[rightOperand])) { + auto newTerm = insertEncodeNode(term->op, leftOperand, rightOperand); + term->replaceOperand(rightOperand, newTerm); + } + + if (isCipherType(type[rightOperand]) && isRawType(type[leftOperand])) { + auto newTerm = insertEncodeNode(term->op, rightOperand, leftOperand); + term->replaceOperand(leftOperand, newTerm); + } + } + } +}; + +} // namespace eva diff --git a/eva/ckks/encryption_parameter_selector.h b/eva/ckks/encryption_parameter_selector.h new file mode 100644 index 0000000..8002a7e --- /dev/null +++ b/eva/ckks/encryption_parameter_selector.h @@ -0,0 +1,210 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include +#include +#include +#include + +namespace eva { + +class EncryptionParametersSelector { +public: + EncryptionParametersSelector(Program &g, + TermMapOptional &scales, + TermMap &types) + : program_(g), scales_(scales), terms_(g), types(types) {} + + void operator()(const Term::Ptr &term) { + // This function computes, for each term, the set of coeff_modulus primes + // needed to reach that term, taking only into account rescalings. Primes + // needed to hold the output values are not included. For example, input + // terms require no extra primes, so for input terms this function will + // assign an empty set of primes. The example below shows parameters + // assigned in a simple example computation, where we rescale by 40 bits: + // + // In_1:{} In_2:{} In_3:{} + // \ \ / + // \ \ / + // \ * MULTIPLY:{} + // \ | + // \ | + // \ * RESCALE:{40} + // \ | + // \ | + // -----* ADD:{40} + // | + // | + // Out_1:{40} + + // This function must only be used with forward pass traversal, as it + // expects operand terms to have been processed already. + if (types[term] == Type::Raw || term->op == Op::Encode) { + return; + } + auto &operands = term->getOperands(); + + // Nothing to do for inputs + if (operands.size() > 0) { + // Get the parameters for this term + auto &parms = terms_[term]; + + for (auto &operand : operands) { + // Get the parameters for each operand (forward pass) + auto &operandParms = terms_[operand]; + + // Set the parameters for this term to be the maximum over operands + if (operandParms.size() > parms.size()) { + parms = operandParms; + } + } + + // Adjust the parameters if this term is a rescale operation + // NOTE: This is ignoring modulus switches, but still works because there + // is always a longest path with no modulus switches. + // TODO: Validate this claim and generalize to include modulus switches. + if (isRescaleOp(term->op)) { + auto newSize = parms.size() + 1; + + // By how much are we rescaling? + auto divisor = term->get(); + assert(divisor != 0); + + // Add the required scaling factor to the parameters + parms.push_back(divisor); + assert(parms.size() == newSize); + } + } + } + + inline void free(const Term::Ptr &term) { terms_[term].clear(); } + + auto getEncryptionParameters() { + // This function returns the encryption parameters (really just a list of + // prime bit counts for the coeff_modulus) needed for this computation. It + // can be called after forward pass traversal has computed the rescaling + // primes for all terms. + // + // The logic is simple: we loop over each output term as those have the + // largest (largest number of primes) parameter sets after forward + // traversal, and find the largest parameter set among those. This set will + // work globally for the computation. Since the parameters are not taking + // into account the need for storing the result for the output terms, we + // need to add one or more additional primes to the parameters, depending on + // the scales and the ranges of the terms. For example, if the output term + // has a parameter set {40} after forward traversal, with a scale and range + // of 40 and 16 bits, respectively, the result requires an additional 56-bit + // prime in the parameter set. This prime is always added in the set before + // the rescaling primes, so in this case the function would return {56,40}. + // If the scale and range are very large, this function will add more than + // one extra prime. + + std::vector parms; + + // The size in bits needed for the output value; this includes the scale and + // the range + std::uint32_t maxOutputSize = 0; + + // The bit count of the largest prime appearing in the parameters + std::uint32_t maxParm = 0; + + // The largest (largest number of primes) set of parameters required among + // all output terms + std::uint32_t maxLen = 0; + + // Loop over each output term + for (auto &entry : program_.getOutputs()) { + auto &output = entry.second; + + // The size for this output term equals the range attribute (bits) plus + // the scale (bits) + auto size = output->get(); + size += scales_[output]; + + // Update maxOutputSize + if (size > maxOutputSize) maxOutputSize = size; + + // Get the parameters for the current output term + auto &oParms = terms_[output]; + + // Update maxLen (number of primes) + if (maxLen < oParms.size()) maxLen = oParms.size(); + + // Update maxParm (largest prime) + for (auto &parm : oParms) { + if (parm > maxParm) maxParm = parm; + } + } + + // Ensure that the output size is non-zero + assert(maxOutputSize != 0); + + if (maxOutputSize > 60) { + // If the required output size is larger than 60 bits, we need to increase + // the parameters with more than one additional primes. + + // In this case maxPrime is always 60 bits + maxParm = 60; + + // Add 60-bit primes for as long as needed + while (maxOutputSize >= 60) { + parms.push_back(60); + maxOutputSize -= 60; + } + + // Add one more prime if needed + if (maxOutputSize > 0) { + // TODO: The minimum should probably depend on poly_modulus_degree + parms.push_back(std::max(20u, maxOutputSize)); + } + } else { + // The output size is less than 60 bits so the output parameters require + // only one additional prime. + + // Update maxParm + if (maxOutputSize > maxParm) maxParm = maxOutputSize; + + // Add the required prime to the parameters for this term + parms.push_back(maxParm); + } + + // Finally, loop over all output terms and add the largest parameter set to + // parms after what was pushed above. + for (auto &entry : program_.getOutputs()) { + auto &output = entry.second; + + // Get the parameters for the current output term + auto &oParms = terms_[output]; + + // If this output node has the longest parameter set, use it + if (maxLen == oParms.size()) { + parms.insert(parms.end(), oParms.rbegin(), oParms.rend()); + + // Exit the for loop; we have our parameter set + break; + } + } + + // Add maxParm to result parameters; this is the "key prime". + // TODO: This might be too aggressive. We can try smaller primes here as + // well, which in some cases is advantageous as it may result in smaller + // poly_modulus_degree, even though the noise growth may be a bit larger. + parms.push_back(maxParm); + + return parms; + } + +private: + Program &program_; + TermMapOptional &scales_; + TermMap> terms_; + TermMap &types; + + inline bool isRescaleOp(const Op &op_code) { return op_code == Op::Rescale; } +}; + +} // namespace eva diff --git a/eva/ckks/lazy_relinearizer.h b/eva/ckks/lazy_relinearizer.h new file mode 100644 index 0000000..f05a092 --- /dev/null +++ b/eva/ckks/lazy_relinearizer.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class LazyRelinearizer { + Program &program; + TermMap &type; + TermMapOptional &scale; + TermMap pending; // maintains whether relinearization is pending + std::uint32_t count; + std::uint32_t countTotal; + + bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } + + bool isRotationOp(const Op &op_code) { + return ((op_code == Op::RotateLeftConst) || + (op_code == Op::RotateRightConst)); + } + + bool isUnencryptedType(const Type &type) { return type != Type::Cipher; } + + bool areAllOperandsEncrypted(Term::Ptr &term) { + for (auto &op : term->getOperands()) { + assert(type[op] != Type::Undef); + if (isUnencryptedType(type[op])) { + return false; + } + } + return true; + } + + bool isEncryptedMultOp(Term::Ptr &term) { + return (isMultiplicationOp(term->op) && areAllOperandsEncrypted(term)); + } + +public: + LazyRelinearizer(Program &g, TermMap &type, + TermMapOptional &scale) + : program(g), type(type), scale(scale), pending(g) { + count = 0; + countTotal = 0; + } + + ~LazyRelinearizer() { + // TODO: move these to a logging system + // std::cout << "Number of delayed relin: " << count << "\n"; + // std::cout << "Number of relin: " << countTotal << "\n"; + } + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + auto &operands = term->getOperands(); + if (operands.size() == 0) return; // inputs + + bool delayed = false; + + if (isEncryptedMultOp(term)) { + assert(pending[term] == false); + pending[term] = true; + delayed = true; + } else if (pending[term] == false) { + return; + } + + bool mustInsert = false; + assert(term->numUses() > 0); + auto firstUse = term->getUses()[0]; + for (auto &use : term->getUses()) { + if (isEncryptedMultOp(use) || isRotationOp(use->op) || + use->op == Op::Output || (firstUse != use)) { // different uses + mustInsert = true; + break; + } + } + + if (mustInsert) { + auto relinNode = program.makeTerm(Op::Relinearize, {term}); + ++countTotal; + + type[relinNode] = type[term]; + scale[relinNode] = scale[term]; + term->replaceOtherUsesWith(relinNode); + } else { + if (delayed) ++count; + for (auto &use : term->getUses()) { + pending[use] = true; + } + } + } +}; + +} // namespace eva diff --git a/eva/ckks/lazy_waterline_rescaler.h b/eva/ckks/lazy_waterline_rescaler.h new file mode 100644 index 0000000..5f0a893 --- /dev/null +++ b/eva/ckks/lazy_waterline_rescaler.h @@ -0,0 +1,433 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/rescaler.h" +#include "eva/util/logging.h" + +namespace eva { + +class LazyWaterlineRescaler : public Rescaler { + std::uint32_t minScale; + const std::uint32_t fixedRescale = 60; + TermMap pending; // maintains whether rescaling is pending + // TODO: level is no longer used. Should be removed. + TermMap level; // maintains the number of rescalings (levels) + + void insertRescaleRecursive(Term::Ptr &term) { + auto temp = term; + auto termScale = scale[temp]; + std::size_t num = 0; + while (termScale >= (fixedRescale + minScale)) { + temp = insertRescale(temp, fixedRescale); + ++num; + termScale -= fixedRescale; + assert(termScale == scale[temp]); + } + level[temp] = level[term] + num; + } + +public: + LazyWaterlineRescaler(Program &g, TermMap &type, + TermMapOptional &scale) + : Rescaler(g, type, scale), pending(g), level(g) { + // ASSUME: minScale is max among all the inputs' scale + minScale = 0; + for (auto &source : program.getSources()) { + if (scale[source] > minScale) minScale = scale[source]; + } + assert(minScale != 0); + } + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + if (term->numOperands() == 0) return; // inputs + if (type[term] == Type::Raw) { + handleRawScale(term); + return; + } + + auto op = term->op; + + bool delayed = false; + + if (isRescaleOp(op)) { + return; // already processed + } else if (isMultiplicationOp(op)) { + assert(pending[term] == false); + + std::uint32_t multScale = 0; + std::uint32_t maxLevel = 0; + for (auto &operand : term->getOperands()) { + multScale += scale[operand]; + if (level[operand] > maxLevel) maxLevel = level[operand]; + + // The following assertion does not currently hold, as the + // multiplications added for matching addition operand scales can + // sometimes scale up a pending term. + // assert(pending[operand] == false); + } + assert(multScale != 0); + scale[term] = multScale; + level[term] = maxLevel; + + // rescale only if above the waterline + auto temp = term; + if (multScale >= (fixedRescale + minScale)) { + pending[term] = true; + delayed = true; + } else { + return; + } + } else { + // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, + // Op::RotateRightConst copy scale of the first operand + scale[term] = scale[term->operandAt(0)]; + level[term] = level[term->operandAt(0)]; + if (isAdditionOp(op)) { + // Op::Add, Op::Sub + std::uint32_t maxLevel = 0; + for (auto &operand : term->getOperands()) { + if (level[operand] > maxLevel) maxLevel = level[operand]; + } + level[term] = maxLevel; + + auto maxScale = scale[term]; + for (auto &operand : term->getOperands()) { + if (scale[operand] > maxScale) maxScale = scale[operand]; + } + scale[term] = maxScale; + + // ensure that all operands have same scale + for (auto &operand : term->getOperands()) { + if (scale[operand] < maxScale && type[operand] != Type::Raw) { + log(Verbosity::Trace, + "Scaling up t%i from scale %i to match other addition operands " + "at scale %i", + operand->index, scale[operand], maxScale); + + auto scaleConstant = program.makeUniformConstant(1); + scale[scaleConstant] = maxScale - scale[operand]; + scaleConstant->set(scale[scaleConstant]); + + auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant}); + scale[mulNode] = maxScale; + + term->replaceOperand(operand, mulNode); + } + } + // assert that all operands have the same scale + for (auto &operand : term->getOperands()) { + assert(maxScale == scale[operand] || type[operand] == Type::Raw); + } + } + + if (pending[term] == false) { + return; + } + } + + assert(pending[term] == true); + + bool mustInsert = false; + assert(term->numUses() > 0); + auto firstUse = term->getUses()[0]; + for (auto &use : term->getUses()) { + if (isMultiplicationOp(use->op) || use->op == Op::Output || + (firstUse != use)) { // different uses + mustInsert = true; + break; + } + } + + if (mustInsert) { + pending[term] = false; + insertRescaleRecursive(term); + } else { + for (auto &use : term->getUses()) { + pending[use] = true; + } + } + } +}; + +// This is a legacy rescaler from before the reduction balancing had support +// for estimating operand levels. +/* +class LazyWaterlineRescalerReductionLogExpander : public Rescaler { + std::uint32_t minScale; + const std::uint32_t fixedRescale = 60; + TermMap pending; // maintains whether rescaling is pending + TermMap level; // maintains the number of rescalings (levels) + std::uint32_t count; + std::uint32_t countTotal; + + bool isPlainType(const Type &type) const { return type == Type::Plain; } + bool isRawType(const Type &type) const { return type == Type::Raw; } + + bool isCipherType(const Type &type) const { return type == Type::Cipher; } + + Term::Ptr expand(Term::Ptr &leftOperand, Term::Ptr &rightOperand, + Term::Ptr &use) { + // create new operand + auto newOperand = program.makeTerm(use->op, {leftOperand, rightOperand}); + if(isCipherType(type[leftOperand]) || +isCipherType(type[rightOperand])){ type[newOperand] = Type::Cipher; + } + else if(isPlainType(type[leftOperand]) || +isPlainType(type[rightOperand])){ type[newOperand] = Type::Plain; + } + else{ + type[newOperand] = Type::Raw; + } + return newOperand; + } + + auto insertRescaleRecursive(Term::Ptr &term) { + auto temp = term; + auto termScale = scale[temp]; + std::size_t num = 0; + while (termScale >= (fixedRescale + minScale)) { + temp = insertRescale(temp, fixedRescale); + ++num; + termScale -= fixedRescale; + assert(termScale == scale[temp]); + } + level[temp] = level[term] + num; + countTotal += num; + return temp; + } + + void expandRecursive(Term::Ptr &term) { + std::vector operands; + std::map> sortedOperands; + for (auto &operand : term->getOperands()) { + auto operandLevel = level[operand]; + if (isCipherType(type[operand])) { + operandLevel += 2; + } else { + if (isPlainType(type[operand])) { + operandLevel = 1; + } + } + sortedOperands[operandLevel].push_back(operand); + } + + // SORTED ORDER: constants, plaintext, then ciphertext ordered by level + for (auto &op : sortedOperands) { + operands.insert(operands.end(), op.second.begin(), op.second.end()); + } + + bool isMultOp = isMultiplicationOp(term->op); + + std::vector nextOperands; + assert(operands.size() >= 2); + while (operands.size() > 2) { + std::size_t i = 0; + // TODO: fix this to include levels + while ((i + 1) < operands.size()) { + auto &leftOperand = operands[i]; + auto &rightOperand = operands[i + 1]; + auto newOperand = expand(leftOperand, rightOperand, term); + if (isMultOp) { + processMultiplicationOp(newOperand); + if (scale[newOperand] >= (fixedRescale + minScale)) { + newOperand = insertRescaleRecursive(newOperand); + } + } else { + processAdditionOp(newOperand); + } + nextOperands.push_back(newOperand); + i += 2; + } + if (i < operands.size()) { + assert((i + 1) == operands.size()); + nextOperands.push_back(operands[i]); + } + operands = nextOperands; + nextOperands.clear(); + } + + // clear term's operands and set it to the current operands + assert(operands.size() == 2); + term->setOperands({operands[0], operands[1]}); + } + + void processMultiplicationOp(Term::Ptr &term) { + if(type[term] == Type::Raw){ + scale[term] = fixedRescale; + return; + } + auto &operands = term->getOperands(); + + std::uint32_t multScale = 0; + std::uint32_t maxLevel = 0; + for (auto &operand : operands) { + multScale += scale[operand]; + if (level[operand] > maxLevel) maxLevel = level[operand]; + assert(pending[operand] == false); + assert(scale[operand] < (fixedRescale + minScale)); + } + assert(multScale != 0); + scale[term] = multScale; + level[term] = maxLevel; + } + + void processAdditionOp(Term::Ptr &term) { + auto &operands = term->getOperands(); + + std::uint32_t maxLevel = 0; + for (auto &operand : operands) { + if (level[operand] > maxLevel) maxLevel = level[operand]; + } + level[term] = maxLevel; + + auto maxScale = 0; + for (auto &operand : operands) { + if (scale[operand] > maxScale) maxScale = scale[operand]; + } + scale[term] = maxScale; + + // assert that pending (delaying rescaling) does not increase level + // TODO: this is insufficient + std::uint32_t maxLevel2 = 0; + for (auto &operand : operands) { + auto operandScale = scale[operand]; + auto operandLevel = level[operand]; + while (operandScale >= (fixedRescale + minScale)) { + operandScale -= fixedRescale; + ++operandLevel; + } + if (operandLevel > maxLevel2) maxLevel2 = operandLevel; + } + { + auto scale2 = maxScale; + auto level2 = maxLevel; + while (scale2 >= (fixedRescale + minScale)) { + scale2 -= fixedRescale; + ++level2; + } + assert(level2 <= maxLevel2); + } + + // ensure that all operands have same scale + for (auto &operand : operands) { + if (scale[operand] < maxScale) { + auto scaleConstant = program.makeUniformConstant(1); + scale[scaleConstant] = maxScale - scale[operand]; + scaleConstant->set(scale[scaleConstant]); + auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant}); + scale[mulNode] = maxScale; + + // TODO: Not obviously correct as it's modifying inside + // iteration. Refine API to make this less surprising. + term->replaceOperand(operand, mulNode); + } + } + // assert that all operands have the same scale + for (auto &operand : operands) { + assert(maxScale == scale[operand]); + } + } + +public: + LazyWaterlineRescalerReductionLogExpander( + Program &g, TermMap &type, TermMapOptional &scale) + : Rescaler(g, type, scale), pending(g), level(g) { + // ASSUME: minScale is max among all the inputs' scale + minScale = 0; + for (auto &source : program.getSources()) { + if (scale[source] > minScale) minScale = scale[source]; + } + assert(minScale != 0); + count = 0; + countTotal = 0; + } + + ~LazyWaterlineRescalerReductionLogExpander() { + // TODO: move these to a logging system + // std::cout << "Number of delayed rescales: " << count << "\n"; + // std::cout << "Number of rescales: " << countTotal << "\n"; + } + + // Must only be used with forward pass traversal + void operator()(Term::Ptr &term) { + auto &operands = term->getOperands(); + if (operands.size() == 0) return; // inputs + + auto op = term->op; + + bool delayed = false; + + if (isRescaleOp(op)) { + return; // already processed + } else if (isMultiplicationOp(op)) { + assert(pending[term] == false); + if (operands.size() > 2) { + expandRecursive(term); + } + + processMultiplicationOp(term); + + // rescale only if above the waterline + auto temp = term; + if (scale[term] >= (fixedRescale + minScale)) { + pending[term] = true; + delayed = true; + } else { + return; + } + } else { + // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, + // Op::RotateRightConst + if (isAdditionOp(op)) { // Op::Add, Op::Sub + if ((op == Op::Add) && (operands.size() > 2)) { + expandRecursive(term); + } + processAdditionOp(term); + } else { + // copy scale of the first operand + for (auto &operand : operands) { + scale[term] = scale[operand]; + level[term] = level[operand]; + break; + } + } + + if (pending[term] == false) { + return; + } + } + + // assert(pending[term] == true); + + bool mustInsert = false; + if (term->numUses() > 0) { + auto firstUse = term->getUses()[0]; + for (auto &use : term->getUses()) { + if (isMultiplicationOp(use->op) || + use->op == Op::Output || + (firstUse != use)) { // different uses + mustInsert = true; + break; + } + } + } else { + assert(term->op == Op::Output); + } + + if (mustInsert) { + pending[term] = false; + insertRescaleRecursive(term); + } else { + if (delayed) ++count; + for (auto &use : term->getUses()) { + pending[use] = true; + } + } + } +}; +*/ + +} // namespace eva diff --git a/eva/ckks/levels_checker.h b/eva/ckks/levels_checker.h new file mode 100644 index 0000000..986e380 --- /dev/null +++ b/eva/ckks/levels_checker.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include +#include +#include + +namespace eva { + +class LevelsChecker { +public: + LevelsChecker(Program &g, TermMap &types) + : program_(g), types_(types), levels_(g) {} + + void operator()(const Term::Ptr &term) { + // This function verifies that the levels are compatibile. It assumes the + // operand terms are processed already, so it must only be used with forward + // pass traversal. + + if (term->numOperands() == 0) { + // If this is a source node, get the encoding level + levels_[term] = term->get(); + } else { + // For other terms, the operands must all have matching level. First find + // the level of any of the ciphertext operands. + std::size_t operandLevel; + for (auto &operand : term->getOperands()) { + if (types_[operand] == Type::Cipher) { + operandLevel = levels_[operand]; + break; + } + } + + // Next verify that all operands have the same level. + for (auto &operand : term->getOperands()) { + if (types_[operand] == Type::Cipher) { + auto operandLevel2 = levels_[operand]; + assert(operandLevel == operandLevel2); + } + } + + // Incremenet the level for a rescale or modulus switch + std::size_t level = operandLevel; + if (isRescaleOp(term->op) || isModSwitchOp(term->op)) { + ++level; + } + levels_[term] = level; + } + } + + void free(const Term::Ptr &term) { + // No-op + } + +private: + Program &program_; + TermMap &types_; + + // Maintains the reverse level (leaves have 0, roots have max) + TermMap levels_; + + bool isModSwitchOp(const Op &op_code) { return (op_code == Op::ModSwitch); } + + bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } +}; + +} // namespace eva diff --git a/eva/ckks/minimum_rescaler.h b/eva/ckks/minimum_rescaler.h new file mode 100644 index 0000000..31f59e6 --- /dev/null +++ b/eva/ckks/minimum_rescaler.h @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/rescaler.h" +#include "eva/util/logging.h" + +namespace eva { + +class MinimumRescaler : public Rescaler { + std::uint32_t minScale; + const std::uint32_t maxRescale = 60; + +public: + MinimumRescaler(Program &g, TermMap &type, + TermMapOptional &scale) + : Rescaler(g, type, scale) { + // ASSUME: minScale is max among all the inputs' scale + minScale = 0; + for (auto &source : program.getSources()) { + if (scale[source] > minScale) minScale = scale[source]; + } + assert(minScale != 0); + } + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + auto &operands = term->getOperands(); + if (operands.size() == 0) return; // inputs + if (type[term] == Type::Raw) { + handleRawScale(term); + return; + } + + auto op = term->op; + + if (isRescaleOp(op)) return; // already processed + + if (!isMultiplicationOp(op)) { + // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, + // Op::RotateRightConst copy scale of the first operand + for (auto &operand : operands) { + assert(operand->op != Op::Constant); + assert(scale[operand] != 0); + scale[term] = scale[operand]; + break; + } + if (isAdditionOp(op)) { + // Op::Add, Op::Sub + auto maxScale = scale[term]; + for (auto &operand : operands) { + // Here we allow raw operands to possibly raise the scale + if (scale[operand] > maxScale) maxScale = scale[operand]; + } + for (auto &operand : operands) { + if (scale[operand] < maxScale && type[operand] != Type::Raw) { + log(Verbosity::Trace, + "Scaling up t%i from scale %i to match other addition operands " + "at scale %i", + operand->index, scale[operand], maxScale); + + auto scaleConstant = program.makeUniformConstant(1); + scale[scaleConstant] = maxScale - scale[operand]; + scaleConstant->set(scale[scaleConstant]); + + auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant}); + scale[mulNode] = maxScale; + + // TODO: Not obviously correct as it's modifying inside + // iteration. + // Refine API to make this less surprising. + term->replaceOperand(operand, mulNode); + } + } + // assert that all operands have the same scale + for (auto &operand : operands) { + assert(maxScale == scale[operand] || type[operand] == Type::Raw); + } + scale[term] = maxScale; + } + return; + } + + // Op::Multiply only + // ASSUME: only two operands + std::vector operandsCopy; + for (auto &operand : operands) { + operandsCopy.push_back(operand); + } + assert(operandsCopy.size() == 2); + std::uint32_t multScale = scale[operandsCopy[0]] + scale[operandsCopy[1]]; + assert(multScale != 0); + scale[term] = multScale; + + auto minOfScales = scale[operandsCopy[0]]; + if (minOfScales > scale[operandsCopy[1]]) + minOfScales = scale[operandsCopy[1]]; + auto rescaleBy = minOfScales - minScale; + if (rescaleBy > maxRescale) rescaleBy = maxRescale; + if ((2 * rescaleBy) >= maxRescale) { + // rescale after multiplication is inevitable + // to reduce the growth of scale, rescale both operands before + // multiplication + assert(rescaleBy <= maxRescale); + insertRescaleBetween(operandsCopy[0], term, rescaleBy); + if (operandsCopy[0] != operandsCopy[1]) { + insertRescaleBetween(operandsCopy[1], term, rescaleBy); + } + + scale[term] = multScale - (2 * rescaleBy); + } else { + // rescale only if above the waterline + auto temp = term; + while (multScale >= (maxRescale + minScale)) { + temp = insertRescale(temp, maxRescale); + multScale -= maxRescale; + assert(multScale == scale[temp]); + } + } + } +}; + +} // namespace eva diff --git a/eva/ckks/mod_switcher.h b/eva/ckks/mod_switcher.h new file mode 100644 index 0000000..33bed5b --- /dev/null +++ b/eva/ckks/mod_switcher.h @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class ModSwitcher { + Program &program; + TermMap &type; + TermMapOptional &scale; + TermMap + level; // maintains the reverse level (leaves have 0, roots have max) + std::vector encodeNodes; + + Term::Ptr insertModSwitchNode(Term::Ptr &term, std::uint32_t termLevel) { + auto newNode = program.makeTerm(Op::ModSwitch, {term}); + scale[newNode] = scale[term]; + level[newNode] = termLevel; + return newNode; + } + + bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } + + bool isCipherType(const Term::Ptr &term) const { + return type[term] == Type::Cipher; + } + +public: + ModSwitcher(Program &g, TermMap &type, + TermMapOptional &scale) + : program(g), type(type), scale(scale), level(g) {} + + ~ModSwitcher() { + auto sources = program.getSources(); + std::uint32_t maxLevel = 0; + for (auto &source : sources) { + if (level[source] > maxLevel) maxLevel = level[source]; + } + for (auto &source : sources) { + auto curLevel = maxLevel - level[source]; + source->set(curLevel); + } + + for (auto &encode : encodeNodes) { + encode->set(maxLevel - level[encode]); + } + } + + void operator()( + Term::Ptr &term) { // must only be used with backward pass traversal + if (term->numUses() == 0) return; + if (term->op == Op::Encode) { + encodeNodes.push_back(term); + } + std::map> useLevels; // ordered map + for (auto &use : term->getUses()) { + useLevels[level[use]].push_back(use); + } + + std::uint32_t termLevel = 0; + if (useLevels.size() > 1) { + auto useLevel = useLevels.rbegin(); // max to min + termLevel = useLevel->first; + ++useLevel; + + auto temp = term; + auto tempLevel = termLevel; + while (useLevel != useLevels.rend()) { + auto expectedLevel = useLevel->first; + while (tempLevel > expectedLevel) { + temp = insertModSwitchNode(temp, tempLevel); + --tempLevel; + } + for (auto &use : useLevel->second) { + use->replaceOperand(term, temp); + } + ++useLevel; + } + } else { + assert(useLevels.size() == 1); + termLevel = useLevels.begin()->first; + } + if (isRescaleOp(term->op)) { + ++termLevel; + } + level[term] = termLevel; + } +}; + +} // namespace eva diff --git a/eva/ckks/parameter_checker.h b/eva/ckks/parameter_checker.h new file mode 100644 index 0000000..52440d8 --- /dev/null +++ b/eva/ckks/parameter_checker.h @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include +#include +#include +#include + +namespace eva { + +class InconsistentParameters : public std::runtime_error { +public: + InconsistentParameters(const std::string &msg) : std::runtime_error(msg) {} +}; + +class ParameterChecker { + TermMap &types; + +public: + ParameterChecker(Program &g, TermMap &types) + : program_(g), parms_(g), types(types) {} + + void operator()(const Term::Ptr &term) { + // Must only be used with forward pass traversal + auto &operands = term->getOperands(); + if (types[term] == Type::Raw || term->op == Op::Encode) { + return; + } + if (operands.size() > 0) { + // Get the parameters for this term + auto &parms = parms_[term]; + // Loop over operands + for (auto &operand : operands) { + // Get the parameters for the operand + auto &operandParms = parms_[operand]; + + // Nothing to do if the operand parameters are empty; the operand sets + // no requirements on this node + if (operandParms.size() > 0) { + if (parms.size() > 0) { + // If the parameters for this term are already set (from a different + // operand), they must match the current operand's parameters + if (operandParms.size() != parms.size()) { + throw InconsistentParameters( + "Two operands require different number of primes"); + } + + // Loop over the primes in the parameters for this term + for (std::size_t i = 0; i < parms.size(); ++i) { + if (parms[i] == 0) { + // If any of the primes is zero (indicating a previous modulus + // switch operand term, fill in its true value from the current + // operand + parms[i] = operandParms[i]; + } else if (operandParms[i] != 0) { + // If the operand prime is non-zero, require equality + if (parms[i] != operandParms[i]) { + throw InconsistentParameters( + "Primes required by two operands do not match"); + } + } + } + } else { + // This is the first operand to impose conditions on this term; + // copy the parameters from the operand + parms = operandParms; + } + } + } + + if (isModSwitchOp(term->op)) { + // Is this a modulus switch? If so, add an extra (placeholder) zero + parms.push_back(0); + } else if (isRescaleOp(term->op)) { + // Is this a rescale? Then add a prime of the requested size + auto divisor = term->get(); + assert(divisor != 0); + parms.push_back(divisor); + } + } else { + // Get the parameters for this term + auto &parms = parms_[term]; + std::uint32_t level = term->get(); + while (level > 0) { + parms.push_back(0); + level--; + } + } + } + + void free(const Term::Ptr &term) { parms_[term].clear(); } + +private: + Program &program_; + TermMap> parms_; + + bool isModSwitchOp(const Op &op_code) { return (op_code == Op::ModSwitch); } + + bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } +}; + +} // namespace eva diff --git a/eva/ckks/rescaler.h b/eva/ckks/rescaler.h new file mode 100644 index 0000000..e20e5f5 --- /dev/null +++ b/eva/ckks/rescaler.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class Rescaler { +protected: + Program &program; + TermMap &type; + TermMapOptional &scale; + + Rescaler(Program &g, TermMap &type, + TermMapOptional &scale) + : program(g), type(type), scale(scale) {} + + bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } + + bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } + + bool isAdditionOp(const Op &op_code) { + return ((op_code == Op::Add) || (op_code == Op::Sub)); + } + + auto insertRescale(Term::Ptr term, std::uint32_t rescaleBy) { + // auto scale = term->getScale(); + auto rescaleNode = program.makeRescale(term, rescaleBy); + type[rescaleNode] = type[term]; + scale[rescaleNode] = scale[term] - rescaleBy; + + term->replaceOtherUsesWith(rescaleNode); + + return rescaleNode; + } + + void insertRescaleBetween(Term::Ptr term1, Term::Ptr term2, + std::uint32_t rescaleBy) { + auto rescaleNode = program.makeRescale(term1, rescaleBy); + type[rescaleNode] = type[term1]; + scale[rescaleNode] = scale[term1] - rescaleBy; + + term2->replaceOperand(term1, rescaleNode); + } + + void handleRawScale(Term::Ptr term) { + if (term->numOperands() > 0) { + int maxScale = 0; + for (auto &operand : term->getOperands()) { + if (scale.at(operand) > maxScale) maxScale = scale.at(operand); + } + scale[term] = maxScale; + } + } +}; + +} // namespace eva diff --git a/eva/ckks/scales_checker.h b/eva/ckks/scales_checker.h new file mode 100644 index 0000000..e567aa4 --- /dev/null +++ b/eva/ckks/scales_checker.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include +#include +#include + +namespace eva { + +class ScalesChecker { +public: + ScalesChecker(Program &g, TermMapOptional &scales, + TermMap &types) + : program_(g), scales_(g), types_(types) {} + + void operator()(const Term::Ptr &term) { + // Must only be used with forward pass traversal + if (types_[term] == Type::Raw) { + return; + } + auto &operands = term->getOperands(); + + // Nothing to do for source terms + if (term->op == Op::Input || term->op == Op::Encode) { + scales_[term] = term->get(); + if (scales_.at(term) == 0) { + if (term->op == Op::Input) { + throw std::runtime_error("Program has an input with 0 scale"); + } else { + throw std::logic_error("Compiled program results in a 0 scale term"); + } + } + } else if (term->op == Op::Mul) { + assert(term->numOperands() == 2); + std::uint32_t scale = 0; + for (auto &operand : operands) { + scale += scales_.at(operand); + } + if (scale == 0) { + throw std::logic_error("Compiled program results in a 0 scale term"); + } + scales_[term] = scale; + } else if (term->op == Op::Rescale) { + assert(term->numOperands() == 1); + auto divisor = term->get(); + auto operandScale = scales_.at(term->operandAt(0)); + std::uint32_t scale = operandScale - divisor; + if (scale == 0) { + throw std::logic_error("Compiled program results in a 0 scale term"); + } + scales_[term] = scale; + + } else if (isAdditionOp(term->op)) { + std::uint32_t scale = 0; + for (auto &operand : operands) { + if (scale == 0) { + scale = scales_.at(operand); + } else { + if (scale != scales_.at(operand)) { + throw std::logic_error("Addition or subtraction in program has " + "operands of non-equal scale"); + } + } + } + if (scale == 0) { + throw std::logic_error("Compiled program results in a 0 scale term"); + } + scales_[term] = scale; + } else { + auto scale = scales_.at(term->operandAt(0)); + if (scale == 0) { + throw std::logic_error("Compiled program results in a 0 scale term"); + } + scales_[term] = scale; + } + } + + void free(const Term::Ptr &term) { + // No-op + } + +private: + Program &program_; + TermMapOptional scales_; + TermMap &types_; + + bool isAdditionOp(const Op &op_code) { + return ((op_code == Op::Add) || (op_code == Op::Sub)); + } +}; + +} // namespace eva diff --git a/eva/ckks/seal_lowering.h b/eva/ckks/seal_lowering.h new file mode 100644 index 0000000..59ca51c --- /dev/null +++ b/eva/ckks/seal_lowering.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class SEALLowering { + Program &program; + TermMap &type; + +public: + SEALLowering(Program &g, TermMap &type) : program(g), type(type) {} + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + + // SEAL does not support plaintext subtraction with a plaintext on the left + // hand side, so lower to a negation and addition. + if (term->op == Op::Sub && type[term->operandAt(0)] != Type::Cipher && + type[term->operandAt(1)] == Type::Cipher) { + auto negation = program.makeTerm(Op::Negate, {term->operandAt(1)}); + auto addition = program.makeTerm(Op::Add, {term->operandAt(0), negation}); + term->replaceAllUsesWith(addition); + } + } +}; + +} // namespace eva diff --git a/eva/common/CMakeLists.txt b/eva/common/CMakeLists.txt new file mode 100644 index 0000000..07e8197 --- /dev/null +++ b/eva/common/CMakeLists.txt @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(eva PRIVATE + reference_executor.cpp +) diff --git a/eva/common/constant_folder.h b/eva/common/constant_folder.h new file mode 100644 index 0000000..a5c2c62 --- /dev/null +++ b/eva/common/constant_folder.h @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class ConstantFolder { + Program &program; + TermMapOptional &scale; + std::vector scratch1, scratch2; + + bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } + + bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } + + bool isAdditionOp(const Op &op_code) { + return ((op_code == Op::Add) || (op_code == Op::Sub)); + } + + void replaceNodeWithConstant(Term::Ptr term, + const std::vector &output, + double termScale) { + // TODO: optimize output representations + auto constant = program.makeDenseConstant(output); + scale[constant] = termScale; + constant->set(scale[constant]); + + term->replaceAllUsesWith(constant); + assert(term->numUses() == 0); + } + + void add(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) { + auto &input1 = args1->get()->expand( + scratch1, program.getVecSize()); + auto &input2 = args2->get()->expand( + scratch2, program.getVecSize()); + + std::vector outputValue(input1.size()); + for (std::uint64_t i = 0; i < outputValue.size(); ++i) { + outputValue[i] = input1[i] + input2[i]; + } + + replaceNodeWithConstant(output, outputValue, + std::max(scale[args1], scale[args2])); + } + + void sub(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) { + auto &input1 = args1->get()->expand( + scratch1, program.getVecSize()); + auto &input2 = args2->get()->expand( + scratch2, program.getVecSize()); + + std::vector outputValue(input1.size()); + for (std::uint64_t i = 0; i < outputValue.size(); ++i) { + outputValue[i] = input1[i] - input2[i]; + } + + replaceNodeWithConstant(output, outputValue, + std::max(scale[args1], scale[args2])); + } + + void mul(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) { + auto &input1 = args1->get()->expand( + scratch1, program.getVecSize()); + auto &input2 = args2->get()->expand( + scratch2, program.getVecSize()); + + std::vector outputValue(input1.size()); + for (std::uint64_t i = 0; i < outputValue.size(); ++i) { + outputValue[i] = input1[i] * input2[i]; + } + + replaceNodeWithConstant(output, outputValue, + std::max(scale[args1], scale[args2])); + } + + void leftRotate(Term::Ptr output, const Term::Ptr &args1, + std::int32_t shift) { + auto &input1 = args1->get()->expand( + scratch1, program.getVecSize()); + + while (shift > 0 && shift >= input1.size()) + shift -= input1.size(); + while (shift < 0) + shift += input1.size(); + + std::vector outputValue(input1.size()); + for (std::uint64_t i = 0; i < (outputValue.size() - shift); ++i) { + outputValue[i] = input1[i + shift]; + } + for (std::uint64_t i = 0; i < shift; ++i) { + outputValue[outputValue.size() - shift + i] = input1[i]; + } + + replaceNodeWithConstant(output, outputValue, scale[args1]); + } + + void rightRotate(Term::Ptr output, const Term::Ptr &args1, + std::int32_t shift) { + auto &input1 = args1->get()->expand( + scratch1, program.getVecSize()); + + while (shift > 0 && shift >= input1.size()) + shift -= input1.size(); + while (shift < 0) + shift += input1.size(); + + std::vector outputValue(input1.size()); + for (std::uint64_t i = 0; i < (outputValue.size() - shift); ++i) { + outputValue[i + shift] = input1[i]; + } + for (std::uint64_t i = 0; i < shift; ++i) { + outputValue[i] = input1[outputValue.size() - shift + i]; + } + + replaceNodeWithConstant(output, outputValue, scale[args1]); + } + + void negate(Term::Ptr output, const Term::Ptr &args1) { + auto &input1 = args1->get()->expand( + scratch1, program.getVecSize()); + + std::vector outputValue(input1.size()); + for (std::uint64_t i = 0; i < outputValue.size(); ++i) { + outputValue[i] = -input1[i]; + } + + replaceNodeWithConstant(output, outputValue, scale[args1]); + } + +public: + ConstantFolder(Program &g, TermMapOptional &scale) + : program(g), scale(scale) {} + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + auto &args = term->getOperands(); + if (args.size() == 0) return; // inputs + + for (auto &arg : args) { + if (arg->op != Op::Constant) return; + } + + auto op_code = term->op; + switch (op_code) { + case Op::Add: + assert(args.size() == 2); + add(term, args[0], args[1]); + break; + case Op::Sub: + assert(args.size() == 2); + sub(term, args[0], args[1]); + break; + case Op::Mul: + assert(args.size() == 2); + mul(term, args[0], args[1]); + break; + case Op::RotateLeftConst: + assert(args.size() == 1); + leftRotate(term, args[0], term->get()); + break; + case Op::RotateRightConst: + assert(args.size() == 1); + rightRotate(term, args[0], term->get()); + break; + case Op::Negate: + assert(args.size() == 1); + negate(term, args[0]); + break; + case Op::Output: + [[fallthrough]]; + case Op::Encode: + break; + case Op::Relinearize: + [[fallthrough]]; + case Op::ModSwitch: + [[fallthrough]]; + case Op::Rescale: + throw std::logic_error("Encountered HE specific operation " + + getOpName(op_code) + + " in unencrypted computation"); + default: + throw std::logic_error("Unhandled op " + getOpName(op_code)); + } + } +}; + +} // namespace eva diff --git a/eva/common/multicore_program_traversal.h b/eva/common/multicore_program_traversal.h new file mode 100644 index 0000000..90666d5 --- /dev/null +++ b/eva/common/multicore_program_traversal.h @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include "eva/util/galois.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace eva { + +class MulticoreProgramTraversal { +public: + MulticoreProgramTraversal(Program &g) : program_(g) {} + + template void forwardPass(Evaluator &eval) { + TermMap predecessors(program_); + TermMap successors(program_); + + // Add the source terms + galois::InsertBag readyNodes; + for (auto source : program_.getSources()) { + readyNodes.push_back(source); + } + + // Enumerate predecessors and successors + galois::for_each( + galois::iterate(readyNodes), + [&](const Term::Ptr &term, auto &ctx) { + // For each term, iterate over its uses + for (auto &use : term->getUses()) { + // Increment the number of successors + ++successors[term]; + + // Increment the number of predecessors + if ((++predecessors[use]) == 1) { + // Only first predecessor will push so each use is added once + ctx.push_back(use); + } + } + }, + galois::wl>(), + galois::no_stats(), + galois::loopname("ForwardCountPredecessorsSuccessors")); + + // Traverse the program + galois::for_each( + galois::iterate(readyNodes), + [&](const Term::Ptr &term, auto &ctx) { + // Process the current term + eval(term); + + // Free operands if their successors are done + for (auto &operand : term->getOperands()) { + if ((--successors[operand]) == 0) { + // Only last successor will free + eval.free(operand); + } + } + + // Execute (ready) uses if their predecessors are done + for (auto &use : term->getUses()) { + if ((--predecessors[use]) == 0) { + // Only last predecessor will push + ctx.push_back(use); + } + } + }, + galois::wl>(), + galois::no_stats(), galois::loopname("ForwardTraversal")); + + // TODO: Reinstate these checks + // for (auto& predecessor : predecessors) assert(predecessor == 0); + // for (auto& successor : successors) assert(successor == 0); + } + + template void backwardPass(Evaluator &eval) { + TermMap predecessors(program_); + TermMap successors(program_); + + // Add the sink terms + galois::InsertBag readyNodes; + for (auto &sink : program_.getSinks()) { + readyNodes.push_back(sink); + } + + // Enumerate predecessors and successors + galois::for_each( + galois::iterate(readyNodes), + [&](const Term::Ptr &term, auto &ctx) { + // For each term, iterate over its operands + for (auto &operand : term->getOperands()) { + // Increment the number of predecessors + ++predecessors[term]; + + // Increment the number of successors for the operand + if ((++successors[operand]) == 1) { + // Only first successor will push so each operand is added once + ctx.push_back(operand); + } + } + }, + galois::wl>(), + galois::no_stats(), + galois::loopname("BackwardCountPredecessorsSuccessors")); + + // Traverse the program + galois::for_each( + galois::iterate(readyNodes), + [&](const Term::Ptr &term, auto &ctx) { + // Process the current term + eval(term); + + // Free uses if their predecessors are done + for (auto &use : term->getUses()) { + if ((--predecessors[use]) == 0) { + // Only last predecessor will free + eval.free(use); + } + } + + // Execute (ready) operands if their successors are done + for (auto &operand : term->getOperands()) { + if ((--successors[operand]) == 0) { + // Only last successor will push + ctx.push_back(operand); + } + } + }, + galois::wl>(), + galois::no_stats(), galois::loopname("BackwardTraversal")); + + // TODO: Reinstate these checks + // for (auto& predecessor : predecessors) assert(predecessor == 0); + // for (auto& successor : successors) assert(successor == 0); + } + +private: + Program &program_; + GaloisGuard galoisGuard_; +}; + +} // namespace eva diff --git a/eva/common/program_traversal.h b/eva/common/program_traversal.h new file mode 100644 index 0000000..4cd7812 --- /dev/null +++ b/eva/common/program_traversal.h @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include "eva/util/logging.h" +#include +#include + +namespace eva { + +/* +Implements efficient forward and backward traversals of Program in the +presence of modifications during traversal. +The rewriter is called for each term in the Program exactly once. +Rewriters must not modify the Program in such a way that terms that are +not uses/operands (for forward/backward traversal, respectively) of the +current term are enabled. With such modifications the whole program is +not guaranteed to be traversed. +*/ +class ProgramTraversal { + Program &program; + + TermMap ready; + TermMap processed; + + template bool arePredecessorsDone(const Term::Ptr &term) { + for (auto &operand : isForward ? term->getOperands() : term->getUses()) { + if (!processed[operand]) return false; + } + return true; + } + + template + void traverse(Rewriter &&rewrite) { + processed.clear(); + ready.clear(); + + std::vector readyNodes = + isForward ? program.getSources() : program.getSinks(); + for (auto &term : readyNodes) { + ready[term] = true; + } + // Used for remembering uses/operands before rewrite is called. Using a + // vector here is fine because duplicates in the list are handled + // gracefully. + std::vector checkList; + + while (readyNodes.size() != 0) { + // Pop term to transform + auto term = readyNodes.back(); + readyNodes.pop_back(); + + // If this term is removed, we will lose uses/operands of this term. + // Remember them here for checking readyness after the rewrite. + checkList.clear(); + for (auto &succ : isForward ? term->getUses() : term->getOperands()) { + checkList.push_back(succ); + } + + log(Verbosity::Trace, "Processing term with index=%lu", term->index); + rewrite(term); + processed[term] = true; + + // If transform adds new sources/sinks add them to ready terms. + for (auto &leaf : isForward ? program.getSources() : program.getSinks()) { + if (!ready[leaf]) { + readyNodes.push_back(leaf); + ready[leaf] = true; + } + } + + // Also check current uses/operands in case any new ones were added. + for (auto &succ : isForward ? term->getUses() : term->getOperands()) { + checkList.push_back(succ); + } + + // Push and mark uses/operands that are ready to be processed. + for (auto &succ : checkList) { + if (!ready[succ] && arePredecessorsDone(succ)) { + readyNodes.push_back(succ); + ready[succ] = true; + } + } + } + } + +public: + ProgramTraversal(Program &g) : program(g), processed(g), ready(g) {} + + template void forwardPass(Rewriter &&rewrite) { + traverse(std::forward(rewrite)); + } + + template void backwardPass(Rewriter &&rewrite) { + traverse(std::forward(rewrite)); + } +}; + +} // namespace eva diff --git a/eva/common/reduction_balancer.h b/eva/common/reduction_balancer.h new file mode 100644 index 0000000..cc7a6b6 --- /dev/null +++ b/eva/common/reduction_balancer.h @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include + +namespace eva { + +/* +This pass combines nodes to reduce the depth of the tree. +Suppose you have expression tree + + * + / \ + * t5(c) => * + / \ / \ \ +t1(c) t2(c) t1 t2 t5 + + +Before combining first it checks if some node have only one use and both +of these nodes have same op then these two nodes are combined into one node +with children of both the nodes. +This pass helps to get the flat form of an expression so that later on it can +be expanded to get a expression in a balanced form. +For example (a * (b * (c * d))) => (a * b * c * d) => (a * b) * (c * d) +*/ +class ReductionCombiner { + Program &program; + + bool isReductionOp(const Op &op_code) { + return ((op_code == Op::Add) || (op_code == Op::Mul)); + } + +public: + ReductionCombiner(Program &g) : program(g) {} + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + if (!term->isInternal() || !isReductionOp(term->op)) return; + + auto uses = term->getUses(); + if (uses.size() == 1) { + auto &use = uses[0]; + if (use->op == term->op) { + // combine term and its use + while (use->eraseOperand(term)) { + for (auto &operand : term->getOperands()) { + // add term's operands to use's operands + use->addOperand(operand); + } + } + } + } + } +}; + +class ReductionLogExpander { + Program &program; + TermMap &type; + TermMapOptional scale; + std::vector operands, nextOperands; + std::map> sortedOperands; + + bool isReductionOp(const Op &op_code) { + return ((op_code == Op::Add) || (op_code == Op::Mul)); + } + +public: + ReductionLogExpander(Program &g, TermMap &type) + : program(g), type(type), scale(g) {} + + void operator()(Term::Ptr &term) { + if (term->op == Op::Rescale || term->op == Op::ModSwitch) { + throw std::logic_error("Rescale or ModSwitch encountered, but " + "ReductionLogExpander uses scale as" + " a proxy for level and assumes rescaling has not " + "been performed yet."); + } + + // Calculate the scales that we would get without any rescaling. Terms at a + // similar scale will likely end up having the same level in typical + // rescaling policies, which helps the sorting group terms of the same level + // together. + if (term->numOperands() == 0) { + scale[term] = term->get(); + } else if (term->op == Op::Mul) { + scale[term] = std::accumulate( + term->getOperands().begin(), term->getOperands().end(), 0, + [&](auto &sum, auto &operand) { return sum + scale.at(operand); }); + } else { + scale[term] = std::accumulate(term->getOperands().begin(), + term->getOperands().end(), 0, + [&](auto &max, auto &operand) { + return std::max(max, scale.at(operand)); + }); + } + + if (isReductionOp(term->op) && term->numOperands() > 2) { + // We sort operands into constants, plaintext and raw, then ciphertexts by + // scale. This helps avoid unnecessary accumulation of scale. + for (auto &operand : term->getOperands()) { + auto order = 0; + if (type[operand] == Type::Plain || type[operand] == Type::Raw) { + order = 1; + } else if (type[operand] == Type::Cipher) { + order = 2 + scale.at(operand); + } + sortedOperands[order].push_back(operand); + } + for (auto &op : sortedOperands) { + operands.insert(operands.end(), op.second.begin(), op.second.end()); + } + + // Expand the sorted operands into a balanced reduction tree by pairing + // adjacent operands until only one remains. + assert(operands.size() >= 2); + while (operands.size() > 2) { + std::size_t i = 0; + while ((i + 1) < operands.size()) { + auto &leftOperand = operands[i]; + auto &rightOperand = operands[i + 1]; + auto newTerm = + program.makeTerm(term->op, {leftOperand, rightOperand}); + nextOperands.push_back(newTerm); + i += 2; + } + if (i < operands.size()) { + assert((i + 1) == operands.size()); + nextOperands.push_back(operands[i]); + } + operands = nextOperands; + nextOperands.clear(); + } + + assert(operands.size() == 2); + term->setOperands(operands); + + operands.clear(); + nextOperands.clear(); + sortedOperands.clear(); + } + } +}; + +} // namespace eva diff --git a/eva/common/reference_executor.cpp b/eva/common/reference_executor.cpp new file mode 100644 index 0000000..4f7567d --- /dev/null +++ b/eva/common/reference_executor.cpp @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/common/reference_executor.h" +#include +#include +#include +#include + +using namespace std; + +namespace eva { + +void ReferenceExecutor::leftRotate(vector &output, + const Term::Ptr &args, int32_t shift) { + auto &input = terms_.at(args); + + // Reserve enough space for output + output.clear(); + output.reserve(input.size()); + + while (shift > 0 && shift >= input.size()) + shift -= input.size(); + while (shift < 0) + shift += input.size(); + + // Shift left and copy to output + copy_n(input.cbegin() + shift, input.size() - shift, back_inserter(output)); + copy_n(input.cbegin(), shift, back_inserter(output)); +} + +void ReferenceExecutor::rightRotate(vector &output, + const Term::Ptr &args, int32_t shift) { + auto &input = terms_.at(args); + + // Reserve enough space for output + output.clear(); + output.reserve(input.size()); + + while (shift > 0 && shift >= input.size()) + shift -= input.size(); + while (shift < 0) + shift += input.size(); + + // Shift right and copy to output + copy_n(input.cend() - shift, shift, back_inserter(output)); + copy_n(input.cbegin(), input.size() - shift, back_inserter(output)); +} + +void ReferenceExecutor::negate(vector &output, const Term::Ptr &args) { + auto &input = terms_.at(args); + + // Reserve enough space for output + output.clear(); + output.reserve(input.size()); + transform(input.cbegin(), input.cend(), back_inserter(output), + std::negate()); +} + +void ReferenceExecutor::operator()(const Term::Ptr &term) { + // Must only be used with forward pass traversal + auto &output = terms_[term]; + + auto op = term->op; + auto args = term->getOperands(); + + switch (op) { + case Op::Input: + // Nothing to do for inputs + break; + case Op::Constant: + // A constant (vector) is expanded to the number of slots (vecSize_ here) + term->get()->expandTo(output, vecSize_); + break; + case Op::Add: + assert(args.size() == 2); + binOp>(output, args[0], args[1]); + break; + case Op::Sub: + assert(args.size() == 2); + binOp>(output, args[0], args[1]); + break; + case Op::Mul: + assert(args.size() == 2); + binOp>(output, args[0], args[1]); + break; + case Op::RotateLeftConst: + assert(args.size() == 1); + leftRotate(output, args[0], term->get()); + break; + case Op::RotateRightConst: + assert(args.size() == 1); + rightRotate(output, args[0], term->get()); + break; + case Op::Negate: + assert(args.size() == 1); + negate(output, args[0]); + break; + case Op::Encode: + [[fallthrough]]; + case Op::Output: + [[fallthrough]]; + case Op::Relinearize: + [[fallthrough]]; + case Op::ModSwitch: + [[fallthrough]]; + case Op::Rescale: + // Copy argument value for outputs + assert(args.size() == 1); + output = terms_[args[0]]; + break; + default: + assert(false); + } +} + +} // namespace eva diff --git a/eva/common/reference_executor.h b/eva/common/reference_executor.h new file mode 100644 index 0000000..072b5c0 --- /dev/null +++ b/eva/common/reference_executor.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/common/valuation.h" +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include +#include +#include +#include + +namespace eva { + +// Executes unencrypted computation +class ReferenceExecutor { +public: + ReferenceExecutor(Program &g) + : program_(g), vecSize_(g.getVecSize()), terms_(g) {} + + ReferenceExecutor(const ReferenceExecutor ©) = delete; + + ReferenceExecutor &operator=(const ReferenceExecutor &assign) = delete; + + template + void setInputs(const std::unordered_map &inputs) { + for (auto &in : inputs) { + auto term = program_.getInput(in.first); + terms_[term] = in.second; // TODO: can we avoid this copy? + if (terms_[term].size() != vecSize_) { + throw std::runtime_error( + "The length of all inputs must be the same as program's vector " + "size. Input " + + in.first + " has length " + std::to_string(terms_[term].size()) + + ", but vector size is " + std::to_string(vecSize_)); + } + } + } + + void operator()(const Term::Ptr &term); + + void free(const Term::Ptr &term) { + if (term->op == Op::Output) return; + terms_[term].clear(); + } + + void getOutputs(Valuation &outputs) { + for (auto &out : program_.getOutputs()) { + outputs[out.first] = terms_[out.second]; + } + } + +private: + Program &program_; + std::uint64_t vecSize_; + TermMapOptional> terms_; + + template + void binOp(std::vector &out, const Term::Ptr &args1, + const Term::Ptr &args2) { + auto &in1 = terms_.at(args1); + auto &in2 = terms_.at(args2); + assert(in1.size() == in2.size()); + + out.clear(); + out.reserve(in1.size()); + transform(in1.cbegin(), in1.cend(), in2.cbegin(), back_inserter(out), Op()); + } + + void leftRotate(std::vector &output, const Term::Ptr &args, + std::int32_t shift); + + void rightRotate(std::vector &output, const Term::Ptr &args, + std::int32_t shift); + + void negate(std::vector &output, const Term::Ptr &args); +}; + +} // namespace eva diff --git a/eva/common/rotation_keys_selector.h b/eva/common/rotation_keys_selector.h new file mode 100644 index 0000000..5ca405d --- /dev/null +++ b/eva/common/rotation_keys_selector.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include +#include +#include +#include + +namespace eva { + +class RotationKeysSelector { +public: + RotationKeysSelector(Program &g, const TermMap &type) + : program_(g), type(type) {} + + void operator()(const Term::Ptr &term) { + auto op = term->op; + + // Nothing to do if this is not a rotation + if (!isLeftRotationOp(op) && !isRightRotationOp(op)) return; + + // No rotation keys needed for raw computation + if (type[term] == Type::Raw) return; + + // Add the rotation count + auto rotation = term->get(); + keys_.insert(isRightRotationOp(op) ? -rotation : rotation); + } + + void free(const Term::Ptr &term) { + // No-op + } + + auto getRotationKeys() { + // Return the set of rotations needed + return keys_; + } + +private: + Program &program_; + const TermMap &type; + std::set keys_; + + bool isLeftRotationOp(const Op &op_code) { + return (op_code == Op::RotateLeftConst); + } + + bool isRightRotationOp(const Op &op_code) { + return (op_code == Op::RotateRightConst); + } +}; + +} // namespace eva diff --git a/eva/common/type_deducer.h b/eva/common/type_deducer.h new file mode 100644 index 0000000..1bdc748 --- /dev/null +++ b/eva/common/type_deducer.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" + +namespace eva { + +class TypeDeducer { + Program &program; + TermMap &types; + +public: + TypeDeducer(Program &g, TermMap &types) : program(g), types(types) {} + + void + operator()(Term::Ptr &term) { // must only be used with forward pass traversal + auto &operands = term->getOperands(); + if (operands.size() > 0) { // not an input/root + Type inferred = Type::Raw; // Plain if not Cipher + for (auto &operand : operands) { + if (types[operand] == Type::Cipher) + inferred = Type::Cipher; // Cipher if any operand is Cipher + } + if (term->op == Op::Encode) { + types[term] = Type::Plain; + } else { + types[term] = inferred; + } + } else if (term->op == Op::Constant) { + types[term] = Type::Raw; + } else { + types[term] = term->get(); + } + } +}; + +} // namespace eva diff --git a/eva/common/valuation.h b/eva/common/valuation.h new file mode 100644 index 0000000..374459d --- /dev/null +++ b/eva/common/valuation.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace eva { + +using Valuation = std::unordered_map>; + +} diff --git a/eva/eva.cpp b/eva/eva.cpp new file mode 100644 index 0000000..5f75397 --- /dev/null +++ b/eva/eva.cpp @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/eva.h" +#include "eva/common/program_traversal.h" +#include "eva/common/reference_executor.h" +#include "eva/common/valuation.h" + +namespace eva { + +Valuation evaluate(Program &program, const Valuation &inputs) { + Valuation outputs; + ProgramTraversal programTraverse(program); + ReferenceExecutor ge(program); + + ge.setInputs(inputs); + programTraverse.forwardPass(ge); + ge.getOutputs(outputs); + + return outputs; +} + +} // namespace eva diff --git a/eva/eva.h b/eva/eva.h new file mode 100644 index 0000000..9031690 --- /dev/null +++ b/eva/eva.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/ckks_compiler.h" +#include "eva/ir/program.h" +#include "eva/seal/seal.h" +#include "eva/serialization/save_load.h" +#include "eva/version.h" + +namespace eva { + +Valuation evaluate(Program &program, const Valuation &inputs); + +} diff --git a/eva/ir/CMakeLists.txt b/eva/ir/CMakeLists.txt new file mode 100644 index 0000000..c54617b --- /dev/null +++ b/eva/ir/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(eva PRIVATE + term.cpp + program.cpp + attribute_list.cpp + attributes.cpp +) diff --git a/eva/ir/attribute_list.cpp b/eva/ir/attribute_list.cpp new file mode 100644 index 0000000..ef65ce7 --- /dev/null +++ b/eva/ir/attribute_list.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/ir/attribute_list.h" +#include "eva/ir/attributes.h" +#include "eva/util/overloaded.h" +#include + +using namespace std; + +namespace eva { + +bool AttributeList::has(AttributeKey k) const { + const AttributeList *curr = this; + while (true) { + if (curr->key < k) { + if (curr->tail) { + curr = curr->tail.get(); + } else { + return false; + } + } else { + return curr->key == k; + } + } +} + +const AttributeValue &AttributeList::get(AttributeKey k) const { + const AttributeList *curr = this; + while (true) { + if (curr->key == k) { + return curr->value; + } else if (curr->key < k && curr->tail) { + curr = curr->tail.get(); + } else { + throw out_of_range("Attribute not in list: " + getAttributeName(k)); + } + } +} + +void AttributeList::set(AttributeKey k, AttributeValue v) { + if (this->key == 0) { + this->key = k; + this->value = move(v); + } else { + AttributeList *curr = this; + AttributeList *prev = nullptr; + while (true) { + if (curr->key < k) { + if (curr->tail) { + prev = curr; + curr = curr->tail.get(); + } else { // Insert at end + // AttributeList constructor is private + curr->tail = unique_ptr{new AttributeList(k, move(v))}; + return; + } + } else if (curr->key > k) { + if (prev) { // Insert between + // AttributeList constructor is private + auto newList = + unique_ptr{new AttributeList(k, move(v))}; + newList->tail = move(prev->tail); + prev->tail = move(newList); + } else { // Insert at beginning + // AttributeList constructor is private + curr->tail = + unique_ptr{new AttributeList(move(*curr))}; + curr->key = k; + curr->value = move(v); + } + return; + } else { + assert(curr->key == k); + curr->value = move(v); + return; + } + } + } +} + +void AttributeList::assignAttributesFrom(const AttributeList &other) { + if (this->key != 0) { + this->key = 0; + this->value = std::monostate(); + this->tail = nullptr; + } + if (other.key == 0) { + return; + } + AttributeList *lhs = this; + const AttributeList *rhs = &other; + while (true) { + lhs->key = rhs->key; + lhs->value = rhs->value; + if (rhs->tail) { + rhs = rhs->tail.get(); + lhs->tail = std::make_unique(); + lhs = lhs->tail.get(); + } else { + return; + } + } +} + +} // namespace eva diff --git a/eva/ir/attribute_list.h b/eva/ir/attribute_list.h new file mode 100644 index 0000000..2caf537 --- /dev/null +++ b/eva/ir/attribute_list.h @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/constant_value.h" +#include "eva/ir/types.h" +#include "eva/serialization/eva.pb.h" +#include +#include +#include +#include +#include +#include +#include + +namespace eva { + +using AttributeValue = std::variant>; + +template struct IsInVariant; +template +struct IsInVariant> + : std::bool_constant<(... || std::is_same{})> {}; + +using AttributeKey = std::uint8_t; + +template struct Attribute { + static_assert(IsInVariant::value, + "Attribute type not in AttributeValue"); + static_assert(Key > 0, "Keys must be strictly positive"); + static_assert(Key <= std::numeric_limits::max(), + "Key larger than current AttributeKey type"); + + using Value = T; + static constexpr AttributeKey key = Key; + + static bool isValid(AttributeKey k, const AttributeValue &v) { + return k == Key && std::holds_alternative(v); + } +}; + +class AttributeList { +public: + AttributeList() : key(0), tail(nullptr) {} + + // This function is defined in eva/serialization/eva_serialization.cpp + void loadAttribute(const msg::Attribute &msg); + + // This function is defined in eva/serialization/eva_serialization.cpp + void serializeAttributes(std::function addMsg) const; + + template bool has() const { return has(TAttr::key); } + + template const typename TAttr::Value &get() const { + return std::get(get(TAttr::key)); + } + + template void set(typename TAttr::Value value) { + set(TAttr::key, std::move(value)); + } + + void assignAttributesFrom(const AttributeList &other); + +private: + AttributeKey key; + AttributeValue value; + std::unique_ptr tail; + + AttributeList(AttributeKey k, AttributeValue v) + : key(k), value(std::move(v)) {} + + bool has(AttributeKey k) const; + const AttributeValue &get(AttributeKey k) const; + void set(AttributeKey k, AttributeValue v); +}; + +} // namespace eva diff --git a/eva/ir/attributes.cpp b/eva/ir/attributes.cpp new file mode 100644 index 0000000..6e0476a --- /dev/null +++ b/eva/ir/attributes.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/ir/attributes.h" +#include + +using namespace std; + +namespace eva { + +#define X(name, type) name::isValid(k, v) || +bool isValidAttribute(AttributeKey k, const AttributeValue &v) { + return EVA_ATTRIBUTES false; +} +#undef X + +#define X(name, type) \ + case detail::name##Index: \ + return #name; +string getAttributeName(AttributeKey k) { + switch (k) { + EVA_ATTRIBUTES + default: + throw runtime_error("Unknown attribute key"); + } +} +#undef X + +} // namespace eva diff --git a/eva/ir/attributes.h b/eva/ir/attributes.h new file mode 100644 index 0000000..f5eb604 --- /dev/null +++ b/eva/ir/attributes.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/attribute_list.h" +#include +#include + +namespace eva { + +#define EVA_ATTRIBUTES \ + X(RescaleDivisorAttribute, std::uint32_t) \ + X(RotationAttribute, std::int32_t) \ + X(ConstantValueAttribute, std::shared_ptr) \ + X(TypeAttribute, Type) \ + X(RangeAttribute, std::uint32_t) \ + X(EncodeAtScaleAttribute, std::uint32_t) \ + X(EncodeAtLevelAttribute, std::uint32_t) + +namespace detail { +enum AttributeIndex { + RESERVE_EMPTY_ATTRIBUTE_KEY = 0, +#define X(name, type) name##Index, + EVA_ATTRIBUTES +#undef X +}; +} // namespace detail + +#define X(name, type) using name = Attribute; +EVA_ATTRIBUTES +#undef X + +bool isValidAttribute(AttributeKey k, const AttributeValue &v); + +std::string getAttributeName(AttributeKey k); + +} // namespace eva diff --git a/eva/ir/constant_value.h b/eva/ir/constant_value.h new file mode 100644 index 0000000..46c77c4 --- /dev/null +++ b/eva/ir/constant_value.h @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/serialization/eva.pb.h" +#include +#include +#include +#include +#include +#include + +namespace eva { + +class ConstantValue { +public: + ConstantValue(std::size_t size) : size(size) {} + virtual ~ConstantValue() {} + virtual const std::vector &expand(std::vector &scratch, + std::size_t slots) const = 0; + virtual void expandTo(std::vector &result, + std::size_t slots) const = 0; + virtual bool isZero() const = 0; + virtual void serialize(msg::ConstantValue &msg) const = 0; + +protected: + std::size_t size; + + void validateSlots(std::size_t slots) const { + if (slots < size) { + throw std::runtime_error("Slots must be at least size of constant"); + } + if (slots % size != 0) { + throw std::runtime_error("Size must exactly divide slots"); + } + } +}; + +class DenseConstantValue : public ConstantValue { +public: + DenseConstantValue(std::size_t size, std::vector values) + : ConstantValue(size), values(values) { + if (size % values.size() != 0) { + throw std::runtime_error( + "DenseConstantValue size must exactly divide size"); + } + } + + const std::vector &expand(std::vector &scratch, + std::size_t slots) const override { + validateSlots(slots); + if (values.size() == slots) { + return values; + } else { + scratch.clear(); + for (int r = slots / values.size(); r > 0; --r) { + scratch.insert(scratch.end(), values.begin(), values.end()); + } + return scratch; + } + } + + void expandTo(std::vector &result, std::size_t slots) const override { + validateSlots(slots); + result.clear(); + result.reserve(slots); + for (int r = slots / values.size(); r > 0; --r) { + result.insert(result.end(), values.begin(), values.end()); + } + } + + bool isZero() const override { + for (double value : values) { + if (value != 0) return false; + } + return true; + } + + void serialize(msg::ConstantValue &msg) const override { + msg.set_size(size); + auto valuesMsg = msg.mutable_values(); + valuesMsg->Reserve(values.size()); + for (const auto &value : values) { + valuesMsg->Add(value); + } + } + +private: + std::vector values; +}; + +class SparseConstantValue : public ConstantValue { +public: + SparseConstantValue(std::size_t size, + std::vector> values) + : ConstantValue(size), values(values) {} + + const std::vector &expand(std::vector &scratch, + std::size_t slots) const override { + validateSlots(slots); + scratch.clear(); + scratch.resize(slots); + for (auto &entry : values) { + for (int i = 0; i < slots; i += values.size()) { + scratch.at(entry.first + i) = entry.second; + } + } + return scratch; + } + + void expandTo(std::vector &result, std::size_t slots) const override { + validateSlots(slots); + result.clear(); + result.resize(slots); + for (auto &entry : values) { + for (int i = 0; i < slots; i += values.size()) { + result.at(entry.first + i) = entry.second; + } + } + } + + bool isZero() const override { + // TODO: this assumes no repeated indices + for (auto entry : values) { + if (entry.second != 0) return false; + } + return true; + } + + void serialize(msg::ConstantValue &msg) const override { + msg.set_size(size); + for (const auto &pair : values) { + msg.add_sparse_indices(pair.first); + msg.add_values(pair.second); + } + } + +private: + std::vector> values; +}; + +std::unique_ptr serialize(const ConstantValue &obj); +std::shared_ptr deserialize(const msg::ConstantValue &msg); + +} // namespace eva diff --git a/eva/ir/ops.h b/eva/ir/ops.h new file mode 100644 index 0000000..64452a1 --- /dev/null +++ b/eva/ir/ops.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +namespace eva { + +#define EVA_OPS \ + X(Undef, 0) \ + X(Input, 1) \ + X(Output, 2) \ + X(Constant, 3) \ + X(Negate, 10) \ + X(Add, 11) \ + X(Sub, 12) \ + X(Mul, 13) \ + X(RotateLeftConst, 14) \ + X(RotateRightConst, 15) \ + X(Relinearize, 20) \ + X(ModSwitch, 21) \ + X(Rescale, 22) \ + X(Encode, 23) + +enum class Op { +#define X(op, code) op = code, + EVA_OPS +#undef X +}; + +inline bool isValidOp(Op op) { + switch (op) { +#define X(op, code) case Op::op: + EVA_OPS +#undef X + return true; + default: + return false; + } +} + +inline std::string getOpName(Op op) { + switch (op) { +#define X(op, code) \ + case Op::op: \ + return #op; + EVA_OPS +#undef X + default: + throw std::runtime_error("Invalid op"); + } +} + +} // namespace eva diff --git a/eva/ir/program.cpp b/eva/ir/program.cpp new file mode 100644 index 0000000..be705f9 --- /dev/null +++ b/eva/ir/program.cpp @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/ir/program.h" +#include "eva/common/program_traversal.h" +#include "eva/ir/term_map.h" +#include "eva/util/logging.h" +#include + +using namespace std; + +namespace eva { + +// TODO: maybe replace with smart iterator to avoid allocation +vector toTermPtrs(const unordered_set &terms) { + vector termPtrs; + termPtrs.reserve(terms.size()); + for (auto &term : terms) { + termPtrs.emplace_back(term->shared_from_this()); + } + return termPtrs; +} + +vector Program::getSources() const { + return toTermPtrs(this->sources); +} + +vector Program::getSinks() const { return toTermPtrs(this->sinks); } + +std::unique_ptr Program::deepCopy() { + auto newProg = std::make_unique(getName(), getVecSize()); + TermMap oldToNew(*this); + ProgramTraversal traversal(*this); + traversal.forwardPass([&](Term::Ptr &term) { + auto newTerm = newProg->makeTerm(term->op); + oldToNew[term] = newTerm; + newTerm->assignAttributesFrom(*term); + for (auto &operand : term->getOperands()) { + newTerm->addOperand(oldToNew[operand]); + } + }); + for (auto &entry : inputs) { + newProg->inputs[entry.first] = oldToNew[entry.second]; + } + for (auto &entry : outputs) { + newProg->outputs[entry.first] = oldToNew[entry.second]; + } + return newProg; +} + +uint64_t Program::allocateIndex() { + // TODO: reuse released indices to save space in TermMap instances + uint64_t index = nextTermIndex++; + for (TermMapBase *termMap : termMaps) { + termMap->resize(nextTermIndex); + } + return index; +} + +void Program::initTermMap(TermMapBase &termMap) { + termMap.resize(nextTermIndex); +} + +void Program::registerTermMap(TermMapBase *termMap) { + termMaps.emplace_back(termMap); +} + +void Program::unregisterTermMap(TermMapBase *termMap) { + auto iter = find(termMaps.begin(), termMaps.end(), termMap); + if (iter == termMaps.end()) { + throw runtime_error("TermMap to unregister not found"); + } else { + termMaps.erase(iter); + } +} + +template +void dumpAttribute(stringstream &s, Term *term, std::string label) { + if (term->has()) { + s << ", " << label << "=" << term->get(); + } +} + +// Print an attribute in DOT format as a box outside the term +template +void toDOTAttributeAsNode(stringstream &s, Term *term, std::string label) { + if (term->has()) { + s << "t" << term->index << "_" << getAttributeName(Attr::key) + << " [shape=box label=\"" << label << "=" << term->get() + << "\"];\n"; + s << "t" << term->index << "_" << getAttributeName(Attr::key) << " -> t" + << term->index << ";\n"; + } +} + +string Program::dump(TermMapOptional &scales, + TermMap &types, + TermMap &level) const { + // TODO: switch to use a non-parallel generic traversal + stringstream s; + s << getName() << "(){\n"; + + // Add all terms in topologically sorted order + uint64_t nextIndex = 0; + unordered_map indices; + stack> work; + for (const auto &sink : getSinks()) { + work.emplace(true, sink.get()); + } + while (!work.empty()) { + bool visit = work.top().first; + auto term = work.top().second; + work.pop(); + if (indices.count(term)) { + continue; + } + if (visit) { + work.emplace(false, term); + for (const auto &operand : term->getOperands()) { + work.emplace(true, operand.get()); + } + } else { + auto index = nextIndex; + nextIndex += 1; + indices[term] = index; + s << "t" << term->index << " = " << getOpName(term->op); + if (term->has()) { + s << "(" << term->get() << ")"; + } + if (term->has()) { + s << "(" << term->get() << ")"; + } + if (term->has()) { + s << ":" << getTypeName(term->get()); + } + for (int i = 0; i < term->numOperands(); ++i) { + s << " t" << term->operandAt(i)->index; + } + dumpAttribute(s, term, "range"); + dumpAttribute(s, term, "level"); + if (types[*term] == Type::Cipher) + s << ", " + << "s" + << "=" << scales[*term] << ", t=cipher "; + else if (types[*term] == Type::Raw) + s << ", " + << "s" + << "=" << scales[*term] << ", t=raw "; + else + s << ", " + << "s" + << "=" << scales[*term] << ", t=plain "; + s << "\n"; + // ConstantValue TODO: printing constant values for simple cases + } + } + + s << "}\n"; + return s.str(); +} + +string Program::toDOT() const { + // TODO: switch to use a non-parallel generic traversal + stringstream s; + + s << "digraph \"" << getName() << "\" {\n"; + + // Add all terms in topologically sorted order + uint64_t nextIndex = 0; + unordered_map indices; + stack> work; + for (const auto &sink : getSinks()) { + work.emplace(true, sink.get()); + } + while (!work.empty()) { + bool visit = work.top().first; + auto term = work.top().second; + work.pop(); + if (indices.count(term)) { + continue; + } + if (visit) { + work.emplace(false, term); + for (const auto &operand : term->getOperands()) { + work.emplace(true, operand.get()); + } + } else { + auto index = nextIndex; + nextIndex += 1; + indices[term] = index; + + // Operands are guaranteed to have been added + s << "t" << term->index << " [label=\"" << getOpName(term->op); + if (term->has()) { + s << "(" << term->get() << ")"; + } + if (term->has()) { + s << "(" << term->get() << ")"; + } + if (term->has()) { + s << " : " << getTypeName(term->get()); + } + s << "\""; // End label + s << "];\n"; + for (int i = 0; i < term->numOperands(); ++i) { + s << "t" << term->operandAt(i)->index << " -> t" << term->index + << " [label=\"" << i << "\"];\n"; + } + toDOTAttributeAsNode(s, term, "range"); + toDOTAttributeAsNode(s, term, "scale"); + toDOTAttributeAsNode(s, term, "level"); + // ConstantValue TODO: printing constant values for simple cases + } + } + + s << "}\n"; + + return s.str(); +} + +} // namespace eva diff --git a/eva/ir/program.h b/eva/ir/program.h new file mode 100644 index 0000000..924dd99 --- /dev/null +++ b/eva/ir/program.h @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/constant_value.h" +#include "eva/ir/term.h" +#include "eva/serialization/eva.pb.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace eva { + +template class TermMapOptional; +template class TermMap; +class TermMapBase; + +class Program { +public: + Program(std::string name, std::uint64_t vecSize) + : name(name), vecSize(vecSize), nextTermIndex(0) { + if (vecSize == 0) { + throw std::runtime_error("Vector size must be non-zero"); + } + if ((vecSize & (vecSize - 1)) != 0) { + throw std::runtime_error("Vector size must be a power-of-two"); + } + } + + Program(const Program ©) = delete; + + Program &operator=(const Program &assign) = delete; + + Term::Ptr makeTerm(Op op, const std::vector &operands = {}) { + auto term = std::make_shared(op, *this); + if (operands.size() > 0) { + term->setOperands(operands); + } + return term; + } + + Term::Ptr makeConstant(std::unique_ptr value) { + auto term = makeTerm(Op::Constant); + term->set(std::move(value)); + return term; + } + + Term::Ptr makeDenseConstant(std::vector values) { + return makeConstant(std::make_unique(vecSize, values)); + } + + Term::Ptr makeUniformConstant(double value) { + return makeDenseConstant({value}); + } + + Term::Ptr makeInput(const std::string &name, Type type = Type::Cipher) { + auto term = makeTerm(Op::Input); + term->set(type); + inputs.emplace(name, term); + return term; + } + + Term::Ptr makeOutput(std::string name, const Term::Ptr &term) { + auto output = makeTerm(Op::Output, {term}); + outputs.emplace(name, output); + return output; + } + + Term::Ptr makeLeftRotation(const Term::Ptr &term, std::int32_t slots) { + auto rotation = makeTerm(Op::RotateLeftConst, {term}); + rotation->set(slots); + return rotation; + } + + Term::Ptr makeRightRotation(const Term::Ptr &term, std::int32_t slots) { + auto rotation = makeTerm(Op::RotateRightConst, {term}); + rotation->set(slots); + return rotation; + } + + Term::Ptr makeRescale(const Term::Ptr &term, std::uint32_t rescaleBy) { + auto rescale = makeTerm(Op::Rescale, {term}); + rescale->set(rescaleBy); + return rescale; + } + + Term::Ptr getInput(std::string name) const { + if (inputs.find(name) == inputs.end()) { + std::stringstream s; + s << "No input named " << name; + throw std::out_of_range(s.str()); + } + return inputs.at(name); + } + + const auto &getInputs() const { return inputs; } + + const auto &getOutputs() const { return outputs; } + + std::string getName() const { return name; } + void setName(std::string newName) { name = newName; } + + std::uint32_t getVecSize() const { return vecSize; } + + std::vector getSources() const; + + std::vector getSinks() const; + + // Make a deep copy of this program + std::unique_ptr deepCopy(); + + std::string toDOT() const; + std::string dump(TermMapOptional &scales, + TermMap &types, + TermMap &level) const; + +private: + std::uint64_t allocateIndex(); + void initTermMap(TermMapBase &termMap); + void registerTermMap(TermMapBase *annotation); + void unregisterTermMap(TermMapBase *annotation); + + std::string name; + std::uint32_t vecSize; + + // These are managed automatically by Term + std::unordered_set sources; + std::unordered_set sinks; + + std::uint64_t nextTermIndex; + std::vector termMaps; + + // These members must currently be last, because their destruction triggers + // associated Terms to be destructed, which still use the sources and sinks + // structures above. + // TODO: move away from shared ownership for Terms and have Program own them + // uniquely. It is an error to hold onto a Term longer than a Program, but + // the shared_ptr is misleading on this regard. + std::unordered_map outputs; + std::unordered_map inputs; + + friend class Term; + friend class TermMapBase; + friend std::unique_ptr serialize(const Program &); + friend std::unique_ptr deserialize(const msg::Program &); +}; + +std::unique_ptr deserialize(const msg::Program &); + +} // namespace eva diff --git a/eva/ir/term.cpp b/eva/ir/term.cpp new file mode 100644 index 0000000..fd42dba --- /dev/null +++ b/eva/ir/term.cpp @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/ir/term.h" +#include "eva/ir/program.h" +#include +#include + +using namespace std; + +namespace eva { + +Term::Term(Op op, Program &program) + : op(op), program(program), index(program.allocateIndex()) { + program.sources.insert(this); + program.sinks.insert(this); +} + +Term::~Term() { + for (Ptr &operand : operands) { + operand->eraseUse(this); + } + if (operands.empty()) { + program.sources.erase(this); + } + assert(uses.empty()); + program.sinks.erase(this); +} + +void Term::addOperand(const Term::Ptr &term) { + if (operands.empty()) { + program.sources.erase(this); + } + operands.emplace_back(term); + term->addUse(this); +} + +bool Term::eraseOperand(const Ptr &term) { + auto iter = find(operands.begin(), operands.end(), term); + if (iter != operands.end()) { + term->eraseUse(this); + operands.erase(iter); + if (operands.empty()) { + program.sources.insert(this); + } + return true; + } + return false; +} + +bool Term::replaceOperand(Ptr oldTerm, Ptr newTerm) { + bool replaced = false; + for (Ptr &operand : operands) { + if (operand == oldTerm) { + operand = newTerm; + oldTerm->eraseUse(this); + newTerm->addUse(this); + replaced = true; + } + } + return replaced; +} + +void Term::replaceUsesWithIf(Ptr term, function predicate) { + auto thisPtr = shared_from_this(); // TODO: avoid this and similar + // unnecessary reference counting + for (auto &use : getUses()) { + if (predicate(use)) { + use->replaceOperand(thisPtr, term); + } + } +} + +void Term::replaceAllUsesWith(Ptr term) { + replaceUsesWithIf(term, [](const Ptr &) { return true; }); +} + +void Term::replaceOtherUsesWith(Ptr term) { + replaceUsesWithIf(term, [&](const Ptr &use) { return use != term; }); +} + +void Term::setOperands(vector o) { + if (operands.empty()) { + program.sources.erase(this); + } + + for (auto &operand : operands) { + operand->eraseUse(this); + } + operands = move(o); + for (auto &operand : operands) { + operand->addUse(this); + } + + if (operands.empty()) { + program.sources.insert(this); + } +} + +size_t Term::numOperands() const { return operands.size(); } + +Term::Ptr Term::operandAt(size_t i) { return operands.at(i); } + +const vector &Term::getOperands() const { return operands; } + +size_t Term::numUses() { return uses.size(); } + +vector Term::getUses() { + vector u; + for (Term *use : uses) { + u.emplace_back(use->shared_from_this()); + } + return u; +} + +bool Term::isInternal() const { + return ((operands.size() != 0) && (uses.size() != 0)); +} + +void Term::addUse(Term *term) { + if (uses.empty()) { + program.sinks.erase(this); + } + uses.emplace_back(term); +} + +bool Term::eraseUse(Term *term) { + auto iter = find(uses.begin(), uses.end(), term); + assert(iter != uses.end()); + uses.erase(iter); + if (uses.empty()) { + program.sinks.insert(this); + return true; + } + return false; +} + +ostream &operator<<(ostream &s, const Term &term) { + s << term.index << ':' << getOpName(term.op) << '('; + bool first = true; + for (const auto &operand : term.getOperands()) { + if (first) { + first = false; + } else { + s << ','; + } + s << operand->index; + } + s << ')'; + return s; +} + +} // namespace eva diff --git a/eva/ir/term.h b/eva/ir/term.h new file mode 100644 index 0000000..49a437a --- /dev/null +++ b/eva/ir/term.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/attributes.h" +#include "eva/ir/ops.h" +#include "eva/ir/types.h" +#include +#include +#include +#include +#include +#include +#include + +namespace eva { + +class Program; + +class Term : public AttributeList, public std::enable_shared_from_this { +public: + using Ptr = std::shared_ptr; + + Term(Op opcode, Program &program); + ~Term(); + + void addOperand(const Ptr &term); + bool eraseOperand(const Ptr &term); + bool replaceOperand(Ptr oldTerm, Ptr newTerm); + void setOperands(std::vector o); + std::size_t numOperands() const; + Ptr operandAt(size_t i); + const std::vector &getOperands() const; + + void replaceUsesWithIf(Ptr term, std::function); + void replaceAllUsesWith(Ptr term); + void replaceOtherUsesWith(Ptr term); + + std::size_t numUses(); + std::vector getUses(); + + bool isInternal() const; + + const Op op; + Program &program; + + // Unique index for this Term in the owning Program. Managed by Program + // and used to index into TermMap instances. + std::uint64_t index; + + friend std::ostream &operator<<(std::ostream &s, const Term &term); + +private: + std::vector operands; // use->def chain (unmanaged pointers) + std::vector uses; // def->use chain (managed pointers) + + void addUse(Term *term); + bool eraseUse(Term *term); +}; + +} // namespace eva diff --git a/eva/ir/term_map.h b/eva/ir/term_map.h new file mode 100644 index 0000000..33ed8ff --- /dev/null +++ b/eva/ir/term_map.h @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/program.h" +#include "eva/ir/term.h" +#include +#include +#include +#include +#include +#include + +namespace eva { + +class TermMapBase { +public: + TermMapBase(Program &p) : program(&p) { program->registerTermMap(this); } + ~TermMapBase() { program->unregisterTermMap(this); } + TermMapBase(const TermMapBase &other) : program(other.program) { + program->registerTermMap(this); + } + TermMapBase &operator=(const TermMapBase &other) = default; + + friend class Program; + +protected: + void init() { program->initTermMap(*this); } + + std::uint64_t getIndex(const Term &term) const { return term.index; } + +private: + virtual void resize(std::size_t size) = 0; + + Program *program; +}; + +template class TermMap : TermMapBase { +public: + TermMap(Program &p) : TermMapBase(p) { init(); } + + TValue &operator[](const Term &term) { return values.at(getIndex(term)); } + + const TValue &operator[](const Term &term) const { + return values.at(getIndex(term)); + } + + TValue &operator[](const Term::Ptr &term) { return this->operator[](*term); } + + const TValue &operator[](const Term::Ptr &term) const { + return this->operator[](*term); + } + + void clear() { values.assign(values.size(), {}); } + +private: + void resize(std::size_t size) override { values.resize(size); } + + std::deque values; +}; + +template <> class TermMap : TermMapBase { +public: + TermMap(Program &p) : TermMapBase(p) { init(); } + + std::vector::reference operator[](const Term &term) { + return values.at(getIndex(term)); + } + + bool operator[](const Term &term) const { return values.at(getIndex(term)); } + + std::vector::reference operator[](const Term::Ptr &term) { + return this->operator[](*term); + } + + bool operator[](const Term::Ptr &term) const { + return this->operator[](*term); + } + + void clear() { values.assign(values.size(), false); } + +private: + void resize(std::size_t size) override { values.resize(size); } + + std::vector values; +}; + +template class TermMapOptional : TermMapBase { +public: + TermMapOptional(Program &p) : TermMapBase(p) { init(); } + + TOptionalValue &operator[](const Term &term) { + auto &value = values.at(getIndex(term)); + if (!value.has_value()) { + value.emplace(); + } + return *value; + } + + TOptionalValue &operator[](const Term::Ptr &term) { + return this->operator[](*term); + } + + TOptionalValue &at(const Term &term) { + return values.at(getIndex(term)).value(); + } + + TOptionalValue &at(const Term::Ptr &term) { return this->at(*term); } + + bool has(const Term &term) const { + return values.at(getIndex(term)).has_value(); + } + + bool has(const Term::Ptr &term) const { return has(*term); } + + void clear() { values.assign(values.size(), std::nullopt); } + +private: + void resize(std::size_t size) override { values.resize(size); } + + std::deque> values; +}; + +} // namespace eva diff --git a/eva/ir/types.h b/eva/ir/types.h new file mode 100644 index 0000000..f305776 --- /dev/null +++ b/eva/ir/types.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +namespace eva { + +#define EVA_TYPES \ + X(Undef, 0) \ + X(Cipher, 1) \ + X(Raw, 2) \ + X(Plain, 3) + +enum class Type : std::int32_t { +#define X(type, code) type = code, + EVA_TYPES +#undef X +}; + +inline std::string getTypeName(Type type) { + switch (type) { +#define X(type, code) \ + case Type::type: \ + return #type; + EVA_TYPES +#undef X + default: + throw std::runtime_error("Invalid type"); + } +} + +} // namespace eva diff --git a/eva/seal/CMakeLists.txt b/eva/seal/CMakeLists.txt new file mode 100644 index 0000000..005051a --- /dev/null +++ b/eva/seal/CMakeLists.txt @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(eva PRIVATE + seal.cpp +) diff --git a/eva/seal/seal.cpp b/eva/seal/seal.cpp new file mode 100644 index 0000000..06bccc0 --- /dev/null +++ b/eva/seal/seal.cpp @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/seal/seal.h" +#include "eva/common/program_traversal.h" +#include "eva/common/valuation.h" +#include "eva/seal/seal_executor.h" +#include "eva/util/logging.h" +#include +#include +#include +#include +#include + +#ifdef EVA_USE_GALOIS +#include "eva/common/multicore_program_traversal.h" +#include "eva/util/galois.h" +#endif + +using namespace std; + +namespace eva { + +SEALValuation SEALPublic::encrypt(const Valuation &inputs, + const CKKSSignature &signature) { + size_t slotCount = encoder.slot_count(); + if (slotCount < signature.vecSize) { + throw runtime_error("Vector size cannot be larger than slot count"); + } + if (slotCount % signature.vecSize != 0) { + throw runtime_error("Vector size must exactly divide the slot count"); + } + + SEALValuation sealInputs(context); + for (auto &in : inputs) { + + // With multicore sealInputs is initialized first, so that multiple threads + // can be used to encode and encrypt values into it at the same time without + // making structural changes. +#ifdef EVA_USE_GALOIS + sealInputs[in.first] = {}; + } + + // Start a second parallel loop to encrypt inputs. + GaloisGuard galois; + galois::do_all( + galois::iterate(inputs), + [&](auto &in) { +#endif + auto name = in.first; + auto &v = in.second; + auto vSize = v.size(); + // TODO remove this check + if (vSize != signature.vecSize) { + throw runtime_error("Input size does not match program vector size"); + } + auto info = signature.inputs.at(name); + + auto ctxData = context.first_context_data(); + for (size_t i = 0; i < info.level; ++i) { + ctxData = ctxData->next_context_data(); + } + + if (info.inputType == Type::Cipher || info.inputType == Type::Plain) { + seal::Plaintext plain; + + if (vSize == 1) { + encoder.encode(v[0], ctxData->parms_id(), pow(2.0, info.scale), + plain); + } else { + vector vec(slotCount); + assert(vSize <= slotCount); + assert((slotCount % vSize) == 0); + auto replicas = (slotCount / vSize); + for (uint32_t r = 0; r < replicas; ++r) { + for (uint64_t i = 0; i < vSize; ++i) { + vec[(r * vSize) + i] = v[i]; + } + } + encoder.encode(vec, ctxData->parms_id(), pow(2.0, info.scale), + plain); + } + if (info.inputType == Type::Cipher) { + seal::Ciphertext cipher; + encryptor.encrypt(plain, cipher); + sealInputs[name] = move(cipher); + } else if (info.inputType == Type::Plain) { + sealInputs[name] = move(plain); + } + } else { + sealInputs[name] = std::shared_ptr( + new DenseConstantValue(signature.vecSize, v)); + } + } +#ifdef EVA_USE_GALOIS + // Finish the parallel loop if using multicore support + , + galois::no_stats(), galois::loopname("EncryptInputs")); +#endif + + return sealInputs; +} + +SEALValuation SEALPublic::execute(Program &program, + const SEALValuation &inputs) { +#ifdef EVA_USE_GALOIS + // Do multicore evaluation if multicore support is available + GaloisGuard galois; + MulticoreProgramTraversal programTraverse(program); +#else + // Otherwise fall back to singlecore evaluation + ProgramTraversal programTraverse(program); +#endif + auto sealExecutor = SEALExecutor(program, context, encoder, encryptor, + evaluator, galoisKeys, relinKeys); + sealExecutor.setInputs(inputs); + programTraverse.forwardPass(sealExecutor); + + SEALValuation encOutputs(context); + sealExecutor.getOutputs(encOutputs); + return encOutputs; +} + +Valuation SEALSecret::decrypt(const SEALValuation &encOutputs, + const CKKSSignature &signature) { + Valuation outputs; + std::vector tempVec; + for (auto &out : encOutputs) { + auto name = out.first; + visit(Overloaded{[&](const seal::Ciphertext &cipher) { + seal::Plaintext plain; + decryptor.decrypt(cipher, plain); + encoder.decode(plain, outputs[name]); + }, + [&](const seal::Plaintext &plain) { + encoder.decode(plain, outputs[name]); + }, + [&](const std::shared_ptr &raw) { + auto &scratch = tempVec; + outputs[name] = raw->expand(scratch, signature.vecSize); + }}, + out.second); + outputs.at(name).resize(signature.vecSize); + } + return outputs; +} + +seal::SEALContext getSEALContext(const seal::EncryptionParameters ¶ms) { + static unordered_map cache; + + // clean cache except for the required entry + for (auto iter = cache.begin(); iter != cache.end();) { + // accessing the context data increases the reference count by one + // Another reference is incremented by cache entry + if (iter->second.key_context_data().use_count() == 2 && + iter->first != params) { + iter = cache.erase(iter); + } else { + ++iter; + } + } + + // find SEALContext + if (cache.count(params) != 0) { + seal::SEALContext result = cache.at(params); + return result; + } else { + auto result = cache.emplace(make_pair( + params, seal::SEALContext(params, true, seal::sec_level_type::none))); + return result.first->second; + } +} + +tuple, unique_ptr> +generateKeys(const CKKSParameters &abstractParams) { + vector logQs(abstractParams.primeBits.begin(), + abstractParams.primeBits.end()); + + auto params = seal::EncryptionParameters(seal::scheme_type::ckks); + params.set_poly_modulus_degree(abstractParams.polyModulusDegree); + params.set_coeff_modulus( + seal::CoeffModulus::Create(abstractParams.polyModulusDegree, logQs)); + + auto context = getSEALContext(params); + + seal::KeyGenerator keygen(context); + vector rotationsVec(abstractParams.rotations.begin(), + abstractParams.rotations.end()); + + seal::PublicKey public_key; + seal::GaloisKeys galois_keys; + seal::RelinKeys relin_keys; + + keygen.create_public_key(public_key); + keygen.create_galois_keys(rotationsVec, galois_keys); + keygen.create_relin_keys(relin_keys); + + auto secretCtx = make_unique(context, keygen.secret_key()); + auto publicCtx = + make_unique(context, public_key, galois_keys, relin_keys); + + return make_tuple(move(publicCtx), move(secretCtx)); +} + +} // namespace eva diff --git a/eva/seal/seal.h b/eva/seal/seal.h new file mode 100644 index 0000000..c6d5795 --- /dev/null +++ b/eva/seal/seal.h @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/ckks_parameters.h" +#include "eva/ckks/ckks_signature.h" +#include "eva/common/valuation.h" +#include "eva/ir/program.h" +#include "eva/serialization/seal.pb.h" +#include +#include +#include +#include +#include +#include +#include + +namespace eva { + +using SchemeValue = std::variant>; + +class SEALValuation { +public: + SEALValuation(const seal::EncryptionParameters ¶ms) : params(params) {} + SEALValuation(const seal::SEALContext &context) + : params(context.key_context_data()->parms()) {} + + auto &operator[](const std::string &name) { return values[name]; } + auto begin() { return values.begin(); } + auto begin() const { return values.begin(); } + auto end() { return values.end(); } + auto end() const { return values.end(); } + +private: + seal::EncryptionParameters params; + std::unordered_map values; + + friend std::unique_ptr serialize(const SEALValuation &); +}; + +std::unique_ptr deserialize(const msg::SEALValuation &); + +class SEALPublic { +public: + SEALPublic(seal::SEALContext ctx, seal::PublicKey pk, seal::GaloisKeys gk, + seal::RelinKeys rk) + : context(ctx), publicKey(pk), galoisKeys(gk), relinKeys(rk), + encoder(ctx), encryptor(ctx, publicKey), evaluator(ctx) {} + + SEALValuation encrypt(const Valuation &inputs, + const CKKSSignature &signature); + + SEALValuation execute(Program &program, const SEALValuation &inputs); + +private: + seal::SEALContext context; + + seal::PublicKey publicKey; + seal::GaloisKeys galoisKeys; + seal::RelinKeys relinKeys; + + seal::CKKSEncoder encoder; + seal::Encryptor encryptor; + seal::Evaluator evaluator; + + friend std::unique_ptr serialize(const SEALPublic &); +}; + +std::unique_ptr deserialize(const msg::SEALPublic &); + +class SEALSecret { +public: + SEALSecret(seal::SEALContext ctx, seal::SecretKey sk) + : context(ctx), secretKey(sk), encoder(ctx), decryptor(ctx, secretKey) {} + + Valuation decrypt(const SEALValuation &encOutputs, + const CKKSSignature &signature); + +private: + seal::SEALContext context; + + seal::SecretKey secretKey; + + seal::CKKSEncoder encoder; + seal::Decryptor decryptor; + + friend std::unique_ptr serialize(const SEALSecret &); +}; + +std::unique_ptr deserialize(const msg::SEALSecret &); + +seal::SEALContext getSEALContext(const seal::EncryptionParameters ¶ms); + +std::tuple, std::unique_ptr> +generateKeys(const CKKSParameters &abstractParams); + +} // namespace eva diff --git a/eva/seal/seal_executor.h b/eva/seal/seal_executor.h new file mode 100644 index 0000000..b4e196d --- /dev/null +++ b/eva/seal/seal_executor.h @@ -0,0 +1,438 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ir/constant_value.h" +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include "eva/util/logging.h" +#include "eva/util/overloaded.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Galois per-thread storage is used when EVA is compiled for multicore support +#ifdef EVA_USE_GALOIS +#include +#endif + +namespace eva { + +// executes unencrypted computation +class SEALExecutor { + using RuntimeValue = + std::variant>; + + Program &program; + seal::SEALContext context; + seal::CKKSEncoder &encoder; + seal::Encryptor &encryptor; + seal::Evaluator &evaluator; + seal::GaloisKeys &galoisKeys; + seal::RelinKeys &relinKeys; + TermMapOptional Objects; + + // Each thread has a separate scratch space into which constants are expanded + // for encoding. +#ifdef EVA_USE_GALOIS + galois::substrate::PerThreadStorage> tempVec; +#else + // Without multicore support only one scratch vector is needed + std::vector tempVec; +#endif + + bool isCipher(const Term::Ptr &t) { + return std::holds_alternative(Objects.at(t)); + } + bool isPlain(const Term::Ptr &t) { + return std::holds_alternative(Objects.at(t)); + } + bool isRaw(const Term::Ptr &t) { + return std::holds_alternative>(Objects.at(t)); + } + + void rightRotateRaw(std::vector &out, const Term::Ptr &args1, + std::int32_t shift) { + auto &in = std::get>(Objects.at(args1)); + + while (shift > 0 && shift >= in.size()) + shift -= in.size(); + while (shift < 0) + shift += in.size(); + + out.clear(); + out.reserve(in.size()); + copy_n(in.cend() - shift, shift, back_inserter(out)); + copy_n(in.cbegin(), in.size() - shift, back_inserter(out)); + } + + void leftRotateRaw(std::vector &out, const Term::Ptr &args1, + std::int32_t shift) { + auto &in = std::get>(Objects.at(args1)); + + while (shift > 0 && shift >= in.size()) + shift -= in.size(); + while (shift < 0) + shift += in.size(); + + out.clear(); + out.reserve(in.size()); + copy_n(in.cbegin() + shift, in.size() - shift, back_inserter(out)); + copy_n(in.cbegin(), shift, back_inserter(out)); + } + + template + void binOpRaw(std::vector &out, const Term::Ptr &args1, + const Term::Ptr &args2) { + auto &in1 = std::get>(Objects.at(args1)); + auto &in2 = std::get>(Objects.at(args2)); + assert(in1.size() == in2.size()); + + out.clear(); + out.reserve(in1.size()); + transform(in1.cbegin(), in1.cend(), in2.cbegin(), back_inserter(out), Op()); + } + + void negateRaw(std::vector &out, const Term::Ptr &args1) { + auto &in = std::get>(Objects.at(args1)); + + out.clear(); + out.reserve(in.size()); + transform(in.cbegin(), in.cend(), back_inserter(out), + std::negate()); + } + + void add(seal::Ciphertext &output, const Term::Ptr &args1, + const Term::Ptr &args2) { + if (!isCipher(args1)) { + assert(isCipher(args2)); + add(output, args2, args1); + return; + } + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + // TODO: should a previous lowering get rid of this dispatch? + std::visit(Overloaded{[&](const seal::Ciphertext &input2) { + evaluator.add(input1, input2, output); + }, + [&](const seal::Plaintext &input2) { + evaluator.add_plain(input1, input2, output); + }, + [&](const std::vector &input2) { + throw std::runtime_error( + "Unsupported operation encountered"); + }}, + Objects.at(args2)); + } + + void sub(seal::Ciphertext &output, const Term::Ptr &args1, + const Term::Ptr &args2) { + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + std::visit(Overloaded{[&](const seal::Ciphertext &input2) { + evaluator.sub(input1, input2, output); + }, + [&](const seal::Plaintext &input2) { + evaluator.sub_plain(input1, input2, output); + }, + [&](const std::vector &input2) { + throw std::runtime_error( + "Unsupported operation encountered"); + }}, + Objects.at(args2)); + } + + void mul(seal::Ciphertext &output, const Term::Ptr &args1, + const Term::Ptr &args2) { + // swap args if arg1 is plain type and arg2 is of cipher type + if (!isCipher(args1) && isCipher(args2)) { + mul(output, args2, args1); + return; + } + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + std::visit(Overloaded{[&](const seal::Ciphertext &input2) { + if (args1 == args2) { + evaluator.square(input1, output); + } else { + evaluator.multiply(input1, input2, output); + } + }, + [&](const seal::Plaintext &input2) { + evaluator.multiply_plain(input1, input2, output); + }, + [&](const std::vector &input2) { + throw std::runtime_error( + "Unsupported operation encountered"); + }}, + Objects.at(args2)); + } + + void leftRotate(seal::Ciphertext &output, const Term::Ptr &args1, + std::int32_t rotation) { + assert(isCipher(args1)); + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + evaluator.rotate_vector(input1, rotation, galoisKeys, output); + } + + void rightRotate(seal::Ciphertext &output, const Term::Ptr &args1, + std::int32_t rotation) { + assert(isCipher(args1)); + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + evaluator.rotate_vector(input1, -rotation, galoisKeys, output); + } + + void negate(seal::Ciphertext &output, const Term::Ptr &args1) { + assert(isCipher(args1)); + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + evaluator.negate(input1, output); + } + + void relinearize(seal::Ciphertext &output, const Term::Ptr &args1) { + assert(isCipher(args1)); + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + evaluator.relinearize(input1, relinKeys, output); + } + + void modSwitch(seal::Ciphertext &output, const Term::Ptr &args1) { + assert(isCipher(args1)); + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + evaluator.mod_switch_to_next(input1, output); + } + + void rescale(seal::Ciphertext &output, const Term::Ptr &args1, + std::uint32_t divisor) { + assert(isCipher(args1)); + seal::Ciphertext &input1 = std::get(Objects.at(args1)); + evaluator.rescale_to_next(input1, output); + output.scale() = input1.scale() / pow(2.0, divisor); + } + + void encodeRaw(seal::Plaintext &output, const Term::Ptr &args1, + uint32_t scale, uint32_t level) { + auto &in = std::get>(Objects.at(args1)); + + auto ctxData = context.first_context_data(); + for (std::size_t i = 0; i < level; ++i) { + ctxData = ctxData->next_context_data(); + } + + // If the slot count is larger than the vector size, then encode repetitions + // of the vector to fill the slot count. This will provide the correct + // semantics for rotations. + assert(encoder.slot_count() % program.getVecSize() == 0); + auto copies = encoder.slot_count() / program.getVecSize(); +#ifdef EVA_USE_GALOIS + auto &scratch = *tempVec.getLocal(); +#else + auto &scratch = tempVec; +#endif + scratch.clear(); + scratch.reserve(encoder.slot_count()); + for (int i = 0; i < copies; ++i) { + scratch.insert(scratch.end(), std::begin(in), std::end(in)); + } + + encoder.encode(scratch, ctxData->parms_id(), pow(2.0, scale), output); + } + + void expandConstant(std::vector &output, + const std::shared_ptr constant) { + constant->expandTo(output, program.getVecSize()); + } + + template T &initValue(const Term::Ptr &term) { + return std::get(Objects[term] = T{}); + } + +public: + SEALExecutor(Program &g, seal::SEALContext ctx, seal::CKKSEncoder &ce, + seal::Encryptor &enc, seal::Evaluator &e, seal::GaloisKeys &gk, + seal::RelinKeys &rk) + : program(g), context(ctx), encoder(ce), encryptor(enc), evaluator(e), + galoisKeys(gk), relinKeys(rk), Objects(g) { + assert(program.getVecSize() <= encoder.slot_count()); + assert((encoder.slot_count() % program.getVecSize()) == 0); + } + + void setInputs(const SEALValuation &inputs) { + for (auto &in : inputs) { + auto term = program.getInput(in.first); + std::visit( + Overloaded{ + [&](const seal::Ciphertext &input) { Objects[term] = input; }, + [&](const seal::Plaintext &input) { Objects[term] = input; }, + [&](const std::shared_ptr &input) { + auto &value = initValue>(term); + expandConstant(value, input); + }}, + in.second); + } + } + + void operator()(const Term::Ptr &term) { + if (verbosityAtLeast(Verbosity::Debug)) { + printf("EVA: Execute t%lu = %s(", term->index, + getOpName(term->op).c_str()); + bool first = true; + for (auto &operand : term->getOperands()) { + if (first) { + first = false; + printf("t%lu", operand->index); + } else { + printf(",t%lu", operand->index); + } + } + printf(")\n"); + fflush(stdout); + } + + if (term->op == Op::Input) return; + auto args = term->getOperands(); + switch (term->op) { + case Op::Constant: { + auto &output = initValue>(term); + expandConstant(output, term->get()); + } break; + case Op::Encode: { + assert(args.size() == 1); + assert(isRaw(args[0])); + auto &output = initValue(term); + encodeRaw(output, args[0], term->get(), + term->get()); + } break; + case Op::Add: + assert(args.size() == 2); + if (isRaw(args[0]) && isRaw(args[1])) { + auto &output = initValue>(term); + binOpRaw>(output, args[0], args[1]); + } else { // handles plain and cipher + assert(isCipher(args[0]) || isPlain(args[0])); + assert(isCipher(args[1]) || isPlain(args[1])); + auto &output = initValue(term); + add(output, args[0], args[1]); + } + break; + case Op::Sub: + assert(args.size() == 2); + if (isRaw(args[0]) && isRaw(args[1])) { + auto &output = initValue>(term); + binOpRaw>(output, args[0], args[1]); + } else { // handles plain and cipher + assert(isCipher(args[0]) || isPlain(args[0])); + assert(isCipher(args[1]) || isPlain(args[1])); + auto &output = initValue(term); + sub(output, args[0], args[1]); + } + break; + case Op::Mul: + assert(args.size() == 2); + if (isRaw(args[0]) && isRaw(args[1])) { + auto &output = initValue>(term); + binOpRaw>(output, args[0], args[1]); + } else { // works on cipher, no plaintext support + assert(isCipher(args[0]) || isCipher(args[1])); + assert(!isRaw(args[0]) && !isRaw(args[1])); + auto &output = initValue(term); + mul(output, args[0], args[1]); + } + break; + case Op::RotateLeftConst: + assert(args.size() == 1); + if (isRaw(args[0])) { + auto &output = initValue>(term); + leftRotateRaw(output, args[0], term->get()); + } else { // works on cipher, no plaintext support + assert(isCipher(args[0])); + auto &output = initValue(term); + leftRotate(output, args[0], term->get()); + } + break; + case Op::RotateRightConst: + assert(args.size() == 1); + if (isRaw(args[0])) { + auto &output = initValue>(term); + rightRotateRaw(output, args[0], term->get()); + } else { // works on cipher, no plaintext support + assert(isCipher(args[0])); + auto &output = initValue(term); + rightRotate(output, args[0], term->get()); + } + break; + case Op::Negate: + assert(args.size() == 1); + if (isRaw(args[0])) { + auto &output = initValue>(term); + negateRaw(output, args[0]); + } else { // works on cipher, no plaintext support + assert(isCipher(args[0])); + auto &output = initValue(term); + negate(output, args[0]); + } + break; + case Op::Relinearize: { + assert(args.size() == 1); + assert(isCipher(args[0])); + auto &output = initValue(term); + relinearize(output, args[0]); + } break; + case Op::ModSwitch: { + assert(args.size() == 1); + assert(isCipher(args[0])); + auto &output = initValue(term); + modSwitch(output, args[0]); + } break; + case Op::Rescale: { + assert(args.size() == 1); + assert(isCipher(args[0])); + auto &output = initValue(term); + rescale(output, args[0], term->get()); + } break; + case Op::Output: { + assert(args.size() == 1); + Objects[term] = Objects.at(args[0]); + } break; + default: + throw std::runtime_error("Unhandled op " + getOpName(term->op)); + } + } + + void free(const Term::Ptr &term) { + if (term->op == Op::Output) { + return; + } + auto &obj = Objects.at(term); + std::visit(Overloaded{[](seal::Ciphertext &cipher) { cipher.release(); }, + [](seal::Plaintext &plain) { plain.release(); }, + [](std::vector &raw) { + raw.clear(); + raw.shrink_to_fit(); + }}, + obj); + } + + void getOutputs(SEALValuation &encOutputs) { + for (auto &out : program.getOutputs()) { + std::visit(Overloaded{[&](const seal::Ciphertext &output) { + encOutputs[out.first] = output; + }, + [&](const seal::Plaintext &output) { + encOutputs[out.first] = output; + }, + [&](const std::vector &output) { + encOutputs[out.first] = + std::make_shared( + program.getVecSize(), output); + }}, + Objects.at(out.second)); + } + } +}; + +} // namespace eva diff --git a/eva/serialization/CMakeLists.txt b/eva/serialization/CMakeLists.txt new file mode 100644 index 0000000..f9eccbb --- /dev/null +++ b/eva/serialization/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS known_type.proto eva.proto ckks.proto seal.proto) +add_library(protobuf OBJECT ${PROTO_SRCS} ${PROTO_HDRS}) +target_include_directories(protobuf PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +target_link_libraries(protobuf PUBLIC protobuf::libprotobuf) + +target_sources(eva PRIVATE + $ + known_type.cpp + save_load.cpp + eva_serialization.cpp + ckks_serialization.cpp + seal_serialization.cpp +) diff --git a/eva/serialization/ckks.proto b/eva/serialization/ckks.proto new file mode 100644 index 0000000..4ecb739 --- /dev/null +++ b/eva/serialization/ckks.proto @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +package eva.msg; + +message CKKSParameters { + repeated uint32 prime_bits = 1; + repeated int32 rotations = 2; + uint32 poly_modulus_degree = 3; +} + +message CKKSEncodingInfo { + int32 input_type = 1; + int32 scale = 2; + int32 level = 3; +} + +message CKKSSignature { + int32 vec_size = 1; + map inputs = 2; +} diff --git a/eva/serialization/ckks_serialization.cpp b/eva/serialization/ckks_serialization.cpp new file mode 100644 index 0000000..b48d8f7 --- /dev/null +++ b/eva/serialization/ckks_serialization.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/ckks/ckks_parameters.h" +#include "eva/ckks/ckks_signature.h" +#include "eva/serialization/ckks.pb.h" +#include +#include + +using namespace std; + +namespace eva { + +unique_ptr serialize(const CKKSParameters &obj) { + // Create a new protobuf message + auto msg = make_unique(); + + // Save the prime bit counts + auto primeBitsMsg = msg->mutable_prime_bits(); + primeBitsMsg->Reserve(obj.primeBits.size()); + for (const auto &bits : obj.primeBits) { + primeBitsMsg->Add(bits); + } + + // Save the rotations that are needed + auto rotationsMsg = msg->mutable_rotations(); + rotationsMsg->Reserve(obj.rotations.size()); + for (const auto &rotation : obj.rotations) { + rotationsMsg->Add(rotation); + } + + // Save the polynomial modulus degree + msg->set_poly_modulus_degree(obj.polyModulusDegree); + + return msg; +} + +unique_ptr deserialize(const msg::CKKSParameters &msg) { + // Create a new CKKSParameters object + auto obj = make_unique(); + + // Load the values from the protobuf message + obj->primeBits = {msg.prime_bits().begin(), msg.prime_bits().end()}; + obj->rotations = {msg.rotations().begin(), msg.rotations().end()}; + obj->polyModulusDegree = msg.poly_modulus_degree(); + + return obj; +} + +unique_ptr serialize(const CKKSSignature &obj) { + // Create a new protobuf message + auto msg = make_unique(); + + // Save the vector size + msg->set_vec_size(obj.vecSize); + + // Save the input map + auto &inputsMap = *msg->mutable_inputs(); + for (auto &[key, info] : obj.inputs) { + auto &infoMsg = inputsMap[key]; + infoMsg.set_input_type(static_cast(info.inputType)); + infoMsg.set_scale(info.scale); + infoMsg.set_level(info.level); + } + + return msg; +} + +unique_ptr deserialize(const msg::CKKSSignature &msg) { + // Create a new map of CKKSEncodingInfo objects and load the data + unordered_map inputs; + for (auto &[key, infoMsg] : msg.inputs()) { + inputs.emplace(key, + CKKSEncodingInfo(static_cast(infoMsg.input_type()), + infoMsg.scale(), infoMsg.level())); + } + + // Return a new CKKSSignature object + return make_unique(msg.vec_size(), move(inputs)); +} + +} // namespace eva diff --git a/eva/serialization/eva.proto b/eva/serialization/eva.proto new file mode 100644 index 0000000..62f304c --- /dev/null +++ b/eva/serialization/eva.proto @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +package eva.msg; + +message Term { + uint32 op = 1; + // Absolute indices to list of terms + repeated uint64 operands = 2; + repeated Attribute attributes = 3; +} + +message ConstantValue { + uint32 size = 1; + // If sparse_indices is set then values are interpreted as a sparse set of values + // Otherwise values is interpreted as dense with broadcasting semantics and size must divide vec_size + // If values is empty then the whole constant is zero + repeated double values = 2; + repeated uint32 sparse_indices = 3; +} + +message Attribute { + uint32 key = 1; + oneof value { + uint32 uint32 = 2; + sint32 int32 = 3; + uint32 type = 4; + ConstantValue constant_value = 5; + } +} + +message TermName { + uint64 term = 1; + string name = 2; +} + +message Program { + uint32 ir_version = 1; + string name = 2; + uint32 vec_size = 3; + repeated Term terms = 4; + repeated TermName inputs = 5; + repeated TermName outputs = 6; +} diff --git a/eva/serialization/eva_format_version.h b/eva/serialization/eva_format_version.h new file mode 100644 index 0000000..df1e548 --- /dev/null +++ b/eva/serialization/eva_format_version.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +namespace eva { + +// Bump the version for any changes that break serialization +const std::int32_t EVA_FORMAT_VERSION = 2; + +} // namespace eva diff --git a/eva/serialization/eva_serialization.cpp b/eva/serialization/eva_serialization.cpp new file mode 100644 index 0000000..dfc4b2a --- /dev/null +++ b/eva/serialization/eva_serialization.cpp @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/ir/attribute_list.h" +#include "eva/ir/program.h" +#include "eva/ir/term_map.h" +#include "eva/serialization/eva_format_version.h" +#include "eva/util/overloaded.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace eva { + +// Definition for member function AttributeList::loadAttribute +// The function definition is here so that all serialization code is together +// and not spread out throughout the project +void AttributeList::loadAttribute(const msg::Attribute &msg) { + // Load the attribute key; this encodes the type of the attribute + AttributeKey key = static_cast(msg.key()); + + // A variant that holds the possible value types for attributes + AttributeValue value; + + switch (msg.value_case()) { + case msg::Attribute::kUint32: + // The attribute holds a uint32; load it + value.emplace(msg.uint32()); + break; + case msg::Attribute::kInt32: + // The attribute holds an int32; load it + value.emplace(msg.int32()); + break; + case msg::Attribute::kType: + // The attribute holds a Type (see eva/ir/types.h); load it + value.emplace(static_cast(msg.type())); + break; + case msg::Attribute::kConstantValue: + // The attribute holds a constant value; load it + value.emplace>(deserialize(msg.constant_value())); + break; + case msg::Attribute::VALUE_NOT_SET: + // No value is set; set the attribute to std::monostate + value.emplace(); + break; + default: + // An unexpected value + throw runtime_error("Unknown attribute type"); + } + + // Check that the attribute value is valid + if (!isValidAttribute(key, value)) { + throw runtime_error("Invalid attribute encountered"); + } + + // Add the attribute to this AttributeList + set(key, move(value)); +} + +// Definition for member function AttributeList::serializeAttributes +// The function definition is here so that all serialization code is together +// and not spread out throughout the project +void AttributeList::serializeAttributes( + function addMsg) const { + // Nothing to do if key is zero (empty attribute; see eva/ir/attributes.h) + if (key == 0) { + return; + } + + // Go over each attribute in this list + const AttributeList *curr = this; + do { + // Get a pointer to a new attribute + msg::Attribute *msg = addMsg(); + + // Set the key and value + msg->set_key(curr->key); + visit(Overloaded{[&](const monostate &value) { + // This is an empty attribute + }, + [&](const uint32_t &value) { msg->set_uint32(value); }, + [&](const int32_t &value) { msg->set_int32(value); }, + [&](const Type &value) { + msg->set_type(static_cast(value)); + }, + [&](const shared_ptr &value) { + auto valueMsg = serialize(*value); + msg->set_allocated_constant_value(valueMsg.release()); + }}, + curr->value); + + // Move on to the next attribute + curr = curr->tail.get(); + } while (curr); +} + +unique_ptr serialize(const ConstantValue &obj) { + // Save a constant value; the implementations are in eva/ir/constant_value.h + auto msg = std::make_unique(); + obj.serialize(*msg); + return msg; +} + +shared_ptr deserialize(const msg::ConstantValue &msg) { + if (msg.size() == 0) { + throw runtime_error("Constant must have non-zero size"); + } + + size_t size = msg.size(); + if (msg.values_size() == 0) { + // Zero size; return a sparse zero constant + return make_shared(size, + vector>{}); + } else if (msg.sparse_indices_size() == 0) { + // No sparse indices so this is a dense constant + vector values(msg.values().begin(), msg.values().end()); + return make_shared(size, move(values)); + } else { + // Must be a sparse constant; check that the data is consistent + if (msg.sparse_indices_size() != msg.values_size()) { + throw runtime_error("Values and sparse indices count mismatch"); + } + + // Load the sparse representation + vector> values; + auto indexIter = msg.sparse_indices().begin(); + auto valueIter = msg.values().begin(); + while (indexIter != msg.sparse_indices().end()) { + values.emplace_back(*indexIter, *valueIter); + ++indexIter; + ++valueIter; + } + + return make_shared(size, move(values)); + } +} + +unique_ptr serialize(const Program &obj) { + // Create a new program message for serialization + auto msg = make_unique(); + + // Save the IR version and vector size + msg->set_ir_version(EVA_FORMAT_VERSION); + msg->set_vec_size(obj.vecSize); + + // Save all terms in topologically sorted order; this is convenient so we can + // easily load it back and set up operand pointers immediately after loading + // a term. To each term we assign a topological index (operands of each term + // have indices less than the index of the current term). The edges of the + // program graph are saved by providing the operand term indices for each + // term. + + // Table of topological indices assigned to the terms + unordered_map indices; + + // Index to be assigned next + uint64_t nextIndex = 0; + + // Work stack of terms that need to be processed + // The bool ("visit") signals whether the operands of the term have already + // been processed. If this is false, we are ready to give the term an index. + stack> work; + + // Add each sink to the work stack with visit flag set to true + for (const auto &sink : obj.getSinks()) { + work.emplace(true, sink.get()); + } + + // Operate on the stack until empty + while (!work.empty()) { + // Pop from the stack + bool visit = work.top().first; + auto term = work.top().second; + work.pop(); + + // If this term has already been given an index, there is nothing to do; + // all of its operands are guaranteed to have been assigned indices + if (indices.count(term)) { + continue; + } + + // This term does not yet appear in the index map + // Do we need to process its operands or are we ready to assign an index? + if (visit) { + // Add the term back with to-do/visit flag set to false; next time we + // process this term we will give it an index + work.emplace(false, term); + + // Add the operands to work stack with visit flag set to true + for (const auto &operand : term->getOperands()) { + work.emplace(true, operand.get()); + } + } else { + // The operands of this term have already been processed; ready to assign + // an index + + // Read the current index value and increment the counter for next + auto index = nextIndex; + nextIndex += 1; + + // Add this term to the indices map + indices[term] = index; + + // Add a new term to the message; set the opcode and add operands + auto termMsg = msg->add_terms(); + termMsg->set_op(static_cast(term->op)); + for (const auto &operand : term->getOperands()) { + // Add operands to the current term by saving the indices + termMsg->add_operands(indices.at(operand.get())); + } + + // Save the attributes for this term + term->serializeAttributes([&]() { return termMsg->add_attributes(); }); + } + } + + // Save the input term indices and labels + for (const auto &entry : obj.inputs) { + auto termNameMsg = msg->add_inputs(); + termNameMsg->set_name(entry.first); + termNameMsg->set_term(indices.at(entry.second.get())); + } + + // Save the output term indices and labels + for (const auto &entry : obj.outputs) { + auto termNameMsg = msg->add_outputs(); + termNameMsg->set_name(entry.first); + termNameMsg->set_term(indices.at(entry.second.get())); + } + + return msg; +} + +unique_ptr deserialize(const msg::Program &msg) { + // Ensure serialization version is compatible + if (msg.ir_version() != EVA_FORMAT_VERSION) { + throw runtime_error("Serialization format version mismatch"); + } + + // Create a new program with the loaded name and vector size + auto obj = make_unique(msg.name(), msg.vec_size()); + + // Create a vector of term pointers + vector terms; + + for (auto &term : msg.terms()) { + // The terms were saved in topologically sorted order, so the operands of + // the current term were already loaded and their pointers are in the terms + // vector. Moreover, the serialized indices match the saving/loading order, + // i.e., the index in the terms vector. + + // Check opcode validity + auto op = static_cast(term.op()); + if (!isValidOp(op)) { + throw runtime_error("Invalid op encountered"); + } + + // Create a new term and set its operands (already loaded) + terms.emplace_back(obj->makeTerm(op)); + for (auto &operandIdx : term.operands()) { + terms.back()->addOperand(terms.at(operandIdx)); + } + + // Load attributes for this term + for (auto &attribute : term.attributes()) { + terms.back()->loadAttribute(attribute); + } + } + + // Set the inputs + for (auto &in : msg.inputs()) { + obj->inputs.emplace(in.name(), terms.at(in.term())); + } + + // Set the outputs + for (auto &out : msg.outputs()) { + obj->outputs.emplace(out.name(), terms.at(out.term())); + } + + return obj; +} + +} // namespace eva diff --git a/eva/serialization/known_type.cpp b/eva/serialization/known_type.cpp new file mode 100644 index 0000000..a249206 --- /dev/null +++ b/eva/serialization/known_type.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/serialization/known_type.h" + +using namespace std; + +namespace eva { + +namespace { + +inline void dispatchKnownTypeDeserialize(KnownType &obj, + const msg::KnownType &msg) { + // Try loading msg until the correct type is found + EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::Program); + EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::CKKSParameters); + EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::CKKSSignature); + EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::SEALValuation); + EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::SEALPublic); + EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::SEALSecret); + + // This is not a known type + throw runtime_error("Unknown inner message type " + + msg.contents().type_url()); +} + +} // namespace + +KnownType deserialize(const msg::KnownType &msg) { + KnownType obj; + dispatchKnownTypeDeserialize(obj, msg); + return obj; +} + +} // namespace eva diff --git a/eva/serialization/known_type.h b/eva/serialization/known_type.h new file mode 100644 index 0000000..e6d1203 --- /dev/null +++ b/eva/serialization/known_type.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/ckks/ckks_parameters.h" +#include "eva/ir/program.h" +#include "eva/seal/seal.h" +#include "eva/serialization/known_type.pb.h" +#include "eva/util/overloaded.h" +#include +#include +#include + +#define EVA_KNOWN_TYPE_TRY_DESERIALIZE(MsgType) \ + do { \ + if (msg.contents().Is()) { \ + MsgType inner; \ + if (!msg.contents().UnpackTo(&inner)) { \ + throw std::runtime_error("Unpacking inner message failed"); \ + } \ + obj = deserialize(inner); \ + return; \ + } \ + } while (false) + +namespace eva { + +// Represents any serializable EVA object +using KnownType = + std::variant, std::unique_ptr, + std::unique_ptr, std::unique_ptr, + std::unique_ptr, std::unique_ptr>; + +KnownType deserialize(const msg::KnownType &msg); + +} // namespace eva diff --git a/eva/serialization/known_type.proto b/eva/serialization/known_type.proto new file mode 100644 index 0000000..1c4e3b6 --- /dev/null +++ b/eva/serialization/known_type.proto @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +package eva.msg; + +import "google/protobuf/any.proto"; + +message KnownType { + google.protobuf.Any contents = 1; + string creator = 2; +} diff --git a/eva/serialization/save_load.cpp b/eva/serialization/save_load.cpp new file mode 100644 index 0000000..f01a21f --- /dev/null +++ b/eva/serialization/save_load.cpp @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/serialization/save_load.h" + +using namespace std; + +namespace eva { + +KnownType load(istream &in) { + msg::KnownType msg; + if (msg.ParseFromIstream(&in)) { + return deserialize(msg); + } else { + throw runtime_error("Could not parse message"); + } +} + +KnownType loadFromFile(const string &path) { + ifstream in(path); + if (in.fail()) { + throw runtime_error("Could not open file"); + } + return load(in); +} + +KnownType loadFromString(const string &str) { + msg::KnownType msg; + if (msg.ParseFromString(str)) { + return deserialize(msg); + } else { + throw runtime_error("Could not parse message"); + } +} + +} // namespace eva diff --git a/eva/serialization/save_load.h b/eva/serialization/save_load.h new file mode 100644 index 0000000..30ab1f6 --- /dev/null +++ b/eva/serialization/save_load.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "eva/serialization/known_type.h" +#include "eva/version.h" +#include +#include +#include +#include + +namespace eva { + +KnownType load(std::istream &in); +KnownType loadFromFile(const std::string &path); +KnownType loadFromString(const std::string &str); + +template T load(std::istream &in) { return std::get(load(in)); } + +template T loadFromFile(const std::string &path) { + return std::get(loadFromFile(path)); +} + +template T loadFromString(const std::string &str) { + return std::get(loadFromString(str)); +} + +namespace detail { +template void serializeKnownType(const T &obj, msg::KnownType &msg) { + auto inner = serialize(obj); + msg.set_creator("EVA " + version()); + msg.mutable_contents()->PackFrom(*inner); +} +} // namespace detail + +template void save(const T &obj, std::ostream &out) { + msg::KnownType msg; + detail::serializeKnownType(obj, msg); + if (!msg.SerializeToOstream(&out)) { + throw std::runtime_error("Could not serialize message"); + } +} + +template void saveToFile(const T &obj, const std::string &path) { + std::ofstream out(path); + if (out.fail()) { + throw std::runtime_error("Could not open file"); + } + save(obj, out); +} + +template std::string saveToString(const T &obj) { + msg::KnownType msg; + detail::serializeKnownType(obj, msg); + std::string str; + if (msg.SerializeToString(&str)) { + return str; + } else { + throw std::runtime_error("Could not serialize message"); + } +} + +} // namespace eva diff --git a/eva/serialization/seal.proto b/eva/serialization/seal.proto new file mode 100644 index 0000000..1465aed --- /dev/null +++ b/eva/serialization/seal.proto @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +package eva.msg; + +import "eva.proto"; + +message SEALObject { + enum SEALType { + UNKNOWN = 0; + CIPHERTEXT = 1; + PLAINTEXT = 2; + SECRET_KEY = 3; + PUBLIC_KEY = 4; + GALOIS_KEYS = 5; + RELIN_KEYS = 6; + ENCRYPTION_PARAMETERS = 7; + } + SEALType seal_type = 1; + bytes data = 2; +} + +message SEALPublic { + SEALObject encryption_parameters = 1; + SEALObject public_key = 2; + SEALObject galois_keys = 3; + SEALObject relin_keys = 4; +} + +message SEALSecret { + SEALObject encryption_parameters = 1; + SEALObject secret_key = 2; +} + +message SEALValuation { + SEALObject encryption_parameters = 1; + map values = 2; + map raw_values = 3; +} diff --git a/eva/serialization/seal_serialization.cpp b/eva/serialization/seal_serialization.cpp new file mode 100644 index 0000000..59532ae --- /dev/null +++ b/eva/serialization/seal_serialization.cpp @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/seal/seal.h" +#include "eva/util/overloaded.h" +#include +#include +#include + +using namespace std; + +namespace eva { + +using SEALObject = msg::SEALObject; + +template auto getSEALTypeTag(); + +template <> auto getSEALTypeTag() { + return SEALObject::CIPHERTEXT; +} + +template <> auto getSEALTypeTag() { + return SEALObject::PLAINTEXT; +} + +template <> auto getSEALTypeTag() { + return SEALObject::SECRET_KEY; +} + +template <> auto getSEALTypeTag() { + return SEALObject::PUBLIC_KEY; +} + +template <> auto getSEALTypeTag() { + return SEALObject::GALOIS_KEYS; +} + +template <> auto getSEALTypeTag() { + return SEALObject::RELIN_KEYS; +} + +template <> auto getSEALTypeTag() { + return SEALObject::ENCRYPTION_PARAMETERS; +} + +template void serializeSEALType(const T &obj, SEALObject *msg) { + // Get an upper bound for the size from SEAL; use default compression mode + auto maxSize = obj.save_size(seal::Serialization::compr_mode_default); + + // Set up a buffer (std::string) + // We allocate the string into a std::unique_ptr and eventually pass ownership + // to the Protobuf message below + auto data = make_unique(); + data->resize(maxSize); + + // Note, since C++11 std::string is guaranteed to be contiguous + auto actualSize = + obj.save(reinterpret_cast(&data->operator[](0)), + maxSize, seal::Serialization::compr_mode_default); + data->resize(actualSize); + + // Change ownership of the data string to msg + msg->set_allocated_data(data.release()); + + // Set the type tag to indicate the SEAL object type + msg->set_seal_type(getSEALTypeTag()); +} + +template void deserializeSEALType(T &obj, const SEALObject &msg) { + // Unknown type; throw + if (msg.seal_type() == SEALObject::UNKNOWN) { + throw runtime_error("SEAL message type set to UNKNOWN"); + } + + // Type of obj is incompatible with the type indicated in msg + if (msg.seal_type() != getSEALTypeTag()) { + throw runtime_error("SEAL message type mismatch"); + } + + // Load the SEAL object + obj.load(reinterpret_cast(msg.data().c_str()), + msg.data().size()); +} + +template +void deserializeSEALTypeWithContext(const seal::SEALContext &context, T &obj, + const SEALObject &msg) { + // Most SEAL objects require the SEALContext for safe loading + // Unknown type; throw + if (msg.seal_type() == SEALObject::UNKNOWN) { + throw runtime_error("SEAL message type set to UNKNOWN"); + } + + // Type of obj is incompatible with the type indicated in msg + if (msg.seal_type() != getSEALTypeTag()) { + throw runtime_error("SEAL message type mismatch"); + } + + // Load the SEAL object and check its validity against given context + obj.load(context, + reinterpret_cast(msg.data().c_str()), + msg.data().size()); +} + +unique_ptr deserialize(const msg::SEALValuation &msg) { + // Deserialize a SEAL valuation: either plaintexts or ciphertexts + // First need to load the encryption parameters and obtain the context + seal::EncryptionParameters encParams; + deserializeSEALType(encParams, msg.encryption_parameters()); + auto context = getSEALContext(encParams); + + // Create the destination valuation and load the correct type + auto obj = make_unique(encParams); + for (const auto &entry : msg.values()) { + auto &value = obj->operator[](entry.first); + + // Load the correct kind of object based on value + switch (entry.second.seal_type()) { + case SEALObject::CIPHERTEXT: { + value = seal::Ciphertext(); + deserializeSEALTypeWithContext(context, get(value), + entry.second); + break; + } + case SEALObject::PLAINTEXT: { + value = seal::Plaintext(); + deserializeSEALTypeWithContext(context, get(value), + entry.second); + break; + } + default: + throw runtime_error("Not a ciphertext or plaintext"); + } + } + + // Deserialize the raw part of the valuation + for (const auto &entry : msg.raw_values()) { + obj->operator[](entry.first) = deserialize(entry.second); + } + + return obj; +} + +unique_ptr serialize(const SEALValuation &obj) { + // Create the Protobuf message and save the encryption parameters + auto msg = make_unique(); + serializeSEALType(obj.params, msg->mutable_encryption_parameters()); + // Serialize a SEAL valuation: either plaintexts or ciphertexts + auto &valuesMsg = *msg->mutable_values(); + auto &rawValuesMsg = *msg->mutable_raw_values(); + for (const auto &entry : obj) { + // Visit entry.second with an overloaded lambda function; we need to specify + // handling for both possible data types (plaintexts and ciphertexts) + visit(Overloaded{[&](const seal::Ciphertext &cipher) { + serializeSEALType(cipher, &valuesMsg[entry.first]); + }, + [&](const seal::Plaintext &plain) { + serializeSEALType(plain, &valuesMsg[entry.first]); + }, + [&](const std::shared_ptr raw) { + raw->serialize(rawValuesMsg[entry.first]); + }}, + entry.second); + } + + return msg; +} + +unique_ptr serialize(const SEALPublic &obj) { + // Serialize a SEALPublic object + auto msg = make_unique(); + + // Save the encryption parameters + serializeSEALType(obj.context.key_context_data()->parms(), + msg->mutable_encryption_parameters()); + + // Save the different public keys + serializeSEALType(obj.publicKey, msg->mutable_public_key()); + serializeSEALType(obj.galoisKeys, msg->mutable_galois_keys()); + serializeSEALType(obj.relinKeys, msg->mutable_relin_keys()); + + return msg; +} + +unique_ptr deserialize(const msg::SEALPublic &msg) { + // Deserialize a SEALPublic object + // Load the encryption parameters and acquire a SEALContext; this is needed + // for safe loading of the other objects + seal::EncryptionParameters encParams; + deserializeSEALType(encParams, msg.encryption_parameters()); + auto context = getSEALContext(encParams); + + // Load the different public keys + seal::PublicKey pk; + deserializeSEALTypeWithContext(context, pk, msg.public_key()); + seal::GaloisKeys gk; + deserializeSEALTypeWithContext(context, gk, msg.galois_keys()); + seal::RelinKeys rk; + deserializeSEALTypeWithContext(context, rk, msg.relin_keys()); + + return make_unique(context, pk, gk, rk); +} + +unique_ptr serialize(const SEALSecret &obj) { + // Serialize a SEALSecret object + auto msg = make_unique(); + + // Save the encryption parameters + serializeSEALType(obj.context.key_context_data()->parms(), + msg->mutable_encryption_parameters()); + + // Save the secret key + serializeSEALType(obj.secretKey, msg->mutable_secret_key()); + return msg; +} + +unique_ptr deserialize(const msg::SEALSecret &msg) { + // Deserialize a SEALSecret object + // Load the encryption parameters and acquire a SEALContext; this is needed + // for safe loading of the other objects + seal::EncryptionParameters encParams; + deserializeSEALType(encParams, msg.encryption_parameters()); + auto context = getSEALContext(encParams); + + // Load the secret key + seal::SecretKey sk; + deserializeSEALTypeWithContext(context, sk, msg.secret_key()); + + return make_unique(context, sk); +} + +} // namespace eva diff --git a/eva/util/CMakeLists.txt b/eva/util/CMakeLists.txt new file mode 100644 index 0000000..55f959e --- /dev/null +++ b/eva/util/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +if(USE_GALOIS) + target_sources(eva PRIVATE + galois.cpp + ) +endif() + +target_sources(eva PRIVATE + logging.cpp +) diff --git a/eva/util/galois.cpp b/eva/util/galois.cpp new file mode 100644 index 0000000..0cabfba --- /dev/null +++ b/eva/util/galois.cpp @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/util/galois.h" + +namespace eva { + +GaloisGuard::GaloisGuard() { + // Galois doesn't exit quietly, so lets just leak it instead. + // It was also crashing on exit when this decision was made. + static galois::SharedMemSys *galois = new galois::SharedMemSys(); +} + +} // namespace eva diff --git a/eva/util/galois.h b/eva/util/galois.h new file mode 100644 index 0000000..0531144 --- /dev/null +++ b/eva/util/galois.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +namespace eva { + +struct GaloisGuard { + GaloisGuard(); +}; + +} // namespace eva diff --git a/eva/util/logging.cpp b/eva/util/logging.cpp new file mode 100644 index 0000000..5195efa --- /dev/null +++ b/eva/util/logging.cpp @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "eva/util/logging.h" +#include +#include +#include +#include + +namespace eva { + +int getUserVerbosity() { + static int userVerbosity = 0; + static bool parsed = false; + if (!parsed) { + if (const char *envP = std::getenv("EVA_VERBOSITY")) { + auto envStr = std::string(envP); + try { + userVerbosity = std::stoi(envStr); + } catch (std::invalid_argument e) { + std::transform(envStr.begin(), envStr.end(), envStr.begin(), ::tolower); + if (envStr == "silent") { + userVerbosity = 0; + } else if (envStr == "info") { + userVerbosity = (int)Verbosity::Info; + } else if (envStr == "debug") { + userVerbosity = (int)Verbosity::Debug; + } else if (envStr == "trace") { + userVerbosity = (int)Verbosity::Trace; + } else { + std::cerr << "Invalid verbosity EVA_VERBOSITY=" << envStr + << " Defaulting to silent.\n"; + userVerbosity = 0; + } + } + } + parsed = true; + } + return userVerbosity; +} + +void log(Verbosity verbosity, const char *fmt, ...) { + if (getUserVerbosity() >= (int)verbosity) { + printf("EVA: "); + va_list args; + va_start(args, fmt); + vprintf(fmt, args); + va_end(args); + printf("\n"); + fflush(stdout); + } +} + +bool verbosityAtLeast(Verbosity verbosity) { + return getUserVerbosity() >= (int)verbosity; +} + +void warn(const char *fmt, ...) { + fprintf(stderr, "WARNING: "); + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + fprintf(stderr, "\n"); + fflush(stderr); +} + +} // namespace eva diff --git a/eva/util/logging.h b/eva/util/logging.h new file mode 100644 index 0000000..39752e1 --- /dev/null +++ b/eva/util/logging.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include + +namespace eva { + +enum class Verbosity { + Info = 1, + Debug = 2, + Trace = 3, +}; + +void log(Verbosity verbosity, const char *fmt, ...); +bool verbosityAtLeast(Verbosity verbosity); +void warn(const char *fmt, ...); + +} // namespace eva diff --git a/eva/util/overloaded.h b/eva/util/overloaded.h new file mode 100644 index 0000000..a1013fc --- /dev/null +++ b/eva/util/overloaded.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +namespace eva { + +// The "Overloaded" trick to create convenient overloaded function objects for +// use in std::visit. +template struct Overloaded : Ts... { + // Bring the various operator() overloads to this namespace + using Ts::operator()...; +}; + +// Add a user-defined deduction guide for the class template +template Overloaded(Ts...) -> Overloaded; + +} // namespace eva diff --git a/eva/version.cpp b/eva/version.cpp new file mode 100644 index 0000000..9c21f36 --- /dev/null +++ b/eva/version.cpp @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "version.h" + +using namespace std; + +namespace eva { + +string version() { return EVA_VERSION_STR; } + +} // namespace eva diff --git a/eva/version.h b/eva/version.h new file mode 100644 index 0000000..92fdbf2 --- /dev/null +++ b/eva/version.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +namespace eva { + +std::string version(); + +} diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 0000000..7956c43 --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1,8 @@ +*.eva +*.evaparams +*.evasignature +*.sealpublic +*.sealsecret +*.sealvals +*_encrypted.png +*_reference.png \ No newline at end of file diff --git a/examples/baboon.png b/examples/baboon.png new file mode 100644 index 0000000000000000000000000000000000000000..bb3cd656286b8461fc3131bd5e486d884884d9e3 GIT binary patch literal 11917 zcmV;8E^^U{P)00001b5ch_0Itp) z=>Px#1ZP1_K>z@;j|==^1poj532;bRa{vIiivR$)ivgap9Z&!O00(qQO+^RK1_cid z2qfLI5dZ)H7<5HgbW?9;ba!ELWdLwtX>N2bZe?^JG%hhNHDpIvQUCxg>q$gGRA@sL zw`XuwXWA{OUsE+TQ*&?Ed~?74+Sj4owz0v0NkS5!tdp}&&N=6tbJU4CQAavaIU|G+ z2t<-V1{>3E2Qo-R8%z)g?CzfVF}hXz$6NcyuC?B^*7K}gLqmN*LT|^yc9`|ePG_4x z+NRd*N#r1Ace~aG)Hq?bF1pUw=}d_IJ^Dz+tqbdTRxwr2Wt#a4mq2XeYXWL*)GBuC zI3@{IDQ2o97$Kc$(v#Fej$4D^Q}}@{u|KHq$+0X}uFKBR>qS;QQ!l6MRvhVGd8UUO%drx9wk^cZ7r1gAP@w_U zbhdxwGS}L@z*xN1=4x^JTeAJ6cn{DXfu;tzDi_`mW9oufmICfC+2tOYChTTAw33L6 zspp7oTqGVYvg@cS5t4!gq5v2k4adi!C=dv?140G>3GGlC0Kvjz#XO9JOLtpH3OQLO zz>B$fnSi90F--=#$t(?oP3fYqJmx8sm63qlDzb#s8i&H2u;^UOSf44{ElTAS17oV5 zA$fjSYxk21C5F|5lqwr*Fu)JVxGfq}V=@g5r%-`zaNj^@tU$EK5Wy_P9wqxzOt~E? zF?LG)3Ps2!^t-rD3tb~;>18Me8OgvQa1a<4LRYBRCNrKx?&LB+XgEeJgfK8D0T0fj z(5yCbdek(2(mOS4DD^XRI;M6LbWoQmI!lAD@{l#1)pw`Gfv_r(l6JZ28X-?Z6{=WL zqu3D;`V*@BfTp`B_e4pt3_sT|4|Y==9*o8gHTn@Xc=TsF^>&Tspwka@hY-C36n7Y6 z@U@0}D7FaNoo28k?NS|x#BL`Du?!`PDCe4}Fv<~BEKbZM%Xl5|CJd885vwIm zzdDk(mM3-T0#?jW##8c4#om9!S)4NG`muZ(oP%itHPiGOyhMx>G7%Iq2#Y{ckr+M^ zL58Bqa3mg#qaiwpEi?vHtYV5}NRHrQhPxZpH_B34p*&R^3cZIRgb~=yP-HCvcK`&b1Hu{ru!g4Q zo$Y}A08lLqQrq6Ly}fm7^PwFOB#3FUF^v|wTuwD>+c9lem6~BR;Dj_J2?&Di1tMzU z#8w0a+6t*_2JHjFTOf!wFl6&Pd)k59!LYq;?Yn@mgD5;0gK49Zni&F!E6DLB1o;73 zwvUr9b2Ecli-!^lfpiwI!HcUgsP}W(dtHutqoFmHBos#&T?t6uCvtAi}NyV)PdWqKL0K9(auG~4h5A_&xS5D08*1GIo3txbT&hQ>WRwr<6B z)+fn-4r1OnL;hj~eE{xg>_CDKHSgQ^)s}|6TN@AV0HY70xJZ-$2_ZBcJhYRBZ`JD= zWF>XazCD=`;6m^}U6OqT(f@rm@IMr8{=tSlYy%t30K-WjXlE0Vk7WudQllu8wz?B? zUsBMOlw=Fa!iY9o=JpH;%{FEz!B837YOIEaR0`$sG&voJTpw|0k~(^rF)(FJmF4lG zLa4+NiO^0W0*q;IZrqC`gTdquFs6Cm-d)|U=F8gulh^d24f)9;@4u{a?A|@Qc_tPZ zzJKSI4-VCT*3|K76L8xh=+=X6TS=XbY7Ru?pw~8R`{J`ty4;}ey8io^;lIu4{&d#- zzSJygY(5aq`)o-oj!!1Z*bKEqPqNHDbSO@)PReaAs@ejV zX+bp_`F@A>fYAU5MG$6Nd#0CF8s_FJ!s3v;J0*yuG&Y|~>oCw|EG(Zu)k$0#7u`tN zw(V0gtN9zv|MrvRfB(V!f6U81fPh-R{N#NYrj04=Y~B0Ot}p-G()uwBwH*lE(%kwt zF1lWAra&o8XhiLoUw+UCt}PkPBmLC@vCsgs0L2bvIU7eNy!QchgxJ|pYx(1u2Ih8~vyMXClYTvLl zmBaS-2@*+Kpo`%QN^M??)}vz?R7{(}-V>3SXbrHstvf#hw0vSi|H+JfuYJ#!tslMD z(!5PTIY^aKQ1FHW+uv)he-DS>jRyUVORiIjQ4~F`libMXgLm)$a_8<3w|@G;?k(>( zw$>elBj`qvD{c$+_?S`}hKW|VbTWrh=g|f8&RDM{JL*pl86#yw{&1u`7YL;}siMH* z=Q_G50&!!FQn<@*Z}9|LZDwGhNc4BX#tsYe1!642$`1&8`lVJM*B!T+{06Q`B68@N z4v8V5C94T=A$m_oQ(Z^pqpkMJ1y(dG~>d`yxd?ZOk zXn^jeDJiCem*dottws=+4wneYN{Py64s`qd=^%+u?_}bsn_JGN5*m0ylRQ!gBnHFL zl6_>>(Ld#nmDH69RrjDioZ|%I6lbVYtl|(m+b^t)h+{ph zfe}e*L}d>$oB@R`WY9R|GMkENl+(>(zLv>TQCLO}MNVNGI6^I7VK%CKE{d9W2-&f( z?h75Mg?8v84;vgK?(PgGoj>&E=SYHB`)Yqp>v{ zgjg)CnLIOB;v{-#Jr!oS{DH*l|~O z#1<99H zQx{;Z4P7?VpnxFW`(BSmagqAjDM4$OirLVym0|E!t~?t5_HN?z4gKs@{?sM<$hXAA zJdAIKQ0Q2Zj3^*6Wi*M2X$x!JIcKQH94(q7MN@v#lbi6D$BorlSG1(_WQC3p!x3j{ zHoK67-{7``bc%L|1Cofrd_Gj6$Vg>2+Y)5_h@vOY)@exH8I{}3muqPl5&%m<;lz{{ z1PtytOQ`zmKCLBo$!}))P}QW34^*f zP@0*#ayz>CQ{uuS*UGy5!XK=K7yQC;rrFaeP(cY8l7dSS;{|H0*u+xV_>QC@*>85G z&Mn4;gSm;+s(pqcRahCP$4V zI8G-m5ofx*3?a8uCPPBeO#p01V8CfOi%;camwm^tS1S$SOT8|a1a zysF_|rPTMrKmI|&dPVs37H;cJH8;mT5T5;e?&6*N%Kg;k-?}cnQXao6F3rm)zZ1^h zXAfM(sm)BYjv`|)q@8jT(HNB5;!0mun;39sCY`AfTXe`cH0O>ETYSCJNR_7z675ND zjZ{!C;MB{-0Ij|~6s5+KjKKkxP6OSHLA0A~wlZQ-hSktYA~ea=G=WZxm&;(vB=_h; zQTa6c*dzV<=k9ZlybHI6R_={_^Ju&>yn@)iW)R&(eKvvk>{}u01$FyN>if49#ABqU zO#+3lS4)>4m9PBjUc99^@}uD7Q|bA4{1d-&yQ?_4l;UPYbJ2NL#c@u*rbhjDbl9jrdv&N}$(6FI2>- z9=6-VGMexvJKND^G6xk@Aq{{*AgmST*%y|RtDMolTh2a>EImq{zgszf^VoOyXDfxP zojab(XseL-et`V-DHp$i``at@|NNDWx+Ek-4*$LH>$QQ)4@(zUoeMY3OY8DY^z%y~5 z)N|*#Ura}@8%{k8E-VSQFStB%?8b>u)i;w z90~Y)t;MN8b=sS%I=YMEa8~F|3VdmaEsT@c(KT#(GmFxy5H-vB%>fr&DF^x-c$>D} zq=Tj6jP9h+9TK=BI+Ii73>&yg5lJ0krOtAXKewL!C3=32Y-FLGt|C_}6zA zQsI}Z)|d2-JE#xVNxNV0I$sk%eS>}P84tB8CdQ848UOad#5eac7ak=qz3^RF*PeWA zKKaPF^jg03j#^n31Y)x8oHmt>j#d2Cp7>BtS1D2)&-9E$V||Y7pebI|be9x{0L7aW znPW`33tJ=P0XWQdsh~wK2a9=dj|r;R;0(sjt{~Cu<5ANmR~Not>%X{K`1+T~g%@2*Pfhdp{TE+pm*4TvZ19FIGJ;8w&&7{~yp^JB zupA!FSB?}?6Wz&bATyGRRXnNw&4VjkcIYA!u>&PAx7EnF0F4xCRzrm>h{M6|$?~FI zESG~63bIQ@X(B1P|?6amuN}m3n-kEZ8GUHF`N`nSl*bwf~o4kBi zLMSq$2$IcT)214&5NJ|>DVRn9v(00}T8tR87MqH*3MHW{zzz60PB&8_>(o2sDx09S z147FEWWBJ?nYybw_uPB_N%Z`KSb;{^H}x z?D6R$dTrd=JrU1Lr!ymw!f3!>hy-(bi=S_cv5j#m)zF3(9IO#=pnM+EWJWUa2ieRP zqY0^hnU-I3$9|Td`CmCsU3i#0dv9Q6 zExG(SdH&(mkn1%Y{SN-YI`ot0bksw=-C2|?!`N~i;t5lzxh_yl_%Epi%&Nv60J0Zl3{6P zp1^N(p=F;gvBo+uY6bVDCw)3wx-7g*Zhxh!iJr^JLEZ;l4 za((pDYUIqr8#em}x#J!9zcv6L-seiM*;7v?q*v4}uL$Uqj0;Oy^$6M}PEgPNeg%sHq+V7mbnOX58!p!y2J{{1&u>6mZp$s- zW@esoQLj;-zeaqxN`NI!-A|l-Jb(FaA~$u-Mle!V&8arAlJIh!qQu_x8`q)KLGE4*yBu2Wzo(sc-#le%dkl-?xJYVd%T zdLFwOaNtWE@gSGq4nZ8G^I%LC9E)$GP}&K&_STkzEzF2u_$SK9s&o0JYxbI9^5@}8 z_XjTD&!4%uc;#;I`L#P5)oUj7ci?}#hHPDxSiZ|1{Z&AGgZ=ys@~?NP*wo^K?77D$ zF5l@rb9eOX2mN0^2`@b}kNucmeqvvGEn0Yp6fru;5VoH0h-$<^r7EB^#toLNjc?&9 zyb_h4&DC^@EJTrw!m<#=A+gBHs`052YB8Qk>!4Cv@z^E;;UEHYK%xW^sC676NGQV) zafdoucf!DR2lwwgh&9PiJjbQK)1Fz6Ej_kP|6E#L?f-f$zIb=@UUTQx9x$ozkXzpY z|L054t~IgsYGUFEpZ*5@`5V}$x0#I0=?8-s?@uk?tSqgbx^i#w@*~&m&ESz+nTt<1 zEhIkj1CW6RvET?Mm~UoiA_}obtP3ct-8#9ODX}v+Dx^S-CJR~#JfOfXU>eCB9j!(u z$9UXwtC1<>bdV{BxSRucViOj*4~5#(0osSbY<8apIBYi(dSLsuFSqQ+YZu?Jax0wq z=iN(Bx=-9noV;`N^3CL#dxtOF$}c@wWf6Y|{pV%|UVwL9*QKuX%sgddHxQq`hHt&a zW+qNQocMaJ=lJ!}i>t>j-z_aaunvEpI&(j~{F~?EI;*k*VPo)eA zkUQEt_L1m-L$zOawC-qc*;8A)^S!@($vXVhGPhadZideOQaF3pKYOFPbhmo`_TcFo z{TJ45@QD9x|8Kvy|LHft&P#^A<<#6Az2+_K<2SIs{UB2(79UPryq7t3Z|vgQkxOgc zbJwCX*YX#CF`r-8EIy|ePQz$GoQ#5zz-d}K+r(fQF)THLE<&>9Xet*dw$X)VI#CE^ zX()Ieh$w{;1#LAGz3@VUQt4BqW6V&1Y`0)^TBuynPGj!JlMVqozG`XQiorJ{5Y6=m zcfa@N_o4n%_QfaM?1k{@$HV9EMUPy|pIc3zxtBh6ZE|JpTa)FF_CIZO{O2q1wjb27 z3!_Uv_6)p5f3X4i>t#nUzOZ)q!tM0p{fUdKvzPA$NB%x=el@%DoBr$@(eyoXZVX84 z08?O0r;1@0F-=^i8p~Fr=>jlEhhu69R56TW#FL~bJO_m40#IBqlGIvrxJsBCp_a1L za(8Ecnw(Ej6H#m=2qzNu)YWd=v->juU}x*0t*vdfd-i?t(MKP*C@Sic&$PwM&Jz#E zFWgBU`LTHB_R!Mp@X=dS7w(;j7ykf!umSw!1+?~CL+Vmu;-;(X4WZ#R_={t?nf%$+ z(X*S1t<9|5%PrhY%wC`SW-WOBg=6WBc;XgNqHcgTfam}i1IBTPSSEqUDP*azTpfvH z#$y>BECZh7;80~~k_drh0x;BO_@;%%6VPKNrLlS6 zp#yu{o4?q%Yim=>zCHVP{q^Hd>aZ5}YBzu%z5slAldrpM$voim-hw}P4s94Xc4PFyz2w}D@r$c-U*8Q({A2dw?YVCr zgwDUv9(^R6{1HIvY=ZCaAcL_=GMoXy36LDcrUE#gmOv3g>EceZ6oum>8Cn)aipKF! zD10lN*j_WAZ96l;I8tUE?iZF*RIeSzVm4whdyx40L(Ti@_SZJI?W(W+tYP0*^=mcVYV_K~#k=F@u7xMA9y@<+`rM7`@oSR{KR?GezXp8py!oTs81nZT z-wRU5+m657rqZfs*EXRCH-);mI=pl*Jo)3XOLsC$PrXYo6=OH)#gp}*hB{C^fQ|yO za6~PSuA-A9Xo?U{6~ZYJ6kUwMv4Lbc8ZX!!L|~Etf+RP>a1AvVh9MV5N#{l=b7ev? zf+)n$K7Er)dZ?{s$AS7SwYzui+qGr;ryuRvy<>mfo{!%DXm>jn6P!1kd}<#2F1oNf zbM9Jj_)2m1$Ku?z%;6tL&)&YpqrPtY^NZ$BZ&Ueq*rM0)?eALOyDm{?PTVV=xE7zi zIkt2!HGe%kc6I8~edpQd+Nm4X;}4OJ?%gfB_cYhFlhGgs94}{aTndVYN@N2V3IbO} zAaFrs0i2;G5Eu{|A3|ielQ^w7T2syP0DO56yI91W&OuIO+e`jduN|QfgP`DT&4+f^ z)$Z7~EeH%bdPJ^hzP&flCoe={}n^YHvn zx5Vt{fKQ&aetv@?Srth)P}|?MetN~1E-l;&R9CVmSI3rD4Lzre$FFyv`$aeZSXaKR zn19ql*EYZ*jh*l|92`IdVZ~Iuf=QH6@N58_)W(qFIchRfgkearObLn1K@;fB41P0_ z)l{=ohMn&NEftYx3h41TsNidL8XKI-eQIel1ldx%ciXNV+rIql<9!F}4|X&kgu}P& zZD`?yh_O@VWA`PMZ@L%mS))gLj{G!u;zoMrT6y}LC)k6(HU8yg#}{{)ygLfrI(E-X z_?~agkt3Jx>N95wr`M|IR)zNd{Mm>4lP^T&?{$OM=*24?0`unJ1L28Kd?$iRCChj$ zwUi;_64(GNvyCA{(FG8J3MbT2dFoCarM;8XfhX5dScht^j3K@m!mX6iiy3HIvoCG0 z?+t@vUMK=m`{fs3v~@H$x3)l0FdT;pM8oO;9lILq_5d)Qo=IxoRbg^on7QN`x@=0F z&L6p+o&EWUH~ARX@w)ZBb@0|Tw&1ad^BQ0O98>$V-c*>qs)^2|PON&zt_bZVH1r=+nB^>F)?U9D+B4~|t9V=VviRE+x~i> zZC72>{vGvu>oEA8?VU*L0Kq$>82*8zE}8O6B3Grjc(?b+kKY;%FCbrTfIoW!-u6%+ zeojNb>8yPL-}8V;$&Y<)N-Ydtcqk5yi(C`3^ks2o**y9Wgt8ayoF@8)P;wWP$%D{H zZ73jGEI`nRI3kG5M`KxS3=x5+AaaCgwv5CUp#>@eh1S9316Tq`&CipE{xJf$QUP2p zfW932e?C4q{~NA(v&L-zwroK6Jf%@z657|X^{=rF>)5u6v$t^UM*8?IrlDvX zyJRL=$iyFtQc`h7(VVuFD@2Fe`R3S(RFV?Th}2wUtvMNG3n1q z@K^Y@H~9M3$lB{B*Yt^>L!(#4_7QDjk>{ug12Z^ff#f_)+|*N-gi4%nuDGKU48mjZ zVgW-bXDEb3CIUwU;E7<4ghr*q*?b&@3J?g8A`wC?LCR#PnrpMb8z-Q*79rOcF*oO- zKMq0iW}saQQA=PHCSG9Fg$kMUXel{bb|!)(9=(G|#mdc{x)4O@rHdnzBHd47-V-L` zHT0{uurJ?%K3<1^^$LCPITrqm%6vtFz9xV-I_oyzJKrGpJ!DcA0+mqlTkYUAj(3XV zJxVo~G13r4n-JKOEQ6D6w!_&BG>64C8<2NP&^)P%>vT2~Y+P#^qw< zN|Hp0k?Tk`*H1z2oCe-r0Nh!I-n)RiRt5CCnthrUz6c0Jv^t`$(p051-P<>qP4xP( z6!d-oqJyJ_aEugVZVUxkBO)L1xW6$7uSk#$#NG|qju+V0m$=sFWWoyyc75|D_TW48 zp5GC>-jP~1C~a@?hn^5Ssw%%Ma+c$nVY{ZOmVUIN8^Lv8ERjyNQ|3y_Jwbs{Eix&@ zdYvKYlF7QFTuV$W)4;~EwAP$qkz z^E{DTej`&};!w|6;ursmL+3gXyp97tF!do(0`s0A4HXnQ$F)NDj&raEThz6&05&uF61ks`ErcXa`W*<8O_jLRj z)c#|j?T^WgZ|J}cF60raKCDlQ%>&}DNxCV)wB;m5x7-!xJ7P?oL*VktBO!^`DzQj} zGKNsXm1vb@HiM_+O0^t`o~t$r909S_$*H+EPr5OKxw{Bmo!`H9V*l+U&6D0HuY8YI z(Ik;W^La(JVwybMH+{UnI9>|(_VSI!&4!L;NRPl_>!59K!P__Q_$7;WuM--p|NIKR z^%p+vkwAOfS6Z`Xp6G2aRGb^6gFVgrFW}G**qrAS`0s?qe-aP8g?&+Aie$DPWw4JR zwPX2mo-@W&*pv=GLuJ<_depImxi76tdNdBLSWOYCC=#WNBNNNa3Wb`lx5+$7lPfH) z`TGq0$5ZTkr(4#JH{6?PxHVdPBG&8^)TsprT~=IA6gpZ_&y4woheAV>RevhUQ7eFO zESW1ij6^&Ke6<1IxoMb}c>Gh5b!Twe42Nx}6qgj=(KTN919M(Tbn?mwMuib=oH z3+s~+WGRB|4PeK56%%=Gb+8!AX1$55%H>7j=m|9bVaMKAutP71)Ms4HOB!#TK!1$_ z{@k(au7Lduzw;>*xyoZaL)5(?Vc+19za!iK39kJ&=&OGtcl{f-D+9r4y=+I+;POb(xGFwb`X{g>_vS zUCosR?GID9pJq@$PXm7(?)YJ#@k(z~zqQ_{2d8{38Bbd!ikQxIoR|(w9i6RA97*(6 zH3k!rBJ|1)m$C45MB@gcbpw9rExhq9yk!FedI4*`#S=eL=u`m8Gu{1#{CFiZ zSD7_M-7r$MO-`@#(&5cX()UBv}~g8yT@j;%b#~f=GP=J@f{$?;U8zJK&CY z;QhZt8{PqSyg=5jvry}uO>a?cZ!y4iBKCDB;0+f1JLbUu3;x@`fgk+_-#;Rdm`om~ zOvm6T^$s6btn_$Nz~Y*r_zbE-eRWf%zU+;Z}gar0hKpy z3FZU7th;9EsQmIg=j&tAZ;!DrPvgHY)UTu)N6q`ElWn;GxDX+ZC&@EeUfE0=%qOe4 zR5Dcv}d)xZqAFY4=1Nh0m+CKVM`-i_m_WjDIZvbom3ETZo z@W<~EyWb(|{y;SR8@Bh~;4K@FeLr%T4w+VCaN9i*s#vEshf9^AzHnh^^0>}oq;YvY zV-tOo{R5N5%7}Yt+*TRV`rTB6o+Xt~buN|3tFXt-uB0to4%952(4L*MTsUF5a8!I{ z0DQG{@It=zY){i{ytNpGRdcdF!-452VYXzN9_$$^X=`mpR}Go;kTNFfu&Y9X~ErT*aZ+5Wr`s)?d)To1L9E z==@(u*tdxKKOj4QM{e0bHNJ!HeM-e%XW$ERRVvhF*J=ldir$BW|FC(diHWJjLf>3pe?HY~vY5C`NxC}XOZ6ptQn9SFuiu-BQF~(4zO2#jz&LFr zvx#rC3iK|OF{m@g{l>7v6mr%KOm4}W2lv(yhB2qPy_oy91)8iN*n z;A#Lpn6*{o=CNXIq%TtL4QC3K$rGo_n=u=!o?Mtead>97I5b`v9qpeQ>M4(uhen4d z#*5X7Y<{lVKby$T8cc^{>9JUPW^!_VXz=Xp%+g@_c%eK~twbG>N_9wQaVl*#TQKU+ znl85~5v&;)iuYBMlf&-b zs%L!6czlR{G>0tZWs`lRY6?HljVXozNdqKlR#$??k(haEBtJe7o1B{;nH&rFY<;<& za(Cj?aQ{kg@?V)=}swFoKP^SvY9u1>56Cbqvd4AF*q2{_3FFh!bF4}@(^5Bra?>B=ovaA zeRGR#t^%FkqYL@vPFGE~($hPXs160vCC~7nak|K;C_L!nJ6QZ{GOW_PJg-Q8({KZPXylq*!|Px$ulY3NG|{@>}SZ&kc= zD&dK6pkKnQ+VpX?rj+cO7(EsyjJ6?+S*LVr5lt%-HXJyt|4M)d|TCPRl_ zhfyjvp%a8$q{+n9s5eEBt4$`A#VC>+WM)^5zgn=x-2R?ayf2&|iKm8o1NrV)wcMTe zM-sMB$msQP%zCIv-s#ZcIPmQW2VuBRKV9Z8E|eCI_AZ{DnV*|Daq7gvk*O0SgXenF zXM?8G0movbiO;c9;=lkA?HJ3Ki(Jac@s zQqcBf^`%}#zDH$JHrlLcivwY@ViY=zNCcJ1IvIG7Qo%NPjWV@>MW^z`d^(+i$6#x0 znS?o!^k!o2bUfG-_Vg42rAnaM=Z!hs0jsae;PUbfYP4R1GHOu@=6(wwkg&k33F)aB z`>BQe;)(L&*}2naXU?2BeEQhbnYr0Bqr+#%#?Fq8F6VRS3Vr7*eJcZ%%lXvB!Sc%V z;L=#>_(*Afw0P>+^r;iWizh3m=W-{G7bmKO{yuul-x>3ut%d_a;USTvStRPDvws&Ta6J2IRc8H|h!_l^%{#s{ +#include +#include +#include "eva/eva.h" + +namespace py = pybind11; +using namespace eva; +using namespace std; + +const char* const SAVE_DOC_STRING = R"DELIMITER(Serialize and save an EVA object to a file. + +Parameters +---------- +path : str + Path of the file to save to +)DELIMITER"; + +// clang-format off +PYBIND11_MODULE(_eva, m) { + m.doc() = "Python wrapper for EVA"; + m.attr("__name__") = "eva._eva"; + + py::enum_(m, "Op") +#define X(op,code) .value(#op, Op::op) +EVA_OPS +#undef X + ; + py::enum_(m, "Type") +#define X(type,code) .value(#type, Type::type) +EVA_TYPES +#undef X + ; + py::class_>(m, "Term", "EVA's native Term class") + .def_readonly("op", &Term::op, "The operation performed by this term"); + py::class_(m, "Program", "EVA's native Program class") + .def(py::init(), py::arg("name"), py::arg("vec_size")) + .def_property("name", &Program::getName, &Program::setName, "The name of this program") + .def_property_readonly("vec_size", &Program::getVecSize, "The number of elements for all vectors in this program") + .def_property_readonly("inputs", &Program::getInputs, py::keep_alive<0,1>(), "A dictionary from input names to terms") + .def_property_readonly("outputs", &Program::getOutputs, py::keep_alive<0,1>(), "A dictionary from output names to terms") + .def("set_output_ranges", [](const Program& prog, uint32_t range) { + for (auto& entry : prog.getOutputs()) { + entry.second->set(range); + } + }, R"DELIMITER(Affects the ranges of output that the program must accomodate. Sets all +outputs at once. + +The value given here does not directly translate to a supported range of +values, as this only ensures the ranges that coefficients may take in +CKKS's encoded form. Some patterns of values may result in coefficients +that are larger than any of the values themselves. If you see overflow +increasing the value given here will help. + +Parameters +---------- +range : int + The range in bits. Must be positive.)DELIMITER", py::arg("range")) + .def("set_input_scales", [](const Program& prog, uint32_t scale) { + for (auto& source : prog.getSources()) { + source->set(scale); + } + }, R"DELIMITER(Sets the scales that inputs will be encoded at. Sets the scales for all +inputs at once. This value will also be interpreted as the minimum scale +that any intermediate value have. + +Parameters +---------- +scale : int + The scale in bits. Must be positive.)DELIMITER", py::arg("scale")) + .def("to_DOT", &Program::toDOT, R"DELIMITER(Produce a graph representation of the program in the DOT format. + +Returns +------- +str + The graph in DOT format)DELIMITER") + .def("_make_term", &Program::makeTerm, py::keep_alive<0,1>()) + .def("_make_left_rotation", &Program::makeLeftRotation, py::keep_alive<0,1>()) + .def("_make_right_rotation", &Program::makeRightRotation, py::keep_alive<0,1>()) + .def("_make_dense_constant", &Program::makeDenseConstant, py::keep_alive<0,1>()) + .def("_make_uniform_constant", &Program::makeUniformConstant, py::keep_alive<0,1>()) + .def("_make_input", &Program::makeInput, py::keep_alive<0,1>()) + .def("_make_output", &Program::makeOutput, py::keep_alive<0,1>()); + + m.def("evaluate", &evaluate, R"DELIMITER(Evaluate the program without homomorphic encryption + +This function implements the reference semantics of EVA. During your +development process you may check that homomorphic evaluation is +giving results that match the unencrypted evaluation given by this function. + +Parameters +---------- +program : Program + The program to be evaluated +inputs : dict from strings to lists of numbers + The inputs for the evaluation + +Returns +------- +dict from strings to lists of numbers + The outputs from the evaluation)DELIMITER", py::arg("program"), py::arg("inputs")); + + // Serialization + m.def("save", &saveToFile, SAVE_DOC_STRING, py::arg("obj"), py::arg("path")); + m.def("save", &saveToFile, SAVE_DOC_STRING, py::arg("obj"), py::arg("path")); + m.def("save", &saveToFile, SAVE_DOC_STRING, py::arg("obj"), py::arg("path")); + m.def("save", &saveToFile, SAVE_DOC_STRING, py::arg("obj"), py::arg("path")); + m.def("save", &saveToFile, SAVE_DOC_STRING, py::arg("obj"), py::arg("path")); + m.def("save", &saveToFile, SAVE_DOC_STRING, py::arg("obj"), py::arg("path")); + m.def("load", static_cast(&loadFromFile), R"DELIMITER(Load and deserialize a previously serialized EVA object from a file. + +Parameters +---------- +path : str + Path of the file to load from + +Returns +------- +An object of the same class as was previously serialized)DELIMITER", py::arg("path")); + + // CKKS compiler + py::module mckks = m.def_submodule("_ckks", "Python wrapper for EVA CKKS compiler"); + py::class_(mckks, "CKKSCompiler") + .def(py::init(), "Create a compiler with the default config") + .def(py::init>(), R"DELIMITER(Create a compiler with a custom config + +Parameters +---------- +config : dict from strings to strings + The configuration options to override)DELIMITER", py::arg("config")) + .def("compile", &CKKSCompiler::compile, R"DELIMITER(Compile a program for CKKS + +Parameters +---------- +program : Program + The program to compile + +Returns +------- +Program + The compiled program +CKKSParameters + The selected encryption parameters +CKKSSignature + The signature of the program)DELIMITER", py::arg("program")); + py::class_(mckks, "CKKSParameters", "Abstract encryption parameters for CKKS") + .def_readonly("prime_bits", &CKKSParameters::primeBits, "List of number of bits each prime should have") + .def_readonly("rotations", &CKKSParameters::rotations, "List of steps that rotation keys should be generated for") + .def_readonly("poly_modulus_degree", &CKKSParameters::polyModulusDegree, "The polynomial degree N required"); + py::class_(mckks, "CKKSSignature", "The signature of a compiled program used for encoding and decoding") + .def_readonly("vec_size", &CKKSSignature::vecSize, "The vector size of the program") + .def_readonly("inputs", &CKKSSignature::inputs, "Dictionary of CKKSEncodingInfo objects for each input"); + py::class_(mckks, "CKKSEncodingInfo", "Holds the information required for encoding an input") + .def_readonly("input_type", &CKKSEncodingInfo::inputType, "The type of this input. Decides whether input is encoded, also encrypted or neither.") + .def_readonly("scale", &CKKSEncodingInfo::scale, "The scale encoding should happen at") + .def_readonly("level", &CKKSEncodingInfo::level, "The level encoding should happen at"); + + // SEAL backend + py::module mseal = m.def_submodule("_seal", "Python wrapper for EVA SEAL backend"); + mseal.def("generate_keys", &generateKeys, R"DELIMITER(Generate keys required for evaluation with SEAL + +Parameters +---------- +abstract_params : CKKSParameters + Specification of the encryption parameters from the compiler + +Returns +------- +SEALPublic + The public part of the SEAL context that is used for encryption and execution. +SEALSecret + The secret part of the SEAL context that is used for decryption. + WARNING: This object holds your generated secret key. Do not share this object + (or its serialized form) with anyone you do not want having access + to the values encrypted with the public context.)DELIMITER", py::arg("absract_params")); + py::class_(mseal, "SEALValuation", "A valuation for inputs or outputs holding values encrypted with SEAL"); + py::class_(mseal, "SEALPublic", "The public part of the SEAL context that is used for encryption and execution.") + .def("encrypt", &SEALPublic::encrypt, R"DELIMITER(Encrypt inputs for a compiled EVA program + +Parameters +---------- +inputs : dict from strings to lists of numbers + The values to be encrypted +signature : CKKSSignature + The signature of the program the inputs are being encrypted for + +Returns +------- +SEALValuation + The encrypted inputs)DELIMITER", py::arg("inputs"), py::arg("signature")) + .def("execute", &SEALPublic::execute, R"DELIMITER(Execute a compiled EVA program with SEAL + +Parameters +---------- +program : Program + The program to be executed +inputs : SEALValuation + The encrypted valuation for the inputs of the program + +Returns +------- +SEALValuation + The encrypted outputs)DELIMITER", py::arg("program"), py::arg("inputs")); + py::class_(mseal, "SEALSecret", R"DELIMITER(The secret part of the SEAL context that is used for decryption. + +WARNING: This object holds your generated secret key. Do not share this object + (or its serialized form) with anyone you do not want having access + to the values encrypted with the public context.)DELIMITER") + .def("decrypt", &SEALSecret::decrypt, R"DELIMITER(Decrypt outputs from a compiled EVA program + +Parameters +---------- +enc_outputs : SEALValuation + The values to be decrypted +signature : CKKSSignature + The signature of the program the outputs are being decrypted for + +Returns +------- +dict from strings to lists of numbers + The decrypted outputs)DELIMITER", py::arg("enc_outputs"), py::arg("signature")); +} +// clang-format on diff --git a/python/setup.py.in b/python/setup.py.in new file mode 100644 index 0000000..c9ee4c8 --- /dev/null +++ b/python/setup.py.in @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from setuptools import setup, find_packages +from setuptools.dist import Distribution + +class BinaryDistribution(Distribution): + """Distribution which always forces a binary package with platform name""" + def has_ext_modules(foo): + return True + +setup( + name='eva', + version='${PROJECT_VERSION}', + author='Microsoft Research EVA compiler team', + author_email='evacompiler@microsoft.com', + description='Compiler for the Microsoft SEAL homomorphic encryption library', + packages=find_packages('${CMAKE_CURRENT_BINARY_DIR}'), + package_data={ + 'eva': ['$'], + }, + distclass=BinaryDistribution +) diff --git a/scripts/clang-format-all.sh b/scripts/clang-format-all.sh new file mode 100755 index 0000000..1f05f84 --- /dev/null +++ b/scripts/clang-format-all.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +BASE_DIR=$(dirname "$0") +PROJECT_ROOT_DIR=$BASE_DIR/../ +shopt -s globstar +clang-format -i $PROJECT_ROOT_DIR/eva/**/*.h +clang-format -i $PROJECT_ROOT_DIR/eva/**/*.cpp diff --git a/tests/all.py b/tests/all.py new file mode 100644 index 0000000..46f2615 --- /dev/null +++ b/tests/all.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from bug_fixes import * +from features import * +from large_programs import * +from std import * + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/bug_fixes.py b/tests/bug_fixes.py new file mode 100644 index 0000000..292401f --- /dev/null +++ b/tests/bug_fixes.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import unittest +from common import * +from eva import EvaProgram, Input, Output + +class BugFixes(EvaTestCase): + + def test_high_inner_term_scale(self): + """ Test lazy waterline rescaler with a program causing a high inner term scale + + This test was added for a bug that was an interaction between + rescaling not being inserted (causing high scales to be accumulated) + and parameter selection not handling high scales in inner terms.""" + + prog = EvaProgram('HighInnerTermScale', vec_size=4) + with prog: + x1 = Input('x1') + x2 = Input('x2') + Output('y', x1*x1*x2) + + prog.set_output_ranges(20) + prog.set_input_scales(60) + + self.assert_compiles_and_matches_reference(prog, config={'rescaler':'lazy_waterline'}) + + @unittest.skip('not fixed in SEAL yet') + def test_large_and_small(self): + """ Check that a ciphertext with very large and small values decodes accurately + + This test was added to track a common bug in CKKS implementations, + where double precision floating points used in decoding fail to + provide good accuracy for small values in ciphertexts when other + very large values are present.""" + + prog = EvaProgram('LargeAndSmall', vec_size=4) + with prog: + x = Input('x') + Output('y', pow(x,8)) + + prog.set_output_ranges(60) + prog.set_input_scales(60) + + inputs = { + 'x': [0,1,10,100] + } + + self.assert_compiles_and_matches_reference(prog, inputs, config={'warn_vec_size':'false'}) + + def test_output_rescaled(self): + """ Check that the lazy waterline policy rescales outputs + + This test was added for a bug where outputs could be returned with + more primes in their modulus than necessary, which causes them to + take more space when serialized.""" + + prog = EvaProgram('OutputRescaled', vec_size=4) + with prog: + x = Input('x') + Output('y', x*x) + + prog.set_output_ranges(20) + prog.set_input_scales(60) + + compiler = CKKSCompiler(config={'rescaler':'lazy_waterline', 'warn_vec_size':'false'}) + prog, params, signature = compiler.compile(prog) + self.assertEqual(params.prime_bits, [60, 20, 60, 60]) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..3882a67 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import unittest +from random import uniform +from eva import evaluate +from eva.ckks import CKKSCompiler +from eva.seal import generate_keys +from eva.metric import valuation_mse + +class EvaTestCase(unittest.TestCase): + def assert_compiles_and_matches_reference(self, prog, inputs = None, config={}): + if inputs == None: + inputs = { name: [uniform(-2,2) for _ in range(prog.vec_size)] + for name in prog.inputs } + config['warn_vec_size'] = 'false' + + reference = evaluate(prog, inputs) + + compiler = CKKSCompiler(config = config) + compiled_prog, params, signature = compiler.compile(prog) + + reference_compiled = evaluate(compiled_prog, inputs) + ref_mse = valuation_mse(reference, reference_compiled) + self.assertTrue(ref_mse < 0.0000000001, + f"Mean squared error was {ref_mse}") + + public_ctx, secret_ctx = generate_keys(params) + encInputs = public_ctx.encrypt(inputs, signature) + encOutputs = public_ctx.execute(compiled_prog, encInputs) + outputs = secret_ctx.decrypt(encOutputs, signature) + + he_mse = valuation_mse(outputs, reference) + self.assertTrue(he_mse < 0.01, f"Mean squared error was {he_mse}") + + return (compiled_prog, params, signature) \ No newline at end of file diff --git a/tests/features.py b/tests/features.py new file mode 100644 index 0000000..843e5a0 --- /dev/null +++ b/tests/features.py @@ -0,0 +1,220 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import unittest +import tempfile +import os +from common import * +from eva import EvaProgram, Input, Output, save, load + +class Features(EvaTestCase): + def test_bin_ops(self): + """ Test all binary ops """ + + for binOp in [lambda a, b: a + b, lambda a, b: a - b, lambda a, b: a * b]: + for enc1 in [False, True]: + for enc2 in [False, True]: + prog = EvaProgram('BinOp', vec_size = 64) + with prog: + a = Input('a', enc1) + b = Input('b', enc2) + Output('y', binOp(a,b)) + + prog.set_output_ranges(20) + prog.set_input_scales(30) + + self.assert_compiles_and_matches_reference(prog, + config={'warn_vec_size':'false'}) + + def test_unary_ops(self): + """ Test all unary ops """ + + for unOp in [lambda x: x, lambda x: -x, lambda x: x**3, lambda x: 42]: + for enc in [False, True]: + prog = EvaProgram('UnOp', vec_size = 64) + with prog: + x = Input('x', enc) + Output('y', unOp(x)) + + prog.set_output_ranges(20) + prog.set_input_scales(30) + + self.assert_compiles_and_matches_reference(prog, + config={'warn_vec_size':'false'}) + + def test_rotations(self): + """ Test all rotations """ + + for rotOp in [lambda x, r: x << r, lambda x, r: x >> r]: + for enc in [False, True]: + for rot in range(-2,2): + prog = EvaProgram('RotOp', vec_size = 8) + with prog: + x = Input('x') + Output('y', rotOp(x,rot)) + + prog.set_output_ranges(20) + prog.set_input_scales(30) + + self.assert_compiles_and_matches_reference(prog, + config={'warn_vec_size':'false'}) + + def test_unencrypted_computation(self): + """ Test computation on unencrypted values """ + + for enc1 in [False, True]: + for enc2 in [False, True]: + prog = EvaProgram('UnencryptedInputs', vec_size=128) + with prog: + x1 = Input('x1', enc1) + x2 = Input('x2', enc2) + Output('y', pow(x2,3) + x1*x2) + + prog.set_output_ranges(20) + prog.set_input_scales(30) + + self.assert_compiles_and_matches_reference(prog, + config={'warn_vec_size':'false'}) + + def test_security_levels(self): + """ Check that all supported security levels work """ + + security_levels = ['128','192','256'] + quantum_safety = ['false','true'] + + for s in security_levels: + for q in quantum_safety: + prog = EvaProgram('SecurityLevel', vec_size=512) + with prog: + x = Input('x') + Output('y', 5*x*x + 3*x + x<<12 + 10) + + prog.set_output_ranges(20) + prog.set_input_scales(30) + + self.assert_compiles_and_matches_reference(prog, + config={'security_level':s, 'quantum_safe':q, 'warn_vec_size':'false'}) + + @unittest.expectedFailure + def test_unsupported_security_level(self): + """ Check that unsupported security levels error out """ + + prog = EvaProgram('SecurityLevel', vec_size=512) + with prog: + x = Input('x') + Output('y', 5*x*x + 3*x + x<<12 + 10) + + prog.set_output_ranges(20) + prog.set_input_scales(30) + + self.assert_compiles_and_matches_reference(prog, + config={'security_level':'1024','warn_vec_size':'false'}) + + def test_reduction_balancer(self): + """ Check that reductions are balanced under balance_reductions=true """ + + prog = EvaProgram('ReductionTree', vec_size=16384) + with prog: + x1 = Input('x1') + x2 = Input('x2') + x3 = Input('x3') + x4 = Input('x4') + Output('y', (x1*(x2*(x3*x4))) + (x1+(x2+(x3+x4)))) + + prog.set_output_ranges(20) + prog.set_input_scales(60) + + progc, params, signature = self.assert_compiles_and_matches_reference(prog, + config={'rescaler':'always', 'balance_reductions':'false', 'warn_vec_size':'false'}) + self.assertEqual(params.prime_bits, [60, 20, 60, 60, 60, 60]) + + progc, params, signature = self.assert_compiles_and_matches_reference(prog, + config={'rescaler':'always', 'balance_reductions':'true', 'warn_vec_size':'false'}) + self.assertEqual(params.prime_bits, [60, 20, 60, 60, 60]) + + def test_seal_no_throw_on_transparent(self): + """ Check that SEAL is compiled with -DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF + + An HE compiler cannot in general work with transparent ciphertext detection + turned on because it is not possible to statically detect all situations that + result in them. For example, x1-x2 is transparent only if the user gives the + same ciphertext as both inputs.""" + + prog = EvaProgram('Transparent', vec_size=4096) + with prog: + x = Input('x') + Output('y', x-x+x*0) + + prog.set_output_ranges(20) + prog.set_input_scales(30) + + self.assert_compiles_and_matches_reference(prog, + config={'warn_vec_size':'false'}) + + def test_serialization(self): + """ Test (de)serialization and check that results stay the same """ + + poly = EvaProgram('Polynomial', vec_size=4096) + with poly: + x = Input('x') + Output('y', 3*x**2 + 5*x - 2) + + poly.set_output_ranges(20) + poly.set_input_scales(30) + + inputs = { + 'x': [i for i in range(poly.vec_size)] + } + reference = evaluate(poly, inputs) + + compiler = CKKSCompiler(config={'warn_vec_size':'false'}) + poly, params, signature = compiler.compile(poly) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = lambda x: os.path.join(tmp_dir, x) + + save(poly, tmp_path('poly.eva')) + save(params, tmp_path('poly.evaparams')) + save(signature, tmp_path('poly.evasignature')) + + # Key generation time + + params = load(tmp_path('poly.evaparams')) + + public_ctx, secret_ctx = generate_keys(params) + + save(public_ctx, tmp_path('poly.sealpublic')) + save(secret_ctx, tmp_path('poly.sealsecret')) + + # Runtime on client + + signature = load(tmp_path('poly.evasignature')) + public_ctx = load(tmp_path('poly.sealpublic')) + + encInputs = public_ctx.encrypt(inputs, signature) + + save(encInputs, tmp_path('poly_inputs.sealvals')) + + # Runtime on server + + poly = load(tmp_path('poly.eva')) + public_ctx = load(tmp_path('poly.sealpublic')) + encInputs = load(tmp_path('poly_inputs.sealvals')) + + encOutputs = public_ctx.execute(poly, encInputs) + + save(encOutputs, tmp_path('poly_outputs.sealvals')) + + # Runtime back on client + + secret_ctx = load(tmp_path('poly.sealsecret')) + encOutputs = load(tmp_path('poly_outputs.sealvals')) + + outputs = secret_ctx.decrypt(encOutputs, signature) + + reference_compiled = evaluate(poly, inputs) + self.assertTrue(valuation_mse(reference, reference_compiled) < 0.0000000001) + self.assertTrue(valuation_mse(outputs, reference) < 0.01) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/large_programs.py b/tests/large_programs.py new file mode 100644 index 0000000..ee63b8c --- /dev/null +++ b/tests/large_programs.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import unittest +import math +from common import * +from eva import EvaProgram, Input, Output + +class LargePrograms(EvaTestCase): + def test_sobel_configs(self): + """ Check accuracy of Sobel filter on random image with various compiler configurations """ + + def convolutionXY(image, width, filter): + for i in range(len(filter)): + for j in range(len(filter[0])): + rotated = image << (i * width + j) + horizontal = rotated * filter[i][j] + vertical = rotated * filter[j][i] + if i == 0 and j == 0: + Ix = horizontal + Iy = vertical + else: + Ix += horizontal + Iy += vertical + return Ix, Iy + + h = 90 + w = 90 + + sobel = EvaProgram('sobel', vec_size=2**(math.ceil(math.log(h*w, 2)))) + with sobel: + image = Input('image') + + sobel_filter = [ + [-1, 0, 1], + [-2, 0, 2], + [-1, 0, 1]] + + a1 = 2.2137874823876622 + a2 = -1.0984324107372518 + a3 = 0.17254603006834726 + + conv_hor, conv_ver = convolutionXY(image, w, sobel_filter) + x = conv_hor**2 + conv_ver**2 + Output('image', x * a1 + x**2 * a2 + x**3 * a3) + + sobel.set_input_scales(45) + sobel.set_output_ranges(20) + + for rescaler in ['lazy_waterline','eager_waterline','always']: + for balance_reductions in ['true','false']: + self.assert_compiles_and_matches_reference(sobel, + config={'rescaler':rescaler,'balance_reductions':balance_reductions}) + + def test_regression(self): + """ Test batched compilation and execution of multiple linear regression programs """ + + linreg = EvaProgram('linear_regression', vec_size=2048) + with linreg: + p = 63 + + x = [Input(f'x{i}') for i in range(p)] + e = Input('e') + b0 = 6.56 + b = [i * 0.732 for i in range(p)] + + y = e + b0 + for i in range(p): + t = x[i] * b[i] + y += t + + Output('y', y) + + linreg.set_input_scales(40) + linreg.set_output_ranges(30) + + linreg_inputs = {'e': [(linreg.vec_size - i) * 0.001 for i in range(linreg.vec_size)]} + for i in range(p): + linreg_inputs[f'x{i}'] = [i * j * 0.01 for j in range(linreg.vec_size)] + + polyreg = EvaProgram('polynomial_regression', vec_size=4096) + with polyreg: + p = 4 + + x = Input('x') + e = Input('e') + b0 = 6.56 + b = [i * 0.732 for i in range(p)] + + y = e + b0 + for i in range(p): + x_i = x + for j in range(i): + x_i = x_i * x + t = x_i * b[i] + y += t + + Output('y', y) + + polyreg.set_input_scales(40) + polyreg.set_output_ranges(30) + + polyreg_inputs = { + 'x': [i * 0.01 for i in range(polyreg.vec_size)], + 'e': [(polyreg.vec_size - i) * 0.001 for i in range(polyreg.vec_size)], + } + + multireg = EvaProgram('multivariate_regression', vec_size=2048) + with multireg: + p = 63 + k = 4 + + x = [Input(f'x{i}') for i in range(p)] + e = [Input(f'e{j}') for j in range(k)] + b0 = [j * 0.56 for j in range(k)] + b = [[k * i * 0.732 for i in range(p)] for j in range(k)] + + y = [0 for j in range(k)] + for j in range(k): + y[j] = e[j] + b0[j] + for i in range(p): + t = x[i] * b[j][i] + y[j] += t + + for j in range(k): + Output(f'y{j}', y[j]) + + multireg.set_input_scales(40) + multireg.set_output_ranges(30) + + multireg_inputs = {} + for i in range(p): + multireg_inputs[f'x{i}'] = [i * j * 0.01 for j in range(multireg.vec_size)] + for j in range(k): + multireg_inputs[f'e{j}'] = [(multireg.vec_size - i) * j * 0.001 for i in range(multireg.vec_size)] + + compiler = CKKSCompiler(config={'warn_vec_size':'false'}) + + for prog, inputs in [(linreg, linreg_inputs), (polyreg, polyreg_inputs), (multireg, multireg_inputs)]: + compiled_prog, params, signature = compiler.compile(prog) + public_ctx, secret_ctx = generate_keys(params) + enc_inputs = public_ctx.encrypt(inputs, signature) + enc_outputs = public_ctx.execute(compiled_prog, enc_inputs) + outputs = secret_ctx.decrypt(enc_outputs, signature) + reference = evaluate(compiled_prog, inputs) + self.assertTrue(valuation_mse(outputs, reference) < 0.01) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/std.py b/tests/std.py new file mode 100644 index 0000000..f24beb6 --- /dev/null +++ b/tests/std.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import unittest +from common import * +from eva import EvaProgram, Input, Output +from eva.std.numeric import horizontal_sum + +class Std(EvaTestCase): + def test_horizontal_sum(self): + """ Test eva.std.numeric.horizontal_sum """ + + for enc in [True, False]: + prog = EvaProgram('HorizontalSum', vec_size = 2048) + with prog: + x = Input('x', is_encrypted=enc) + y = horizontal_sum(x) + Output('y', y) + + prog.set_output_ranges(25) + prog.set_input_scales(33) + + self.assert_compiles_and_matches_reference(prog, + config={'warn_vec_size':'false'}) + + prog = EvaProgram('HorizontalSumConstant', vec_size = 2048) + with prog: + y = horizontal_sum([1 for _ in range(prog.vec_size)]) + Output('y', y) + + prog.set_output_ranges(25) + prog.set_input_scales(33) + + self.assert_compiles_and_matches_reference(prog, + config={'warn_vec_size':'false'}) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/third_party/Galois b/third_party/Galois new file mode 160000 index 0000000..306535c --- /dev/null +++ b/third_party/Galois @@ -0,0 +1 @@ +Subproject commit 306535c4931b8d398518624b9b6428f7120a0b44 diff --git a/third_party/pybind11 b/third_party/pybind11 new file mode 160000 index 0000000..f1abf5d --- /dev/null +++ b/third_party/pybind11 @@ -0,0 +1 @@ +Subproject commit f1abf5d9159b805674197f6bc443592e631c9130