diff --git a/lib/Dialect/Openfhe/IR/OpenfheOps.td b/lib/Dialect/Openfhe/IR/OpenfheOps.td index 5cdaf23c6..5000a2132 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheOps.td +++ b/lib/Dialect/Openfhe/IR/OpenfheOps.td @@ -57,7 +57,8 @@ def GenParamsOp : Openfhe_Op<"gen_params"> { def GenContextOp : Openfhe_Op<"gen_context"> { let arguments = (ins - Openfhe_CCParams:$params + Openfhe_CCParams:$params, + OptionalAttr:$supportFHE ); let results = (outs Openfhe_CryptoContext:$context); } @@ -77,6 +78,21 @@ def GenRotKeyOp : Openfhe_Op<"gen_rotkey"> { ); } +def SetupBootstrapOp : Openfhe_Op<"setup_bootstrap"> { + let arguments = (ins + Openfhe_CryptoContext:$cryptoContext, + Builtin_IntegerAttr:$levelBudgetEncode, + Builtin_IntegerAttr:$levelBudgetDecode + ); +} + +def GenBootstrapKeyOp : Openfhe_Op<"gen_bootstrap_key"> { + let arguments = (ins + Openfhe_CryptoContext:$cryptoContext, + Openfhe_PrivateKey:$privateKey + ); +} + def MakePackedPlaintextOp : Openfhe_Op<"make_packed_plaintext", [Pure]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, @@ -208,4 +224,16 @@ def KeySwitchOp : Openfhe_Op<"key_switch", [ let results = (outs NewLWECiphertext:$output); } +def BootstrapOp : Openfhe_Op<"bootstrap", [ + Pure, + AllTypesMatch<["ciphertext", "output"]> +]> { + let summary = "OpenFHE bootstrap operation of a ciphertext. (For CKKS)"; + let arguments = (ins + Openfhe_CryptoContext:$cryptoContext, + NewLWECiphertext:$ciphertext + ); + let results = (outs NewLWECiphertext:$output); +} + #endif // LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_TD_ diff --git a/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp b/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp index 3f77f6417..9789aa8b2 100644 --- a/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp +++ b/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp @@ -57,9 +57,23 @@ SmallVector findAllRotIndices(func::FuncOp op) { return rotIndicesResult; } +// Helper function to check if the function has BootstrapOp +bool hasBootstrapOp(func::FuncOp op) { + bool result = false; + op.walk([&](Operation *op) { + if (isa(op)) { + result = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + // function that generates the crypto context with proper parameters LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName, - int64_t mulDepth, ImplicitLocOpBuilder &builder) { + int64_t mulDepth, bool hasBootstrapOp, + ImplicitLocOpBuilder &builder) { Type openfheContextType = openfhe::CryptoContextType::get(builder.getContext()); SmallVector funcArgTypes; @@ -76,8 +90,9 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName, Type openfheParamsType = openfhe::CCParamsType::get(builder.getContext()); Value ccParams = builder.create(openfheParamsType, mulDepth, plainMod); - Value cryptoContext = - builder.create(openfheContextType, ccParams); + Value cryptoContext = builder.create( + openfheContextType, ccParams, + BoolAttr::get(builder.getContext(), hasBootstrapOp)); builder.create(cryptoContext); return success(); @@ -87,6 +102,7 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName, LogicalResult generateConfigFunc(func::FuncOp op, const std::string &configFuncName, bool hasMulOp, SmallVector rotIndices, + bool hasBootstrapOp, ImplicitLocOpBuilder &builder) { Type openfheContextType = openfhe::CryptoContextType::get(builder.getContext()); @@ -108,12 +124,20 @@ LogicalResult generateConfigFunc(func::FuncOp op, Value cryptoContext = configFuncOp.getArgument(0); Value privateKey = configFuncOp.getArgument(1); - if (hasMulOp) { + if (hasMulOp || hasBootstrapOp) { builder.create(cryptoContext, privateKey); } if (!rotIndices.empty()) { builder.create(cryptoContext, privateKey, rotIndices); } + if (hasBootstrapOp) { + // TODO: determine level budget otherwise + builder.create( + cryptoContext, + IntegerAttr::get(IndexType::get(builder.getContext()), 3), + IntegerAttr::get(IndexType::get(builder.getContext()), 3)); + builder.create(cryptoContext, privateKey); + } builder.create(cryptoContext); return success(); @@ -132,7 +156,19 @@ LogicalResult convertFunc(func::FuncOp op, int64_t mulDepth) { ImplicitLocOpBuilder builder = ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); - if (failed(generateGenFunc(op, genFuncName, mulDepth, builder))) { + bool hasBootstrapOpResult = hasBootstrapOp(op); + // TODO: determine this earlier, including mulDepth + // TODO: determine bootstrapDepth from levelBudget and approxModDepth + // levelBudgetEncode = 3 + // approxModDepth = 14, this solely depends on secretKeyDist + // here we use the value for UNIFORM_TERNARY + // levelBudgetDecode = 3 + int bootstrapDepth = 3 + 14 + 3; + if (hasBootstrapOpResult) { + mulDepth += bootstrapDepth; + } + if (failed(generateGenFunc(op, genFuncName, mulDepth, hasBootstrapOpResult, + builder))) { return failure(); } @@ -141,7 +177,7 @@ LogicalResult convertFunc(func::FuncOp op, int64_t mulDepth) { bool hasMulOpResult = hasMulOp(op); SmallVector rotIndices = findAllRotIndices(op); if (failed(generateConfigFunc(op, configFuncName, hasMulOpResult, rotIndices, - builder))) { + hasBootstrapOpResult, builder))) { return failure(); } return success(); diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 820d19921..0a4accc07 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -96,7 +96,8 @@ LogicalResult OpenFhePkeEmitter::translate(Operation &op) { SquareOp, NegateOp, MulConstOp, RelinOp, ModReduceOp, LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp, EncryptOp, DecryptOp, GenParamsOp, GenContextOp, GenMulKeyOp, GenRotKeyOp, - MakePackedPlaintextOp, MakeCKKSPackedPlaintextOp>( + GenBootstrapKeyOp, MakePackedPlaintextOp, + MakeCKKSPackedPlaintextOp, SetupBootstrapOp, BootstrapOp>( [&](auto op) { return printOperation(op); }) .Default([&](Operation &) { return emitError(op.getLoc(), "unable to find printer for op"); @@ -317,6 +318,11 @@ LogicalResult OpenFhePkeEmitter::printOperation(KeySwitchOp op) { {op.getCiphertext(), op.getEvalKey()}, "KeySwitch"); } +LogicalResult OpenFhePkeEmitter::printOperation(BootstrapOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext()}, "EvalBootstrap"); +} + LogicalResult OpenFhePkeEmitter::printOperation(arith::ConstantOp op) { auto valueAttr = op.getValue(); if (auto intAttr = dyn_cast(valueAttr)) { @@ -691,6 +697,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenContextOp op) { os << contextName << "->Enable(PKE);\n"; os << contextName << "->Enable(KEYSWITCH);\n"; os << contextName << "->Enable(LEVELEDSHE);\n"; + if (op.getSupportFHE().has_value() && op.getSupportFHE().value()) { + os << contextName << "->Enable(ADVANCEDSHE);\n"; + os << contextName << "->Enable(FHE);\n"; + } return success(); } @@ -715,6 +725,26 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenRotKeyOp op) { return success(); } +LogicalResult OpenFhePkeEmitter::printOperation(GenBootstrapKeyOp op) { + auto contextName = variableNames->getNameForValue(op.getCryptoContext()); + auto privateKeyName = variableNames->getNameForValue(op.getPrivateKey()); + // compiler can not determine slot num for now + // full packing for CKKS, as we currently always full packing + os << "auto numSlots = " << contextName << "->GetRingDimension() / 2;\n"; + os << contextName << "->EvalBootstrapKeyGen(" << privateKeyName + << ", numSlots);\n"; + return success(); +} + +LogicalResult OpenFhePkeEmitter::printOperation(SetupBootstrapOp op) { + auto contextName = variableNames->getNameForValue(op.getCryptoContext()); + os << contextName << "->EvalBootstrapSetup({"; + os << op.getLevelBudgetEncode().getValue() << ", "; + os << op.getLevelBudgetDecode().getValue(); + os << "});\n"; + return success(); +} + LogicalResult OpenFhePkeEmitter::emitType(Type type, Location loc) { auto result = convertType(type, loc); if (failed(result)) { diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h index c1fa0adcf..9a7f8ce29 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h @@ -67,12 +67,14 @@ class OpenFhePkeEmitter { LogicalResult printOperation(AddOp op); LogicalResult printOperation(AddPlainOp op); LogicalResult printOperation(AutomorphOp op); + LogicalResult printOperation(BootstrapOp op); LogicalResult printOperation(DecryptOp op); LogicalResult printOperation(EncryptOp op); LogicalResult printOperation(GenParamsOp op); LogicalResult printOperation(GenContextOp op); LogicalResult printOperation(GenMulKeyOp op); LogicalResult printOperation(GenRotKeyOp op); + LogicalResult printOperation(GenBootstrapKeyOp op); LogicalResult printOperation(KeySwitchOp op); LogicalResult printOperation(LevelReduceOp op); LogicalResult printOperation(MakePackedPlaintextOp op); @@ -85,6 +87,7 @@ class OpenFhePkeEmitter { LogicalResult printOperation(NegateOp op); LogicalResult printOperation(RelinOp op); LogicalResult printOperation(RotOp op); + LogicalResult printOperation(SetupBootstrapOp op); LogicalResult printOperation(SquareOp op); LogicalResult printOperation(SubOp op);