Skip to content

Commit

Permalink
openfhe: support cyclic repetition for MakePackedPlaintext
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Dec 2, 2024
1 parent a55ec9d commit d76bede
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 85 deletions.
19 changes: 2 additions & 17 deletions docs/content/en/docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,25 +336,10 @@ int main(int argc, char *argv[]) {

cryptoContext = dot_product__configure_crypto_context(cryptoContext, keyPair.secretKey);

int32_t n = cryptoContext->GetCryptoParameters()
->GetElementParams()
->GetCyclotomicOrder() /
2;
int16_t arg0Vals[8] = {1, 2, 3, 4, 5, 6, 7, 8};
int16_t arg1Vals[8] = {2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int16_t> arg0 = {1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int16_t> arg1 = {2, 3, 4, 5, 6, 7, 8, 9};
int64_t expected = 240;

std::vector<int16_t> arg0;
std::vector<int16_t> arg1;
arg0.reserve(n);
arg1.reserve(n);

// TODO(#645): support cyclic repetition in add-client-interface
for (int i = 0; i < n; ++i) {
arg0.push_back(arg0Vals[i % 8]);
arg1.push_back(arg1Vals[i % 8]);
}

auto arg0Encrypted =
dot_product__encrypt__arg0(cryptoContext, arg0, keyPair.publicKey);
auto arg1Encrypted =
Expand Down
41 changes: 36 additions & 5 deletions lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,12 +510,27 @@ LogicalResult OpenFhePkeEmitter::printOperation(
LogicalResult OpenFhePkeEmitter::printOperation(
openfhe::MakePackedPlaintextOp op) {
std::string inputVarName = variableNames->getNameForValue(op.getValue());
std::string inputVarFilledName = inputVarName + "_filled";
std::string inputVarFilledLengthName = inputVarName + "_filled_n";

emitAutoAssignPrefix(op.getResult());
FailureOr<Value> resultCC = getContextualCryptoContext(op.getOperation());
if (failed(resultCC)) return resultCC;
os << variableNames->getNameForValue(resultCC.value())
<< "->MakePackedPlaintext(" << inputVarName << ");\n";
std::string cc = variableNames->getNameForValue(resultCC.value());

// cyclic repetition to mitigate openfhe zero-padding (#645)
os << "auto " << inputVarFilledLengthName << " = " << cc
<< "->GetCryptoParameters()->GetElementParams()->GetRingDimension() / "
"2;\n";
os << "auto " << inputVarFilledName << " = " << inputVarName << ";\n";
os << inputVarFilledName << ".clear();\n";
os << inputVarFilledName << ".reserve(" << inputVarFilledLengthName << ");\n";
os << "for (auto i = 0; i < " << inputVarFilledLengthName << "; ++i) {\n";
os << " " << inputVarFilledName << ".push_back(" << inputVarName << "[i % "
<< inputVarName << ".size()]);\n";
os << "}\n";

emitAutoAssignPrefix(op.getResult());
os << cc << "->MakePackedPlaintext(" << inputVarFilledName << ");\n";
return success();
}

Expand All @@ -527,12 +542,28 @@ LogicalResult OpenFhePkeEmitter::printOperation(
}

std::string inputVarName = variableNames->getNameForValue(op.getValue());
std::string inputVarFilledName = inputVarName + "_filled";
std::string inputVarFilledLengthName = inputVarName + "_filled_n";

emitAutoAssignPrefix(op.getResult());
FailureOr<Value> resultCC = getContextualCryptoContext(op.getOperation());
if (failed(resultCC)) return resultCC;
std::string cc = variableNames->getNameForValue(resultCC.value());

// cyclic repetition to mitigate openfhe zero-padding (#645)
os << "auto " << inputVarFilledLengthName << " = " << cc
<< "->GetCryptoParameters()->GetElementParams()->GetRingDimension() / "
"2;\n";
os << "auto " << inputVarFilledName << " = " << inputVarName << ";\n";
os << inputVarFilledName << ".clear();\n";
os << inputVarFilledName << ".reserve(" << inputVarFilledLengthName << ");\n";
os << "for (auto i = 0; i < " << inputVarFilledLengthName << "; ++i) {\n";
os << " " << inputVarFilledName << ".push_back(" << inputVarName << "[i % "
<< inputVarName << ".size()]);\n";
os << "}\n";

emitAutoAssignPrefix(op.getResult());
os << variableNames->getNameForValue(resultCC.value())
<< "->MakeCKKSPackedPlaintext(" << inputVarName << ");\n";
<< "->MakeCKKSPackedPlaintext(" << inputVarFilledName << ");\n";
return success();
}

Expand Down
14 changes: 14 additions & 0 deletions tests/Examples/openfhe/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test")

package(default_applicable_licenses = ["@heir//:license"])

# BGV

openfhe_end_to_end_test(
name = "binops_test",
generated_lib_header = "binops_lib.h",
Expand Down Expand Up @@ -48,6 +50,18 @@ openfhe_end_to_end_test(
test_src = "roberts_cross_test.cpp",
)

# CKKS

openfhe_end_to_end_test(
name = "dot_product_8f_test",
generated_lib_header = "dot_product_8f_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-ckks=entry-function=dot_product ciphertext-degree=8"],
heir_translate_flags = ["--openfhe-scheme=ckks"],
mlir_src = "dot_product_8f.mlir",
tags = ["notap"],
test_src = "dot_product_8f_test.cpp",
)

openfhe_end_to_end_test(
name = "naive_matmul_test",
generated_lib_header = "naive_matmul_lib.h",
Expand Down
11 changes: 3 additions & 8 deletions tests/Examples/openfhe/box_blur_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,13 @@ TEST(BoxBlurTest, TestInput1) {
auto secretKey = keyPair.secretKey;
cryptoContext = box_blur__configure_crypto_context(cryptoContext, secretKey);

int32_t n = cryptoContext->GetCryptoParameters()
->GetElementParams()
->GetCyclotomicOrder() /
2;
std::vector<int16_t> input;
std::vector<int16_t> expected;
input.reserve(n);
input.reserve(4096);
expected.reserve(4096);

// TODO(#645): support cyclic repetition in add-client-interface
for (int i = 0; i < n; ++i) {
input.push_back(i % 4096);
for (int i = 0; i < 4096; ++i) {
input.push_back(i);
}

for (int row = 0; row < 64; ++row) {
Expand Down
30 changes: 4 additions & 26 deletions tests/Examples/openfhe/dot_product_8_test.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
#include <cstdint>
#include <vector>

#include "gtest/gtest.h" // from @googletest
#include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe
#include "src/pke/include/constants.h" // from @openfhe
#include "src/pke/include/cryptocontext-fwd.h" // from @openfhe
#include "src/pke/include/gen-cryptocontext.h" // from @openfhe
#include "src/pke/include/key/keypair.h" // from @openfhe
#include "src/pke/include/openfhe.h" // from @openfhe
#include "src/pke/include/scheme/bgvrns/gen-cryptocontext-bgvrns-params.h" // from @openfhe
#include "src/pke/include/scheme/bgvrns/gen-cryptocontext-bgvrns.h" // from @openfhe
#include "gtest/gtest.h" // from @googletest
#include "src/pke/include/openfhe.h" // from @openfhe

// Generated headers (block clang-format from messing up order)
#include "tests/Examples/openfhe/dot_product_8_lib.h"
Expand All @@ -26,25 +19,10 @@ TEST(DotProduct8Test, RunTest) {
cryptoContext =
dot_product__configure_crypto_context(cryptoContext, secretKey);

int32_t n = cryptoContext->GetCryptoParameters()
->GetElementParams()
->GetCyclotomicOrder() /
2;
int16_t arg0Vals[8] = {1, 2, 3, 4, 5, 6, 7, 8};
int16_t arg1Vals[8] = {2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int16_t> arg0 = {1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int16_t> arg1 = {2, 3, 4, 5, 6, 7, 8, 9};
int64_t expected = 240;

// TODO(#645): support cyclic repetition in add-client-interface
std::vector<int16_t> arg0;
std::vector<int16_t> arg1;
arg0.reserve(n);
arg1.reserve(n);

for (int i = 0; i < n; ++i) {
arg0.push_back(arg0Vals[i % 8]);
arg1.push_back(arg1Vals[i % 8]);
}

auto arg0Encrypted =
dot_product__encrypt__arg0(cryptoContext, arg0, publicKey);
auto arg1Encrypted =
Expand Down
12 changes: 12 additions & 0 deletions tests/Examples/openfhe/dot_product_8f.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
func.func @dot_product(%arg0: tensor<8xf16>, %arg1: tensor<8xf16>) -> f16 {
%c0 = arith.constant 0 : index
%c0_sf16 = arith.constant 0.1 : f16
%0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_sf16) -> (f16) {
%1 = tensor.extract %arg0[%arg2] : tensor<8xf16>
%2 = tensor.extract %arg1[%arg2] : tensor<8xf16>
%3 = arith.mulf %1, %2 : f16
%4 = arith.addf %iter, %3 : f16
affine.yield %4 : f16
}
return %0 : f16
}
65 changes: 65 additions & 0 deletions tests/Examples/openfhe/dot_product_8f_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include <cstdint>
#include <vector>

#include "gtest/gtest.h" // from @googletest
#include "src/pke/include/openfhe.h" // from @openfhe

// Generated headers (block clang-format from messing up order)
#include "tests/Examples/openfhe/dot_product_8f_lib.h"

namespace mlir {
namespace heir {
namespace openfhe {

// TODO(#891): support other schemes besides BGV in add-client-interface
CiphertextT dot_product__encrypt__arg0(CryptoContextT v16,
std::vector<double> v17,
PublicKeyT v18) {
int32_t n =
v16->GetCryptoParameters()->GetElementParams()->GetRingDimension() / 2;
std::vector<double> outputs;
outputs.reserve(n);
for (int i = 0; i < n; ++i) {
outputs.push_back(v17[i % v17.size()]);
}
const auto& v19 = v16->MakeCKKSPackedPlaintext(outputs);
const auto& v20 = v16->Encrypt(v18, v19);
return v20;
}

double dot_product__decrypt__result0(CryptoContextT v26, CiphertextT v27,
PrivateKeyT v28) {
PlaintextT v29;
v26->Decrypt(v28, v27, &v29);
double v30 = v29->GetCKKSPackedValue()[0].real();
return v30;
}

TEST(DotProduct8FTest, RunTest) {
auto cryptoContext = dot_product__generate_crypto_context();
auto keyPair = cryptoContext->KeyGen();
auto publicKey = keyPair.publicKey;
auto secretKey = keyPair.secretKey;
cryptoContext =
dot_product__configure_crypto_context(cryptoContext, secretKey);

std::vector<double> arg0 = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8};
std::vector<double> arg1 = {0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
double expected = 2.4 + 0.1;

// TODO(#891): support other schemes besides BGV in add-client-interface
auto arg0Encrypted =
dot_product__encrypt__arg0(cryptoContext, arg0, publicKey);
auto arg1Encrypted =
dot_product__encrypt__arg0(cryptoContext, arg1, publicKey);
auto outputEncrypted =
dot_product(cryptoContext, arg0Encrypted, arg1Encrypted);
auto actual =
dot_product__decrypt__result0(cryptoContext, outputEncrypted, secretKey);

EXPECT_NEAR(expected, actual, 1e-3);
}

} // namespace openfhe
} // namespace heir
} // namespace mlir
1 change: 0 additions & 1 deletion tests/Examples/openfhe/halevi_shoup_matmul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ TEST(NaiveMatmulTest, RunTest) {
// 0.099224686622619628) and adds -0.45141533017158508
double expected = -0.35219;

// TODO(#645): support cyclic repetition in add-client-interface
// TODO(#891): support other schemes besides BGV in add-client-interface
auto arg0Encrypted =
matmul__encrypt__arg0(cryptoContext, arg0Vals, publicKey);
Expand Down
1 change: 0 additions & 1 deletion tests/Examples/openfhe/naive_matmul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ TEST(NaiveMatmulTest, RunTest) {
// adds 0.25
double expected = 0.3492247;

// TODO(#645): support cyclic repetition in add-client-interface
// TODO(#891): support other schemes besides BGV in add-client-interface
auto arg0Encrypted =
matmul__encrypt__arg0(cryptoContext, arg0Vals, publicKey);
Expand Down
11 changes: 3 additions & 8 deletions tests/Examples/openfhe/roberts_cross_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ TEST(RobertsCrossTest, TestInput1) {
cryptoContext =
roberts_cross__configure_crypto_context(cryptoContext, secretKey);

int32_t n = cryptoContext->GetCryptoParameters()
->GetElementParams()
->GetCyclotomicOrder() /
2;
std::vector<int16_t> input;
std::vector<int16_t> expected;
input.reserve(n);
input.reserve(4096);
expected.reserve(4096);

// TODO(#645): support cyclic repetition in add-client-interface
for (int i = 0; i < n; ++i) {
input.push_back(i % 4096);
for (int i = 0; i < 4096; ++i) {
input.push_back(i);
}

for (int row = 0; row < 64; ++row) {
Expand Down
21 changes: 3 additions & 18 deletions tests/Examples/openfhe/simple_sum_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,9 @@ TEST(BinopsTest, TestInput1) {
cryptoContext =
simple_sum__configure_crypto_context(cryptoContext, secretKey);

int32_t n = cryptoContext->GetCryptoParameters()
->GetElementParams()
->GetCyclotomicOrder() /
2;
std::vector<int16_t> input;
// TODO(#645): support cyclic repetition in add-client-interface
// I want to do this, but MakePackedPlaintext does not repeat the values.
// It zero pads, and rotating the zero-padded values will not achieve the
// rotate-and-reduce trick required for simple_sum
//
// = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
// 23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
input.reserve(n);

for (int i = 0; i < n; ++i) {
input.push_back((i % 32) + 1);
}
std::vector<int16_t> input = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
int64_t expected = 16 * 33;

auto inputEncrypted =
Expand Down
4 changes: 3 additions & 1 deletion tests/Transforms/mlir_to_openfhe_ckks/naive_matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
// CHECK-DAG: std::vector<double> [[v4:.*]](16, 3.000000e+00);
// CHECK-DAG: std::vector<double> [[v5:.*]](16, 4.000000e+00);
// CHECK-DAG: std::vector<double> [[v6:.*]](16, 2.000000e+00);
// CHECK-DAG: auto [[v6_filled:.*]] = [[v6]];
// CHECK-DAG: [[v6_filled]].push_back([[v6]]
// CHECK-DAG: size_t [[v7:.*]] = 1;
// CHECK-DAG: size_t [[v8:.*]] = 0;
// CHECK-DAG: const auto& [[v9:.*]] = [[v1]][0][0];
// CHECK-DAG: const auto& [[v10:.*]] = [[v2]][0][0];
// CHECK: const auto& [[v11:.*]] = [[v0]]->MakeCKKSPackedPlaintext([[v6]]);
// CHECK: const auto& [[v11:.*]] = [[v0]]->MakeCKKSPackedPlaintext([[v6_filled]]);
// CHECK-NEXT: const auto& [[v12:.*]] = [[v0]]->EvalMult([[v9]], [[v11]]);
// CHECK-NEXT: const auto& [[v13:.*]] = [[v0]]->EvalAdd([[v10]], [[v12]]);
// CHECK-NEXT: [[v2]][0][0] = [[v13]];
Expand Down

0 comments on commit d76bede

Please sign in to comment.