Skip to content

Commit

Permalink
openfhe: add bootstrap op
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Dec 18, 2024
1 parent 9031113 commit a3c5ce8
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 20 deletions.
23 changes: 21 additions & 2 deletions lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ class Openfhe_BinaryOp<string mnemonic, list<Trait> traits = []>
def GenParamsOp : Openfhe_Op<"gen_params"> {
let arguments = (ins
I64Attr:$mulDepth,
I64Attr:$plainMod
I64Attr:$plainMod,
BoolAttr:$insecure
);
let results = (outs Openfhe_CCParams:$params);
}

def GenContextOp : Openfhe_Op<"gen_context"> {
let arguments = (ins
Openfhe_CCParams:$params
Openfhe_CCParams:$params,
BoolAttr:$supportFHE
);
let results = (outs Openfhe_CryptoContext:$context);
}
Expand All @@ -77,6 +79,21 @@ def GenRotKeyOp : Openfhe_Op<"gen_rotkey"> {
);
}

def SetupBootstrapOp : Openfhe_Op<"setup_bootstrap"> {
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Builtin_IntegerAttr:$levelBudgetEncode,
Builtin_IntegerAttr:$levelBudgetDecode
);
}

def GenBootstrapKeyOp : Openfhe_Op<"gen_bootstrapkey"> {
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Openfhe_PrivateKey:$privateKey
);
}

def MakePackedPlaintextOp : Openfhe_Op<"make_packed_plaintext", [Pure]> {
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Expand Down Expand Up @@ -208,4 +225,6 @@ def KeySwitchOp : Openfhe_Op<"key_switch", [
let results = (outs NewLWECiphertext:$output);
}

def BootstrapOp : Openfhe_UnaryTypeSwitchOp<"bootstrap"> { let summary = "OpenFHE bootstrap operation of a ciphertext. (For CKKS)"; }

#endif // LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_TD_
67 changes: 52 additions & 15 deletions lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,23 @@ SmallVector<int64_t> findAllRotIndices(func::FuncOp op) {
return rotIndicesResult;
}

// Helper function to check if the function has BootstrapOp
bool hasBootstrapOp(func::FuncOp op) {
bool result = false;
op.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<openfhe::BootstrapOp>(op)) {
result = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return result;
}

// function that generates the crypto context with proper parameters
LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
int64_t mulDepth, ImplicitLocOpBuilder &builder) {
int64_t mulDepth, bool hasBootstrapOp,
bool insecure, ImplicitLocOpBuilder &builder) {
Type openfheContextType =
openfhe::CryptoContextType::get(builder.getContext());
SmallVector<Type> funcArgTypes;
Expand All @@ -73,21 +87,21 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
// TODO(#661) : Calculate the appropriate values by analyzing the function
int64_t plainMod = 4295294977;
Type openfheParamsType = openfhe::CCParamsType::get(builder.getContext());
Value ccParams = builder.create<openfhe::GenParamsOp>(openfheParamsType,
mulDepth, plainMod);
Value cryptoContext =
builder.create<openfhe::GenContextOp>(openfheContextType, ccParams);
Value ccParams = builder.create<openfhe::GenParamsOp>(
openfheParamsType, mulDepth, plainMod, insecure);
Value cryptoContext = builder.create<openfhe::GenContextOp>(
openfheContextType, ccParams,
BoolAttr::get(builder.getContext(), hasBootstrapOp));

builder.create<func::ReturnOp>(cryptoContext);
return success();
}

// function that configures the crypto context with proper keygeneration
LogicalResult generateConfigFunc(func::FuncOp op,
const std::string &configFuncName,
bool hasRelinOp,
SmallVector<int64_t> rotIndices,
ImplicitLocOpBuilder &builder) {
LogicalResult generateConfigFunc(
func::FuncOp op, const std::string &configFuncName, bool hasRelinOp,
SmallVector<int64_t> rotIndices, bool hasBootstrapOp, int levelBudgetEncode,
int levelBudgetDecode, ImplicitLocOpBuilder &builder) {
Type openfheContextType =
openfhe::CryptoContextType::get(builder.getContext());
Type privateKeyType = openfhe::PrivateKeyType::get(builder.getContext());
Expand All @@ -108,18 +122,28 @@ LogicalResult generateConfigFunc(func::FuncOp op,
Value cryptoContext = configFuncOp.getArgument(0);
Value privateKey = configFuncOp.getArgument(1);

if (hasRelinOp) {
if (hasRelinOp || hasBootstrapOp) {
builder.create<openfhe::GenMulKeyOp>(cryptoContext, privateKey);
}
if (!rotIndices.empty()) {
builder.create<openfhe::GenRotKeyOp>(cryptoContext, privateKey, rotIndices);
}
if (hasBootstrapOp) {
builder.create<openfhe::SetupBootstrapOp>(
cryptoContext,
IntegerAttr::get(IndexType::get(builder.getContext()),
levelBudgetEncode),
IntegerAttr::get(IndexType::get(builder.getContext()),
levelBudgetDecode));
builder.create<openfhe::GenBootstrapKeyOp>(cryptoContext, privateKey);
}

builder.create<func::ReturnOp>(cryptoContext);
return success();
}

LogicalResult convertFunc(func::FuncOp op) {
LogicalResult convertFunc(func::FuncOp op, int levelBudgetEncode,
int levelBudgetDecode, bool insecure) {
auto module = op->getParentOfType<ModuleOp>();
std::string genFuncName("");
llvm::raw_string_ostream genNameOs(genFuncName);
Expand All @@ -146,7 +170,16 @@ LogicalResult convertFunc(func::FuncOp op) {
}
}

if (failed(generateGenFunc(op, genFuncName, mulDepth, builder))) {
bool hasBootstrapOpResult = hasBootstrapOp(op);
// TODO(#1207): determine mulDepth earlier in mgmt level
// approxModDepth = 14, this solely depends on secretKeyDist
// here we use the value for UNIFORM_TERNARY
int bootstrapDepth = levelBudgetEncode + 14 + levelBudgetDecode;
if (hasBootstrapOpResult) {
mulDepth += bootstrapDepth;
}
if (failed(generateGenFunc(op, genFuncName, mulDepth, hasBootstrapOpResult,
insecure, builder))) {
return failure();
}

Expand All @@ -155,7 +188,9 @@ LogicalResult convertFunc(func::FuncOp op) {
bool hasRelinOpResult = hasRelinOp(op);
SmallVector<int64_t> rotIndices = findAllRotIndices(op);
if (failed(generateConfigFunc(op, configFuncName, hasRelinOpResult,
rotIndices, builder))) {
rotIndices, hasBootstrapOpResult,
levelBudgetEncode, levelBudgetDecode,
builder))) {
return failure();
}
return success();
Expand All @@ -169,7 +204,9 @@ struct ConfigureCryptoContext
auto result =
getOperation()->walk<WalkOrder::PreOrder>([&](func::FuncOp op) {
auto funcName = op.getSymName();
if ((funcName == entryFunction) && failed(convertFunc(op))) {
if ((funcName == entryFunction) &&
failed(convertFunc(op, levelBudgetEncode, levelBudgetDecode,
insecure))) {
op->emitError("Failed to configure the crypto context for func");
return WalkResult::interrupt();
}
Expand Down
10 changes: 8 additions & 2 deletions lib/Dialect/Openfhe/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ def ConfigureCryptoContext : Pass<"openfhe-configure-crypto-context"> {
let options = [
Option<"entryFunction", "entry-function", "std::string",
/*default=*/"", "Default entry function "
"name of entry function.">

"name of entry function.">,
Option<"levelBudgetEncode", "level-budget-encode", "int",
/*default=*/"3", "Level budget for CKKS bootstrap encode (s2c) phase">,
Option<"levelBudgetDecode", "level-budget-decode", "int",
/*default=*/"3", "Level budget for CKKS bootstrap decode (c2s) phase">,
Option<"insecure", "insecure", "bool",
/*default=*/"false", "Whether to use insecure parameter for faster evaluation"
"(should only be used in test) (defaults to false)">
];
}

Expand Down
36 changes: 35 additions & 1 deletion lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ LogicalResult OpenFhePkeEmitter::translate(Operation &op) {
SquareOp, NegateOp, MulConstOp, RelinOp, ModReduceOp,
LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp, EncryptOp,
DecryptOp, GenParamsOp, GenContextOp, GenMulKeyOp, GenRotKeyOp,
MakePackedPlaintextOp, MakeCKKSPackedPlaintextOp>(
GenBootstrapKeyOp, MakePackedPlaintextOp,
MakeCKKSPackedPlaintextOp, SetupBootstrapOp, BootstrapOp>(
[&](auto op) { return printOperation(op); })
.Default([&](Operation &) {
return emitError(op.getLoc(), "unable to find printer for op");
Expand Down Expand Up @@ -317,6 +318,11 @@ LogicalResult OpenFhePkeEmitter::printOperation(KeySwitchOp op) {
{op.getCiphertext(), op.getEvalKey()}, "KeySwitch");
}

LogicalResult OpenFhePkeEmitter::printOperation(BootstrapOp op) {
return printEvalMethod(op.getResult(), op.getCryptoContext(),
{op.getCiphertext()}, "EvalBootstrap");
}

LogicalResult OpenFhePkeEmitter::printOperation(arith::ConstantOp op) {
auto valueAttr = op.getValue();
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
Expand Down Expand Up @@ -679,6 +685,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenParamsOp op) {
os << "CCParamsT " << paramsName << ";\n";
os << paramsName << ".SetMultiplicativeDepth(" << mulDepth << ");\n";
os << paramsName << ".SetPlaintextModulus(" << plainMod << ");\n";
if (op.getInsecure()) {
os << paramsName << ".SetSecurityLevel(lbcrypto::HEStd_NotSet);\n";
os << paramsName << ".SetRingDim(128);\n";
}
return success();
}

Expand All @@ -691,6 +701,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenContextOp op) {
os << contextName << "->Enable(PKE);\n";
os << contextName << "->Enable(KEYSWITCH);\n";
os << contextName << "->Enable(LEVELEDSHE);\n";
if (op.getSupportFHE()) {
os << contextName << "->Enable(ADVANCEDSHE);\n";
os << contextName << "->Enable(FHE);\n";
}
return success();
}

Expand All @@ -715,6 +729,26 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenRotKeyOp op) {
return success();
}

LogicalResult OpenFhePkeEmitter::printOperation(GenBootstrapKeyOp op) {
auto contextName = variableNames->getNameForValue(op.getCryptoContext());
auto privateKeyName = variableNames->getNameForValue(op.getPrivateKey());
// compiler can not determine slot num for now
// full packing for CKKS, as we currently always full packing
os << "auto numSlots = " << contextName << "->GetRingDimension() / 2;\n";
os << contextName << "->EvalBootstrapKeyGen(" << privateKeyName
<< ", numSlots);\n";
return success();
}

LogicalResult OpenFhePkeEmitter::printOperation(SetupBootstrapOp op) {
auto contextName = variableNames->getNameForValue(op.getCryptoContext());
os << contextName << "->EvalBootstrapSetup({";
os << op.getLevelBudgetEncode().getValue() << ", ";
os << op.getLevelBudgetDecode().getValue();
os << "});\n";
return success();
}

LogicalResult OpenFhePkeEmitter::emitType(Type type, Location loc) {
auto result = convertType(type, loc);
if (failed(result)) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Target/OpenFhePke/OpenFhePkeEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ class OpenFhePkeEmitter {
LogicalResult printOperation(AddOp op);
LogicalResult printOperation(AddPlainOp op);
LogicalResult printOperation(AutomorphOp op);
LogicalResult printOperation(BootstrapOp op);
LogicalResult printOperation(DecryptOp op);
LogicalResult printOperation(EncryptOp op);
LogicalResult printOperation(GenParamsOp op);
LogicalResult printOperation(GenContextOp op);
LogicalResult printOperation(GenMulKeyOp op);
LogicalResult printOperation(GenRotKeyOp op);
LogicalResult printOperation(GenBootstrapKeyOp op);
LogicalResult printOperation(KeySwitchOp op);
LogicalResult printOperation(LevelReduceOp op);
LogicalResult printOperation(MakePackedPlaintextOp op);
Expand All @@ -85,6 +87,7 @@ class OpenFhePkeEmitter {
LogicalResult printOperation(NegateOp op);
LogicalResult printOperation(RelinOp op);
LogicalResult printOperation(RotOp op);
LogicalResult printOperation(SetupBootstrapOp op);
LogicalResult printOperation(SquareOp op);
LogicalResult printOperation(SubOp op);

Expand Down
19 changes: 19 additions & 0 deletions tests/Dialect/Openfhe/IR/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ciphertext_space_L0_ = #lwe.ciphertext_space<ring = #ring_rns_L0_1_x1024_, encryption_type = lsb>

!pk = !openfhe.public_key
!sk = !openfhe.private_key
!ek = !openfhe.eval_key
!cc = !openfhe.crypto_context
!ct = !lwe.new_lwe_ciphertext<application_data = <message_type = i3>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_>
Expand Down Expand Up @@ -150,4 +151,22 @@ module {
%out = openfhe.level_reduce %cc, %ct: (!cc, !ct) -> !ct
return
}

// CHECK-LABEL: func @test_bootstrap
func.func @test_bootstrap(%cc : !cc, %ct : !ct) {
%out = openfhe.bootstrap %cc, %ct: (!cc, !ct) -> !ct
return
}

// CHECK-LABEL: func @test_gen_bootstrap_key
func.func @test_gen_bootstrap_key(%cc : !cc, %sk : !sk) {
openfhe.gen_bootstrapkey %cc, %sk: (!cc, !sk) -> ()
return
}

// CHECK-LABEL: func @test_setup_bootstrap
func.func @test_setup_bootstrap(%cc : !cc) {
openfhe.setup_bootstrap %cc {levelBudgetEncode = 3, levelBudgetDecode = 3}: (!cc) -> ()
return
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: heir-opt --openfhe-configure-crypto-context=entry-function=bootstrap %s | FileCheck %s

!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64>
!Z65537_i64_ = !mod_arith.int<65537 : i64>
#full_crt_packing_encoding = #lwe.full_crt_packing_encoding<scaling_factor = 0>
#key = #lwe.key<>
#modulus_chain_L5_C0_ = #lwe.modulus_chain<elements = <1095233372161 : i64, 1032955396097 : i64, 1005037682689 : i64, 998595133441 : i64, 972824936449 : i64, 959939837953 : i64>, current = 0>
!rns_L0_ = !rns.rns<!Z1095233372161_i64_>
#ring_Z65537_i64_1_x32_ = #polynomial.ring<coefficientType = !Z65537_i64_, polynomialModulus = <1 + x**32>>
#plaintext_space = #lwe.plaintext_space<ring = #ring_Z65537_i64_1_x32_, encoding = #full_crt_packing_encoding>
#ring_rns_L0_1_x32_ = #polynomial.ring<coefficientType = !rns_L0_, polynomialModulus = <1 + x**32>>
#ciphertext_space_L0_ = #lwe.ciphertext_space<ring = #ring_rns_L0_1_x32_, encryption_type = lsb>
!ct_L0_ = !lwe.new_lwe_ciphertext<application_data = <message_type = i16>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_>


func.func @bootstrap(%arg0: !openfhe.crypto_context, %arg1: !ct_L0_) -> !ct_L0_ {
%0 = openfhe.bootstrap %arg0, %arg1 : (!openfhe.crypto_context, !ct_L0_) -> !ct_L0_
return %0 : !ct_L0_
}

// CHECK: @bootstrap
// CHECK: @bootstrap__generate_crypto_context
// CHECK: mulDepth = 20
// CHECK: openfhe.gen_context %{{.*}} {supportFHE = true}

// CHECK: @bootstrap__configure_crypto_context
// CHECK: openfhe.gen_mulkey
// CHECK: openfhe.setup_bootstrap %{{.*}} {levelBudgetDecode = 3 : index, levelBudgetEncode = 3 : index}
// CHECK: openfhe.gen_bootstrapkey
15 changes: 15 additions & 0 deletions tests/Examples/openfhe/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,18 @@ openfhe_end_to_end_test(
tags = ["notap"],
test_src = "halevi_shoup_matmul_test.cpp",
)

openfhe_end_to_end_test(
name = "simple_ckks_bootstrapping_test",
generated_lib_header = "simple_ckks_bootstrapping_lib.h",
heir_opt_flags = [
"--openfhe-configure-crypto-context=entry-function=simple_ckks_bootstrapping insecure=true",
],
heir_translate_flags = [
"--openfhe-scheme=ckks",
"--openfhe-include-type=source-relative",
],
mlir_src = "simple_ckks_bootstrapping.mlir",
tags = ["notap"],
test_src = "simple_ckks_bootstrapping_test.cpp",
)
Loading

0 comments on commit a3c5ce8

Please sign in to comment.