Skip to content

Commit

Permalink
openfhe: add bootstrap op
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Dec 16, 2024
1 parent 8c68164 commit 42d961e
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 8 deletions.
30 changes: 29 additions & 1 deletion lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def GenParamsOp : Openfhe_Op<"gen_params"> {

def GenContextOp : Openfhe_Op<"gen_context"> {
let arguments = (ins
Openfhe_CCParams:$params
Openfhe_CCParams:$params,
OptionalAttr<BoolAttr>:$supportFHE
);
let results = (outs Openfhe_CryptoContext:$context);
}
Expand All @@ -77,6 +78,21 @@ def GenRotKeyOp : Openfhe_Op<"gen_rotkey"> {
);
}

def SetupBootstrapOp : Openfhe_Op<"setup_bootstrap"> {
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Builtin_IntegerAttr:$levelBudgetEncode,
Builtin_IntegerAttr:$levelBudgetDecode
);
}

def GenBootstrapKeyOp : Openfhe_Op<"gen_bootstrap_key"> {
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Openfhe_PrivateKey:$privateKey
);
}

def MakePackedPlaintextOp : Openfhe_Op<"make_packed_plaintext", [Pure]> {
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Expand Down Expand Up @@ -208,4 +224,16 @@ def KeySwitchOp : Openfhe_Op<"key_switch", [
let results = (outs NewLWECiphertext:$output);
}

def BootstrapOp : Openfhe_Op<"bootstrap", [
Pure,
AllTypesMatch<["ciphertext", "output"]>
]> {
let summary = "OpenFHE bootstrap operation of a ciphertext. (For CKKS)";
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
NewLWECiphertext:$ciphertext
);
let results = (outs NewLWECiphertext:$output);
}

#endif // LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_TD_
48 changes: 42 additions & 6 deletions lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,23 @@ SmallVector<int64_t> findAllRotIndices(func::FuncOp op) {
return rotIndicesResult;
}

// Helper function to check if the function has BootstrapOp
bool hasBootstrapOp(func::FuncOp op) {
bool result = false;
op.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<openfhe::BootstrapOp>(op)) {
result = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return result;
}

// function that generates the crypto context with proper parameters
LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
int64_t mulDepth, ImplicitLocOpBuilder &builder) {
int64_t mulDepth, bool hasBootstrapOp,
ImplicitLocOpBuilder &builder) {
Type openfheContextType =
openfhe::CryptoContextType::get(builder.getContext());
SmallVector<Type> funcArgTypes;
Expand All @@ -76,8 +90,9 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
Type openfheParamsType = openfhe::CCParamsType::get(builder.getContext());
Value ccParams = builder.create<openfhe::GenParamsOp>(openfheParamsType,
mulDepth, plainMod);
Value cryptoContext =
builder.create<openfhe::GenContextOp>(openfheContextType, ccParams);
Value cryptoContext = builder.create<openfhe::GenContextOp>(
openfheContextType, ccParams,
BoolAttr::get(builder.getContext(), hasBootstrapOp));

builder.create<func::ReturnOp>(cryptoContext);
return success();
Expand All @@ -87,6 +102,7 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
LogicalResult generateConfigFunc(func::FuncOp op,
const std::string &configFuncName,
bool hasMulOp, SmallVector<int64_t> rotIndices,
bool hasBootstrapOp,
ImplicitLocOpBuilder &builder) {
Type openfheContextType =
openfhe::CryptoContextType::get(builder.getContext());
Expand All @@ -108,12 +124,20 @@ LogicalResult generateConfigFunc(func::FuncOp op,
Value cryptoContext = configFuncOp.getArgument(0);
Value privateKey = configFuncOp.getArgument(1);

if (hasMulOp) {
if (hasMulOp || hasBootstrapOp) {
builder.create<openfhe::GenMulKeyOp>(cryptoContext, privateKey);
}
if (!rotIndices.empty()) {
builder.create<openfhe::GenRotKeyOp>(cryptoContext, privateKey, rotIndices);
}
if (hasBootstrapOp) {
// TODO: determine level budget otherwise
builder.create<openfhe::SetupBootstrapOp>(
cryptoContext,
IntegerAttr::get(IndexType::get(builder.getContext()), 3),
IntegerAttr::get(IndexType::get(builder.getContext()), 3));
builder.create<openfhe::GenBootstrapKeyOp>(cryptoContext, privateKey);
}

builder.create<func::ReturnOp>(cryptoContext);
return success();
Expand All @@ -132,7 +156,19 @@ LogicalResult convertFunc(func::FuncOp op, int64_t mulDepth) {
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());

if (failed(generateGenFunc(op, genFuncName, mulDepth, builder))) {
bool hasBootstrapOpResult = hasBootstrapOp(op);
// TODO: determine this earlier, including mulDepth
// TODO: determine bootstrapDepth from levelBudget and approxModDepth
// levelBudgetEncode = 3
// approxModDepth = 14, this solely depends on secretKeyDist
// here we use the value for UNIFORM_TERNARY
// levelBudgetDecode = 3
int bootstrapDepth = 3 + 14 + 3;
if (hasBootstrapOpResult) {
mulDepth += bootstrapDepth;
}
if (failed(generateGenFunc(op, genFuncName, mulDepth, hasBootstrapOpResult,
builder))) {
return failure();
}

Expand All @@ -141,7 +177,7 @@ LogicalResult convertFunc(func::FuncOp op, int64_t mulDepth) {
bool hasMulOpResult = hasMulOp(op);
SmallVector<int64_t> rotIndices = findAllRotIndices(op);
if (failed(generateConfigFunc(op, configFuncName, hasMulOpResult, rotIndices,
builder))) {
hasBootstrapOpResult, builder))) {
return failure();
}
return success();
Expand Down
32 changes: 31 additions & 1 deletion lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ LogicalResult OpenFhePkeEmitter::translate(Operation &op) {
SquareOp, NegateOp, MulConstOp, RelinOp, ModReduceOp,
LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp, EncryptOp,
DecryptOp, GenParamsOp, GenContextOp, GenMulKeyOp, GenRotKeyOp,
MakePackedPlaintextOp, MakeCKKSPackedPlaintextOp>(
GenBootstrapKeyOp, MakePackedPlaintextOp,
MakeCKKSPackedPlaintextOp, SetupBootstrapOp, BootstrapOp>(
[&](auto op) { return printOperation(op); })
.Default([&](Operation &) {
return emitError(op.getLoc(), "unable to find printer for op");
Expand Down Expand Up @@ -317,6 +318,11 @@ LogicalResult OpenFhePkeEmitter::printOperation(KeySwitchOp op) {
{op.getCiphertext(), op.getEvalKey()}, "KeySwitch");
}

LogicalResult OpenFhePkeEmitter::printOperation(BootstrapOp op) {
return printEvalMethod(op.getResult(), op.getCryptoContext(),
{op.getCiphertext()}, "EvalBootstrap");
}

LogicalResult OpenFhePkeEmitter::printOperation(arith::ConstantOp op) {
auto valueAttr = op.getValue();
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
Expand Down Expand Up @@ -691,6 +697,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenContextOp op) {
os << contextName << "->Enable(PKE);\n";
os << contextName << "->Enable(KEYSWITCH);\n";
os << contextName << "->Enable(LEVELEDSHE);\n";
if (op.getSupportFHE().has_value() && op.getSupportFHE().value()) {
os << contextName << "->Enable(ADVANCEDSHE);\n";
os << contextName << "->Enable(FHE);\n";
}
return success();
}

Expand All @@ -715,6 +725,26 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenRotKeyOp op) {
return success();
}

LogicalResult OpenFhePkeEmitter::printOperation(GenBootstrapKeyOp op) {
auto contextName = variableNames->getNameForValue(op.getCryptoContext());
auto privateKeyName = variableNames->getNameForValue(op.getPrivateKey());
// compiler can not determine slot num for now
// full packing for CKKS, as we currently always full packing
os << "auto numSlots = " << contextName << "->GetRingDimension() / 2;\n";
os << contextName << "->EvalBootstrapKeyGen(" << privateKeyName
<< ", numSlots);\n";
return success();
}

LogicalResult OpenFhePkeEmitter::printOperation(SetupBootstrapOp op) {
auto contextName = variableNames->getNameForValue(op.getCryptoContext());
os << contextName << "->EvalBootstrapSetup({";
os << op.getLevelBudgetEncode().getValue() << ", ";
os << op.getLevelBudgetDecode().getValue();
os << "});\n";
return success();
}

LogicalResult OpenFhePkeEmitter::emitType(Type type, Location loc) {
auto result = convertType(type, loc);
if (failed(result)) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Target/OpenFhePke/OpenFhePkeEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ class OpenFhePkeEmitter {
LogicalResult printOperation(AddOp op);
LogicalResult printOperation(AddPlainOp op);
LogicalResult printOperation(AutomorphOp op);
LogicalResult printOperation(BootstrapOp op);
LogicalResult printOperation(DecryptOp op);
LogicalResult printOperation(EncryptOp op);
LogicalResult printOperation(GenParamsOp op);
LogicalResult printOperation(GenContextOp op);
LogicalResult printOperation(GenMulKeyOp op);
LogicalResult printOperation(GenRotKeyOp op);
LogicalResult printOperation(GenBootstrapKeyOp op);
LogicalResult printOperation(KeySwitchOp op);
LogicalResult printOperation(LevelReduceOp op);
LogicalResult printOperation(MakePackedPlaintextOp op);
Expand All @@ -85,6 +87,7 @@ class OpenFhePkeEmitter {
LogicalResult printOperation(NegateOp op);
LogicalResult printOperation(RelinOp op);
LogicalResult printOperation(RotOp op);
LogicalResult printOperation(SetupBootstrapOp op);
LogicalResult printOperation(SquareOp op);
LogicalResult printOperation(SubOp op);

Expand Down

0 comments on commit 42d961e

Please sign in to comment.