diff --git a/docs/content/en/docs/getting_started.md b/docs/content/en/docs/getting_started.md index 1bea7c4e3..a63fc8f3f 100644 --- a/docs/content/en/docs/getting_started.md +++ b/docs/content/en/docs/getting_started.md @@ -336,25 +336,10 @@ int main(int argc, char *argv[]) { cryptoContext = dot_product__configure_crypto_context(cryptoContext, keyPair.secretKey); - int32_t n = cryptoContext->GetCryptoParameters() - ->GetElementParams() - ->GetCyclotomicOrder() / - 2; - int16_t arg0Vals[8] = {1, 2, 3, 4, 5, 6, 7, 8}; - int16_t arg1Vals[8] = {2, 3, 4, 5, 6, 7, 8, 9}; + std::vector arg0 = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector arg1 = {2, 3, 4, 5, 6, 7, 8, 9}; int64_t expected = 240; - std::vector arg0; - std::vector arg1; - arg0.reserve(n); - arg1.reserve(n); - - // TODO(#645): support cyclic repetition in add-client-interface - for (int i = 0; i < n; ++i) { - arg0.push_back(arg0Vals[i % 8]); - arg1.push_back(arg1Vals[i % 8]); - } - auto arg0Encrypted = dot_product__encrypt__arg0(cryptoContext, arg0, keyPair.publicKey); auto arg1Encrypted = diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 853b906c7..595f2b3fa 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -510,12 +510,27 @@ LogicalResult OpenFhePkeEmitter::printOperation( LogicalResult OpenFhePkeEmitter::printOperation( openfhe::MakePackedPlaintextOp op) { std::string inputVarName = variableNames->getNameForValue(op.getValue()); + std::string inputVarFilledName = inputVarName + "_filled"; + std::string inputVarFilledLengthName = inputVarName + "_filled_n"; - emitAutoAssignPrefix(op.getResult()); FailureOr resultCC = getContextualCryptoContext(op.getOperation()); if (failed(resultCC)) return resultCC; - os << variableNames->getNameForValue(resultCC.value()) - << "->MakePackedPlaintext(" << inputVarName << ");\n"; + std::string cc = variableNames->getNameForValue(resultCC.value()); + + // cyclic repetition to mitigate openfhe zero-padding (#645) + os << "auto " << inputVarFilledLengthName << " = " << cc + << "->GetCryptoParameters()->GetElementParams()->GetRingDimension() / " + "2;\n"; + os << "auto " << inputVarFilledName << " = " << inputVarName << ";\n"; + os << inputVarFilledName << ".clear();\n"; + os << inputVarFilledName << ".reserve(" << inputVarFilledLengthName << ");\n"; + os << "for (auto i = 0; i < " << inputVarFilledLengthName << "; ++i) {\n"; + os << " " << inputVarFilledName << ".push_back(" << inputVarName << "[i % " + << inputVarName << ".size()]);\n"; + os << "}\n"; + + emitAutoAssignPrefix(op.getResult()); + os << cc << "->MakePackedPlaintext(" << inputVarFilledName << ");\n"; return success(); } @@ -527,12 +542,28 @@ LogicalResult OpenFhePkeEmitter::printOperation( } std::string inputVarName = variableNames->getNameForValue(op.getValue()); + std::string inputVarFilledName = inputVarName + "_filled"; + std::string inputVarFilledLengthName = inputVarName + "_filled_n"; - emitAutoAssignPrefix(op.getResult()); FailureOr resultCC = getContextualCryptoContext(op.getOperation()); if (failed(resultCC)) return resultCC; + std::string cc = variableNames->getNameForValue(resultCC.value()); + + // cyclic repetition to mitigate openfhe zero-padding (#645) + os << "auto " << inputVarFilledLengthName << " = " << cc + << "->GetCryptoParameters()->GetElementParams()->GetRingDimension() / " + "2;\n"; + os << "auto " << inputVarFilledName << " = " << inputVarName << ";\n"; + os << inputVarFilledName << ".clear();\n"; + os << inputVarFilledName << ".reserve(" << inputVarFilledLengthName << ");\n"; + os << "for (auto i = 0; i < " << inputVarFilledLengthName << "; ++i) {\n"; + os << " " << inputVarFilledName << ".push_back(" << inputVarName << "[i % " + << inputVarName << ".size()]);\n"; + os << "}\n"; + + emitAutoAssignPrefix(op.getResult()); os << variableNames->getNameForValue(resultCC.value()) - << "->MakeCKKSPackedPlaintext(" << inputVarName << ");\n"; + << "->MakeCKKSPackedPlaintext(" << inputVarFilledName << ");\n"; return success(); } diff --git a/tests/Examples/openfhe/BUILD b/tests/Examples/openfhe/BUILD index 178f36c87..f0a64c0bf 100644 --- a/tests/Examples/openfhe/BUILD +++ b/tests/Examples/openfhe/BUILD @@ -4,6 +4,8 @@ load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") package(default_applicable_licenses = ["@heir//:license"]) +# BGV + openfhe_end_to_end_test( name = "binops_test", generated_lib_header = "binops_lib.h", @@ -48,6 +50,18 @@ openfhe_end_to_end_test( test_src = "roberts_cross_test.cpp", ) +# CKKS + +openfhe_end_to_end_test( + name = "dot_product_8f_test", + generated_lib_header = "dot_product_8f_lib.h", + heir_opt_flags = ["--mlir-to-openfhe-ckks=entry-function=dot_product ciphertext-degree=8"], + heir_translate_flags = ["--openfhe-scheme=ckks"], + mlir_src = "dot_product_8f.mlir", + tags = ["notap"], + test_src = "dot_product_8f_test.cpp", +) + openfhe_end_to_end_test( name = "naive_matmul_test", generated_lib_header = "naive_matmul_lib.h", diff --git a/tests/Examples/openfhe/box_blur_test.cpp b/tests/Examples/openfhe/box_blur_test.cpp index 32ffddf3b..ac7bf4584 100644 --- a/tests/Examples/openfhe/box_blur_test.cpp +++ b/tests/Examples/openfhe/box_blur_test.cpp @@ -31,18 +31,13 @@ TEST(BoxBlurTest, TestInput1) { auto secretKey = keyPair.secretKey; cryptoContext = box_blur__configure_crypto_context(cryptoContext, secretKey); - int32_t n = cryptoContext->GetCryptoParameters() - ->GetElementParams() - ->GetCyclotomicOrder() / - 2; std::vector input; std::vector expected; - input.reserve(n); + input.reserve(4096); expected.reserve(4096); - // TODO(#645): support cyclic repetition in add-client-interface - for (int i = 0; i < n; ++i) { - input.push_back(i % 4096); + for (int i = 0; i < 4096; ++i) { + input.push_back(i); } for (int row = 0; row < 64; ++row) { diff --git a/tests/Examples/openfhe/dot_product_8_test.cpp b/tests/Examples/openfhe/dot_product_8_test.cpp index 0b7093b02..8e040013c 100644 --- a/tests/Examples/openfhe/dot_product_8_test.cpp +++ b/tests/Examples/openfhe/dot_product_8_test.cpp @@ -1,15 +1,8 @@ #include #include -#include "gtest/gtest.h" // from @googletest -#include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe -#include "src/pke/include/constants.h" // from @openfhe -#include "src/pke/include/cryptocontext-fwd.h" // from @openfhe -#include "src/pke/include/gen-cryptocontext.h" // from @openfhe -#include "src/pke/include/key/keypair.h" // from @openfhe -#include "src/pke/include/openfhe.h" // from @openfhe -#include "src/pke/include/scheme/bgvrns/gen-cryptocontext-bgvrns-params.h" // from @openfhe -#include "src/pke/include/scheme/bgvrns/gen-cryptocontext-bgvrns.h" // from @openfhe +#include "gtest/gtest.h" // from @googletest +#include "src/pke/include/openfhe.h" // from @openfhe // Generated headers (block clang-format from messing up order) #include "tests/Examples/openfhe/dot_product_8_lib.h" @@ -26,25 +19,10 @@ TEST(DotProduct8Test, RunTest) { cryptoContext = dot_product__configure_crypto_context(cryptoContext, secretKey); - int32_t n = cryptoContext->GetCryptoParameters() - ->GetElementParams() - ->GetCyclotomicOrder() / - 2; - int16_t arg0Vals[8] = {1, 2, 3, 4, 5, 6, 7, 8}; - int16_t arg1Vals[8] = {2, 3, 4, 5, 6, 7, 8, 9}; + std::vector arg0 = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector arg1 = {2, 3, 4, 5, 6, 7, 8, 9}; int64_t expected = 240; - // TODO(#645): support cyclic repetition in add-client-interface - std::vector arg0; - std::vector arg1; - arg0.reserve(n); - arg1.reserve(n); - - for (int i = 0; i < n; ++i) { - arg0.push_back(arg0Vals[i % 8]); - arg1.push_back(arg1Vals[i % 8]); - } - auto arg0Encrypted = dot_product__encrypt__arg0(cryptoContext, arg0, publicKey); auto arg1Encrypted = diff --git a/tests/Examples/openfhe/dot_product_8f.mlir b/tests/Examples/openfhe/dot_product_8f.mlir new file mode 100644 index 000000000..92ab5bca3 --- /dev/null +++ b/tests/Examples/openfhe/dot_product_8f.mlir @@ -0,0 +1,12 @@ +func.func @dot_product(%arg0: tensor<8xf16>, %arg1: tensor<8xf16>) -> f16 { + %c0 = arith.constant 0 : index + %c0_sf16 = arith.constant 0.1 : f16 + %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_sf16) -> (f16) { + %1 = tensor.extract %arg0[%arg2] : tensor<8xf16> + %2 = tensor.extract %arg1[%arg2] : tensor<8xf16> + %3 = arith.mulf %1, %2 : f16 + %4 = arith.addf %iter, %3 : f16 + affine.yield %4 : f16 + } + return %0 : f16 +} diff --git a/tests/Examples/openfhe/dot_product_8f_test.cpp b/tests/Examples/openfhe/dot_product_8f_test.cpp new file mode 100644 index 000000000..6a5aa2579 --- /dev/null +++ b/tests/Examples/openfhe/dot_product_8f_test.cpp @@ -0,0 +1,65 @@ +#include +#include + +#include "gtest/gtest.h" // from @googletest +#include "src/pke/include/openfhe.h" // from @openfhe + +// Generated headers (block clang-format from messing up order) +#include "tests/Examples/openfhe/dot_product_8f_lib.h" + +namespace mlir { +namespace heir { +namespace openfhe { + +// TODO(#891): support other schemes besides BGV in add-client-interface +CiphertextT dot_product__encrypt__arg0(CryptoContextT v16, + std::vector v17, + PublicKeyT v18) { + int32_t n = + v16->GetCryptoParameters()->GetElementParams()->GetRingDimension() / 2; + std::vector outputs; + outputs.reserve(n); + for (int i = 0; i < n; ++i) { + outputs.push_back(v17[i % v17.size()]); + } + const auto& v19 = v16->MakeCKKSPackedPlaintext(outputs); + const auto& v20 = v16->Encrypt(v18, v19); + return v20; +} + +double dot_product__decrypt__result0(CryptoContextT v26, CiphertextT v27, + PrivateKeyT v28) { + PlaintextT v29; + v26->Decrypt(v28, v27, &v29); + double v30 = v29->GetCKKSPackedValue()[0].real(); + return v30; +} + +TEST(DotProduct8FTest, RunTest) { + auto cryptoContext = dot_product__generate_crypto_context(); + auto keyPair = cryptoContext->KeyGen(); + auto publicKey = keyPair.publicKey; + auto secretKey = keyPair.secretKey; + cryptoContext = + dot_product__configure_crypto_context(cryptoContext, secretKey); + + std::vector arg0 = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}; + std::vector arg1 = {0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; + double expected = 2.4 + 0.1; + + // TODO(#891): support other schemes besides BGV in add-client-interface + auto arg0Encrypted = + dot_product__encrypt__arg0(cryptoContext, arg0, publicKey); + auto arg1Encrypted = + dot_product__encrypt__arg0(cryptoContext, arg1, publicKey); + auto outputEncrypted = + dot_product(cryptoContext, arg0Encrypted, arg1Encrypted); + auto actual = + dot_product__decrypt__result0(cryptoContext, outputEncrypted, secretKey); + + EXPECT_NEAR(expected, actual, 1e-3); +} + +} // namespace openfhe +} // namespace heir +} // namespace mlir diff --git a/tests/Examples/openfhe/halevi_shoup_matmul_test.cpp b/tests/Examples/openfhe/halevi_shoup_matmul_test.cpp index d48278ca0..c66faa32c 100644 --- a/tests/Examples/openfhe/halevi_shoup_matmul_test.cpp +++ b/tests/Examples/openfhe/halevi_shoup_matmul_test.cpp @@ -49,7 +49,6 @@ TEST(NaiveMatmulTest, RunTest) { // 0.099224686622619628) and adds -0.45141533017158508 double expected = -0.35219; - // TODO(#645): support cyclic repetition in add-client-interface // TODO(#891): support other schemes besides BGV in add-client-interface auto arg0Encrypted = matmul__encrypt__arg0(cryptoContext, arg0Vals, publicKey); diff --git a/tests/Examples/openfhe/naive_matmul_test.cpp b/tests/Examples/openfhe/naive_matmul_test.cpp index 358efb93d..fbcf473bf 100644 --- a/tests/Examples/openfhe/naive_matmul_test.cpp +++ b/tests/Examples/openfhe/naive_matmul_test.cpp @@ -69,7 +69,6 @@ TEST(NaiveMatmulTest, RunTest) { // adds 0.25 double expected = 0.3492247; - // TODO(#645): support cyclic repetition in add-client-interface // TODO(#891): support other schemes besides BGV in add-client-interface auto arg0Encrypted = matmul__encrypt__arg0(cryptoContext, arg0Vals, publicKey); diff --git a/tests/Examples/openfhe/roberts_cross_test.cpp b/tests/Examples/openfhe/roberts_cross_test.cpp index 1f06a9ca6..b8eb25007 100644 --- a/tests/Examples/openfhe/roberts_cross_test.cpp +++ b/tests/Examples/openfhe/roberts_cross_test.cpp @@ -29,18 +29,13 @@ TEST(RobertsCrossTest, TestInput1) { cryptoContext = roberts_cross__configure_crypto_context(cryptoContext, secretKey); - int32_t n = cryptoContext->GetCryptoParameters() - ->GetElementParams() - ->GetCyclotomicOrder() / - 2; std::vector input; std::vector expected; - input.reserve(n); + input.reserve(4096); expected.reserve(4096); - // TODO(#645): support cyclic repetition in add-client-interface - for (int i = 0; i < n; ++i) { - input.push_back(i % 4096); + for (int i = 0; i < 4096; ++i) { + input.push_back(i); } for (int row = 0; row < 64; ++row) { diff --git a/tests/Examples/openfhe/simple_sum_test.cpp b/tests/Examples/openfhe/simple_sum_test.cpp index 07f7298cb..eec9e8a0c 100644 --- a/tests/Examples/openfhe/simple_sum_test.cpp +++ b/tests/Examples/openfhe/simple_sum_test.cpp @@ -26,24 +26,9 @@ TEST(BinopsTest, TestInput1) { cryptoContext = simple_sum__configure_crypto_context(cryptoContext, secretKey); - int32_t n = cryptoContext->GetCryptoParameters() - ->GetElementParams() - ->GetCyclotomicOrder() / - 2; - std::vector input; - // TODO(#645): support cyclic repetition in add-client-interface - // I want to do this, but MakePackedPlaintext does not repeat the values. - // It zero pads, and rotating the zero-padded values will not achieve the - // rotate-and-reduce trick required for simple_sum - // - // = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - // 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - // 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}; - input.reserve(n); - - for (int i = 0; i < n; ++i) { - input.push_back((i % 32) + 1); - } + std::vector input = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}; int64_t expected = 16 * 33; auto inputEncrypted = diff --git a/tests/Transforms/mlir_to_openfhe_ckks/naive_matmul.mlir b/tests/Transforms/mlir_to_openfhe_ckks/naive_matmul.mlir index ceb50bac5..594b63148 100644 --- a/tests/Transforms/mlir_to_openfhe_ckks/naive_matmul.mlir +++ b/tests/Transforms/mlir_to_openfhe_ckks/naive_matmul.mlir @@ -8,11 +8,13 @@ // CHECK-DAG: std::vector [[v4:.*]](16, 3.000000e+00); // CHECK-DAG: std::vector [[v5:.*]](16, 4.000000e+00); // CHECK-DAG: std::vector [[v6:.*]](16, 2.000000e+00); +// CHECK-DAG: auto [[v6_filled:.*]] = [[v6]]; +// CHECK-DAG: [[v6_filled]].push_back([[v6]] // CHECK-DAG: size_t [[v7:.*]] = 1; // CHECK-DAG: size_t [[v8:.*]] = 0; // CHECK-DAG: const auto& [[v9:.*]] = [[v1]][0][0]; // CHECK-DAG: const auto& [[v10:.*]] = [[v2]][0][0]; -// CHECK: const auto& [[v11:.*]] = [[v0]]->MakeCKKSPackedPlaintext([[v6]]); +// CHECK: const auto& [[v11:.*]] = [[v0]]->MakeCKKSPackedPlaintext([[v6_filled]]); // CHECK-NEXT: const auto& [[v12:.*]] = [[v0]]->EvalMult([[v9]], [[v11]]); // CHECK-NEXT: const auto& [[v13:.*]] = [[v0]]->EvalAdd([[v10]], [[v12]]); // CHECK-NEXT: [[v2]][0][0] = [[v13]];