From 670b95e6a05e182d01d8b3c5c050cffb25135520 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 22 Jan 2024 10:42:14 +0100 Subject: [PATCH] refactor(compiler): Re-implement TFHE multi-parameter parametrization with type inference The current pass applying the parameters determined by the optimizer to the IR propagates the parametrized TFHE types to operations not directly tagged with an optimizer ID only under certain conditions. In particular, it does not always properly propagate types into nested regions (e.g., of `scf.for` loops). This burdens preceding transformations that are applied in between the invocation of the optimizer and the parametrization pass with data-flow analysis and book-keeping in order to tag newly inserted operations with the right optimizer IDs that ensure proper parametrization. This commit replaces the current parametrization pass with a new pass that propagates parametrized TFHE types up and down def-use chains using type inference and a proper rewriter. The pass is limited to the operations supported by `TFHEParametrizationTypeResolver::resolve`. --- .../Dialect/TFHE/Transforms/Transforms.h | 4 +- .../Dialect/TFHE/Transforms/Transforms.td | 2 +- .../TFHECircuitSolutionParametrization.cpp | 1297 +++++++++++------ .../compiler/lib/Support/CompilerEngine.cpp | 2 + .../compiler/lib/Support/Pipeline.cpp | 23 +- 5 files changed, 849 insertions(+), 479 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.h index d53adfb4f1..b5d2c410fc 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.h @@ -16,9 +16,9 @@ namespace mlir { namespace concretelang { std::unique_ptr> createTFHEOptimizationPass(); -std::unique_ptr> +std::unique_ptr> createTFHECircuitSolutionParametrizationPass( - concrete_optimizer::dag::CircuitSolution); + std::optional); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td index 2609e1cced..a0af3edda0 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td @@ -10,7 +10,7 @@ def TFHEOptimization : Pass<"tfhe-optimization"> { let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ]; } -def TFHECircuitSolutionParametrization : Pass<"tfhe-circuit-solution-parametrization"> { +def TFHECircuitSolutionParametrization : Pass<"tfhe-circuit-solution-parametrization", "mlir::ModuleOp"> { let summary = "Parametrize TFHE with a circuit solution given by the optimizer"; let constructor = "mlir::concretelang::createTFHECircuitSolutionParametrizationPass()"; let options = []; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp index fcf7a10475..797d75a1ef 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp @@ -3,542 +3,905 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include "llvm/Support/Debug.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "concretelang/Dialect/TFHE/IR/TFHEOps.h" -#include "concretelang/Dialect/TFHE/Transforms/Transforms.h" -#include "concretelang/Support/Constants.h" -#include "concretelang/Support/logging.h" +#include "concrete-optimizer.hpp" +#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h" +#include "concretelang/Dialect/TypeInference/IR/TypeInferenceOps.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include namespace mlir { namespace concretelang { - namespace { +// Return the element type of `t` if `t` is a tensor or memref type +// or `t` itself +template static std::optional tryGetScalarType(mlir::Type t) { + if (T ctt = t.dyn_cast()) + return ctt; + + if (mlir::RankedTensorType rtt = t.dyn_cast()) + return tryGetScalarType(rtt.getElementType()); + else if (mlir::MemRefType mrt = t.dyn_cast()) + return tryGetScalarType(mrt.getElementType()); + + return std::nullopt; +} + +// Wraps a `::concrete_optimizer::dag::CircuitSolution` and provides +// helper functions for lookups and code generation +class CircuitSolutionWrapper { +public: + CircuitSolutionWrapper( + const ::concrete_optimizer::dag::CircuitSolution &solution) + : solution(solution) {} + + enum class SolutionKeyKind { + OPERAND, + RESULT, + KSK_IN, + KSK_OUT, + CKSK_IN, + CKSK_OUT, + BSK_IN, + BSK_OUT + }; + + // Returns the `GLWESecretKey` type for a secrete key + TFHE::GLWESecretKey + toGLWESecretKey(const ::concrete_optimizer::dag::SecretLweKey &key) const { + return TFHE::GLWESecretKey::newParameterized( + key.glwe_dimension * key.polynomial_size, 1, key.identifier); + } -#define DEBUG(MSG) \ - if (llvm::DebugFlag) \ - llvm::errs() << MSG << "\n"; + // Looks up the keys associated to an operation with a given `oid` + const ::concrete_optimizer::dag::InstructionKeys & + lookupInstructionKeys(int64_t oid) const { + assert(oid <= (int64_t)solution.instructions_keys.size() && + "Invalid optimizer ID"); -#define VERBOSE(MSG) \ - if (mlir::concretelang::isVerbose()) { \ - llvm::errs() << MSG << "\n"; \ + return solution.instructions_keys[oid]; } -namespace TFHE = mlir::concretelang::TFHE; + // Returns a `GLWEKeyswitchKeyAttr` for a given keyswitch key + // (either of type `KeySwitchKey` or `ConversionKeySwitchKey`) + template + TFHE::GLWEKeyswitchKeyAttr getKeyswitchKeyAttr(mlir::MLIRContext *ctx, + const KeyT &ksk) const { + return TFHE::GLWEKeyswitchKeyAttr::get( + ctx, toGLWESecretKey(ksk.input_key), toGLWESecretKey(ksk.output_key), + ksk.ks_decomposition_parameter.level, + ksk.ks_decomposition_parameter.log2_base, -1); + } -/// Optimization pass that should choose more efficient ways of performing -/// crypto operations. -class TFHECircuitSolutionParametrizationPass - : public TFHECircuitSolutionParametrizationBase< - TFHECircuitSolutionParametrizationPass> { -public: - TFHECircuitSolutionParametrizationPass( - concrete_optimizer::dag::CircuitSolution solution) - : solution(solution){}; + // Returns a `GLWEKeyswitchKeyAttr` for the keyswitch key of an + // operation tagged with a given `oid` + TFHE::GLWEKeyswitchKeyAttr getKeyswitchKeyAttr(mlir::MLIRContext *ctx, + int64_t oid) const { + const ::concrete_optimizer::dag::KeySwitchKey &ksk = + lookupKeyswitchKey(oid); - void runOnOperation() override { - mlir::Operation *op = getOperation(); - op->walk([&](mlir::func::FuncOp func) { - DEBUG("apply solution: \n" << solution.dump().c_str()); - DEBUG("process func: " << func); - // Process function arguments, change type of arguments according of the - // optimizer identifier stored in the "TFHE.OId" attribute. - for (size_t i = 0; i < func.getNumArguments(); i++) { - auto arg = func.getArgument(i); - auto attr = func.getArgAttrOfType(i, "TFHE.OId"); - if (attr != nullptr) { - DEBUG("process arg = " << arg) - arg.setType(getParametrizedType(arg.getType(), attr)); - } else { - DEBUG("skip arg " << arg) - } - } - // Process operations, apply the instructions keys according of the - // optimizer identifier stored in the "TFHE.OId" - VERBOSE("\n### BEFORE Apply instruction keys " << func); - applyInstructionKeys(func); - // The keyswitch operator is an internal node of the optimizer tlu node, - // so it don't follow the same rule than the other operator on the type of - // outputs - VERBOSE("\n### BEFORE Fixup keyswitch \n" << func); - fixupKeyswitchOuputs(func); - // Propagate types on non parametrized operators - VERBOSE("\n### BEFORE Fixup non parametrized ops \n" << func); - fixupNonParametrizedOps(func); - // Fixup incompatible operators with extra conversion keys - VERBOSE("\n### BEFORE Fixup with extra conversion keys \n" << func); - fixupIncompatibleLeveledOpWithExtraConversionKeys(func); - // Fixup the function signature - VERBOSE("\n### BEFORE Fixup function signature \n" << func); - fixupFunctionSignature(func); - // Remove optimizer identifiers - VERBOSE("\n### BEFORE Remove optimizer identifiers \n" << func); - removeOptimizerIdentifiers(func); - }); + return getKeyswitchKeyAttr(ctx, ksk); } - static mlir::Type getParametrizedType(mlir::Type originalType, - TFHE::GLWECipherTextType newGlwe) { - if (auto oldGlwe = originalType.dyn_cast(); - oldGlwe != nullptr) { - assert(oldGlwe.getKey().isNone()); - return newGlwe; - } else if (auto oldTensor = originalType.dyn_cast(); - oldTensor != nullptr) { - auto oldGlwe = - oldTensor.getElementType().dyn_cast(); - assert(oldGlwe != nullptr); - assert(oldGlwe.getKey().isNone()); - return mlir::RankedTensorType::get(oldTensor.getShape(), newGlwe); - } - assert(false); + // Returns a `GLWEBootstrapKeyAttr` for the bootstrap key of an + // operation tagged with a given `oid` + TFHE::GLWEBootstrapKeyAttr getBootstrapKeyAttr(mlir::MLIRContext *ctx, + int64_t oid) const { + const ::concrete_optimizer::dag::BootstrapKey &bsk = + lookupBootstrapKey(oid); + + return TFHE::GLWEBootstrapKeyAttr::get( + ctx, toGLWESecretKey(bsk.input_key), toGLWESecretKey(bsk.output_key), + bsk.output_key.polynomial_size, bsk.output_key.glwe_dimension, + bsk.br_decomposition_parameter.level, + bsk.br_decomposition_parameter.log2_base, -1); } - mlir::Type getParametrizedType(mlir::Type originalType, - mlir::IntegerAttr optimizerAttrID) { - auto context = originalType.getContext(); - auto newGlwe = - getOutputLWECipherTextType(context, optimizerAttrID.getInt()); - return getParametrizedType(originalType, newGlwe); + // Looks up the keyswitch key for an operation tagged with a given + // `oid` + const ::concrete_optimizer::dag::KeySwitchKey & + lookupKeyswitchKey(int64_t oid) const { + uint64_t keyID = lookupInstructionKeys(oid).tlu_keyswitch_key; + return solution.circuit_keys.keyswitch_keys[keyID]; } - static TFHE::GLWECipherTextType getGlweTypeFromType(mlir::Type type) { - if (auto glwe = type.dyn_cast(); - glwe != nullptr) { - return glwe; - } else if (auto tensor = type.dyn_cast(); - tensor != nullptr) { - auto glwe = tensor.getElementType().dyn_cast(); - if (glwe == nullptr) { - return nullptr; - } - return glwe; - } - return nullptr; + // Looks up the bootstrap key for an operation tagged with a given + // `oid` + const ::concrete_optimizer::dag::BootstrapKey & + lookupBootstrapKey(int64_t oid) const { + uint64_t keyID = lookupInstructionKeys(oid).tlu_bootstrap_key; + return solution.circuit_keys.bootstrap_keys[keyID]; } - // Return the - static TFHE::GLWECipherTextType - getParametrizedGlweTypeFromType(mlir::Type type) { - auto glwe = getGlweTypeFromType(type); - if (glwe != nullptr && glwe.getKey().isParameterized()) { - return glwe; - } - return nullptr; + // Looks up the conversion keyswitch key for an operation tagged + // with a given `oid` + const ::concrete_optimizer::dag::ConversionKeySwitchKey & + lookupConversionKeyswitchKey(uint64_t oid) const { + uint64_t keyID = lookupInstructionKeys(oid).extra_conversion_keys[0]; + return solution.circuit_keys.conversion_keyswitch_keys[keyID]; } - // Returns true if the type is or contains a glwe type with a none key. - static bool isNoneGlweType(mlir::Type type) { - auto glwe = getGlweTypeFromType(type); - return glwe != nullptr && glwe.getKey().isNone(); + // Looks up the conversion keyswitch key for the conversion of the + // key with the ID `fromKeyID` to the key with the ID `toKeyID`. The + // key must exist, otherwise an assertion is triggered. + const ::concrete_optimizer::dag::ConversionKeySwitchKey & + lookupConversionKeyswitchKey(uint64_t fromKeyID, uint64_t toKeyID) const { + auto convKSKIt = std::find_if( + solution.circuit_keys.conversion_keyswitch_keys.cbegin(), + solution.circuit_keys.conversion_keyswitch_keys.cend(), + [&](const ::concrete_optimizer::dag::ConversionKeySwitchKey &arg) { + return arg.input_key.identifier == fromKeyID && + arg.output_key.identifier == toKeyID; + }); + + assert(convKSKIt != + solution.circuit_keys.conversion_keyswitch_keys.cend() && + "Required conversion key must be available"); + + return *convKSKIt; } - void applyInstructionKeys(mlir::func::FuncOp func) { - auto context = func.getContext(); - func.walk([&](mlir::Operation *op) { - auto attrOptimizerID = op->getAttrOfType("TFHE.OId"); - // Skip operation is no optimizer identifier - if (attrOptimizerID == nullptr) { - DEBUG("skip operation: " << op->getName()) - return; - } - DEBUG("process operation: " << *op); - auto optimizerID = attrOptimizerID.getInt(); - // Change the output type of the operation - for (auto result : op->getResults()) { - result.setType(getParametrizedType(result.getType(), attrOptimizerID)); - } - // Set the keyswitch_key attribute - // TODO: Change ambiguous attribute name - auto attrKeyswitchKey = - op->getAttrOfType("key"); - if (attrKeyswitchKey == nullptr) { - DEBUG("no keyswitch key"); - } else { - op->setAttr("key", getKeyswitchKeyAttr(context, optimizerID)); - } - // Set boostrap_key attribute - // TODO: Change ambiguous attribute name - auto attrBootstrapKey = - op->getAttrOfType("key"); - if (attrBootstrapKey == nullptr) { - DEBUG("no bootstrap key"); - } else { - op->setAttr("key", getBootstrapKeyAttr(context, optimizerID)); - // FIXME: For now we know that if there are an extra conversion key - // this result will only be used in another partition. This is a - // STRONG assumptions of how the optimization work, this is done like - // that to avoid a bug in type propagation, but the extra conversion - // key should be added at the use and not here. - auto instKeys = getInstructionKey(optimizerID); - if (instKeys.extra_conversion_keys.size() != 0) { - assert(instKeys.extra_conversion_keys.size() == 1); - auto convKSK = - solution.circuit_keys - .conversion_keyswitch_keys[instKeys.extra_conversion_keys[0]]; - auto convKSKAttr = getExtraConversionKeyAttr(context, convKSK); - mlir::IRRewriter rewriter(context); - rewriter.setInsertionPointAfter(op); - auto outputKey = toLWESecretKey(convKSK.output_key); - auto resType = TFHE::GLWECipherTextType::get(context, outputKey); - auto extraKSK = rewriter.create( - op->getLoc(), resType, op->getResult(0), convKSKAttr); - rewriter.replaceAllUsesExcept(op->getResult(0), extraKSK, extraKSK); - } - } - }); + // Looks up the secret key of type `kind` for an instruction tagged + // with the optimizer id `oid` + const ::concrete_optimizer::dag::SecretLweKey & + lookupSecretKey(int64_t oid, SolutionKeyKind kind) const { + uint64_t keyID; + + switch (kind) { + case SolutionKeyKind::OPERAND: + keyID = lookupInstructionKeys(oid).input_key; + return solution.circuit_keys.secret_keys[keyID]; + case SolutionKeyKind::RESULT: + keyID = lookupInstructionKeys(oid).output_key; + return solution.circuit_keys.secret_keys[keyID]; + case SolutionKeyKind::KSK_IN: + return lookupKeyswitchKey(oid).input_key; + case SolutionKeyKind::KSK_OUT: + return lookupKeyswitchKey(oid).output_key; + case SolutionKeyKind::CKSK_IN: + return lookupConversionKeyswitchKey(oid).input_key; + case SolutionKeyKind::CKSK_OUT: + return lookupConversionKeyswitchKey(oid).output_key; + case SolutionKeyKind::BSK_IN: + return lookupBootstrapKey(oid).input_key; + case SolutionKeyKind::BSK_OUT: + return lookupBootstrapKey(oid).output_key; + } + + llvm_unreachable("Unknown key kind"); } - void fixupKeyswitchOuputs(mlir::func::FuncOp func) { - auto context = func.getContext(); - func.walk([&](TFHE::KeySwitchGLWEOp op) { - DEBUG("process op: " << op) - auto attrKeyswitchKey = - op->getAttrOfType("key"); - assert(attrKeyswitchKey != nullptr); - auto outputKey = attrKeyswitchKey.getOutputKey(); - outputKey = GLWESecretKeyAsLWE(outputKey); - op.getResult().setType(TFHE::GLWECipherTextType::get(context, outputKey)); - DEBUG("fixed op: " << op) - }); - // Fixup input of the boostrap operator - DEBUG("### Fixup input tlu of bootstrap") - func.walk([&](TFHE::BootstrapGLWEOp op) { - DEBUG("process op: " << op) - auto attrBootstrapKey = - op->getAttrOfType("key"); - assert(attrBootstrapKey != nullptr); - auto polySize = attrBootstrapKey.getPolySize(); - auto lutDefiningOp = op.getLookupTable().getDefiningOp(); - // Dirty fixup of the lookup table as we known the operators that can - // define it - // TODO: Do something more robust, using the GLWE type? - mlir::Builder builder(op->getContext()); - assert(lutDefiningOp != nullptr); - if (auto encodeOp = mlir::dyn_cast( - lutDefiningOp); - encodeOp != nullptr) { - encodeOp.setPolySize(polySize); - } else if (auto constantOp = - mlir::dyn_cast(lutDefiningOp)) { - // Rounded PBS case - auto denseAttr = - constantOp.getValueAttr().dyn_cast(); - auto val = denseAttr.getValues()[0]; - std::vector lut(polySize, val); - constantOp.setValueAttr(mlir::DenseIntElementsAttr::get( - mlir::RankedTensorType::get(lut.size(), builder.getIntegerType(64)), - lut)); - } - op.getLookupTable().setType(mlir::RankedTensorType::get( - mlir::ArrayRef(polySize), builder.getI64Type())); - // Also fixup the bootstrap key as the TFHENormalization rely on - // GLWESecretKey structure and not on identifier - // TODO: FIXME - auto outputKey = attrBootstrapKey.getOutputKey().getParameterized(); - auto newOutputKey = TFHE::GLWESecretKey::newParameterized( - outputKey->polySize * outputKey->dimension, 1, outputKey->identifier); - auto newAttrBootstrapKey = TFHE::GLWEBootstrapKeyAttr::get( - context, attrBootstrapKey.getInputKey(), newOutputKey, - attrBootstrapKey.getPolySize(), attrBootstrapKey.getGlweDim(), - attrBootstrapKey.getLevels(), attrBootstrapKey.getBaseLog(), -1); - op.setKeyAttr(newAttrBootstrapKey); - }); + TFHE::GLWECipherTextType + getTFHETypeForKey(mlir::MLIRContext *ctx, + const ::concrete_optimizer::dag::SecretLweKey &key) const { + return TFHE::GLWECipherTextType::get(ctx, toGLWESecretKey(key)); } - static void - fixupNonParametrizedOp(mlir::Operation *op, - TFHE::GLWECipherTextType parametrizedGlweType) { - DEBUG(" START Fixup {" << *op) - for (auto result : op->getResults()) { - if (isNoneGlweType(result.getType())) { - DEBUG(" -> Fixing result " << result) - result.setType( - getParametrizedType(result.getType(), parametrizedGlweType)); - DEBUG(" -> Fixed result " << result) - // Recurse on all users of the fixed result - for (auto user : result.getUsers()) { - DEBUG(" -> Propagate on user " << *user) - fixupNonParametrizedOp(user, parametrizedGlweType); - } - } - } - // Recursively fixup producer of op operands - mlir::Block *parentBlock = nullptr; - for (auto operand : op->getOperands()) { - if (isNoneGlweType(operand.getType())) { - DEBUG(" -> Propagate on operand " << operand.getType()) - if (auto opResult = operand.dyn_cast(); - opResult != nullptr) { - fixupNonParametrizedOp(opResult.getOwner(), parametrizedGlweType); - continue; - } - if (auto blockArg = operand.dyn_cast(); - blockArg != nullptr) { - DEBUG(" -> Fixing block arg " << blockArg) - blockArg.setType( - getParametrizedType(blockArg.getType(), parametrizedGlweType)); - for (auto users : blockArg.getUsers()) { - fixupNonParametrizedOp(users, parametrizedGlweType); +protected: + const ::concrete_optimizer::dag::CircuitSolution &solution; +}; + +// Type resolver for the type inference for values with unparametrized +// `tfhe.glwe` types +class TFHEParametrizationTypeResolver : public TypeResolver { +public: + TFHEParametrizationTypeResolver( + std::optional solution) + : solution(solution) {} + + LocalInferenceState + resolve(mlir::Operation *op, + const LocalInferenceState &inferredTypes) override { + LocalInferenceState state = inferredTypes; + + mlir::TypeSwitch(op) + .Case([&](auto op) { + TypeConstraintSet<> cs; + + if (solution.has_value()) { + cs.addConstraint( + *this, solution.value()); + } + + cs.addConstraint(); + + cs.converge(op, *this, state, inferredTypes); + }) + + .Case([&](auto op) { + converge(op, state, inferredTypes); + }) + + .Case([&](auto op) { + converge, + SameOperandAndResultTypeConstraint<0, 0>>(op, state, + inferredTypes); + }) + .Case([&](auto op) { + converge>(op, state, + inferredTypes); + }) + .Case([&](auto op) { + converge>(op, state, + inferredTypes); + }) + .Case( + [&](auto op) { + converge>( + op, state, inferredTypes); + }) + + .Case([&](auto op) { + TypeConstraintSet<> cs; + + if (solution.has_value()) { + cs.addConstraint(*this, + solution.value()); + } + + // TODO: This can be quite slow for `tensor.from_elements` + // with lots of operands; implement + // SameOperandTypeConstraint taking into account all + // operands at once. + for (size_t i = 1; i < op.getNumOperands(); i++) { + cs.addConstraint< + DynamicSameTypeConstraint>(0, i); } - auto blockOwner = blockArg.getOwner(); - if (blockOwner->isEntryBlock()) { - DEBUG(" -> Will propagate on parent op " - << blockOwner->getParentOp()); - assert(parentBlock == blockOwner || parentBlock == nullptr); - parentBlock = blockOwner; + + cs.addConstraint>(); + cs.converge(op, *this, state, inferredTypes); + }) + + .Case( + [&](auto op) { + converge, + SameOperandAndResultTypeConstraint<1, 0>>(op, state, + inferredTypes); + }) + .Case([&](auto op) { + converge>( + op, state, inferredTypes); + }) + .Case([&](mlir::scf::ForOp op) { + TypeConstraintSet<> cs; + + if (solution.has_value()) { + cs.addConstraint(*this, + solution.value()); } - continue; + + for (size_t i = 0; i < op.getNumIterOperands(); i++) { + mlir::Value initArg = op.getInitArgs()[i]; + mlir::Value regionIterArg = op.getRegionIterArg(i); + mlir::Value result = op.getResult(i); + mlir::Value terminatorOperand = + op.getBody()->getTerminator()->getOperand(i); + + // Ensure that init args, return values, region iter args and + // operands of terminator all have the same type + cs.addConstraint>( + [=]() { return initArg; }, [=]() { return regionIterArg; }); + + cs.addConstraint>( + [=]() { return initArg; }, [=]() { return result; }); + + cs.addConstraint>( + [=]() { return result; }, [=]() { return terminatorOperand; }); + } + + cs.converge(op, *this, state, inferredTypes); + }) + + .Case([&](auto op) { + TypeConstraintSet<> cs; + + if (solution.has_value()) { + cs.addConstraint(*this, + solution.value()); + } + + for (size_t i = 0; i < op.getNumOperands(); i++) { + cs.addConstraint>( + [=]() { return op->getParentOp()->getResult(i); }, + [=]() { return op->getOperand(i); }); + } + + cs.converge(op, *this, state, inferredTypes); + }) + .Default([&](auto _op) { assert(false && "unsupported op type"); }); + + return state; + } + + bool isUnresolvedType(mlir::Type t) const override { + return isUnparametrizedGLWEType(t); + } + +protected: + // Type constraint that applies the type assigned to the operation + // by a TFHE solver via the `TFHE.OId` attribute + class ApplySolverSolutionConstraint : public TypeConstraint { + public: + ApplySolverSolutionConstraint(const TypeResolver &typeResolver, + const CircuitSolutionWrapper &solution) + : solution(solution), typeResolver(typeResolver) {} + + void apply(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState) override { + mlir::IntegerAttr oid = op->getAttrOfType("TFHE.OId"); + + if (!oid) + return; + + mlir::TypeSwitch(op) + .Case( + [&](auto op) { + applyKeyswitch(op, resolver, currState, prevState, + oid.getInt()); + }) + .Case( + [&](auto op) { + applyBootstrap(op, resolver, currState, prevState, + oid.getInt()); + }) + .Default([&](auto op) { + applyGeneric(op, resolver, currState, prevState, oid.getInt()); + }); + } + + protected: + // For any value in `values`, set the scalar or element type to + // `t` if the value is of an unresolved type or of a tensor type + // with an unresolved element type + void setUnresolvedTo(mlir::ValueRange values, mlir::Type t, + TypeResolver &resolver, + LocalInferenceState &currState) { + for (mlir::Value v : values) { + if (typeResolver.isUnresolvedType(v.getType())) { + currState.set(v, + TypeInferenceUtils::applyElementType(t, v.getType())); } - // An mlir::Value should always be an OpResult or a BlockArgument - assert(false); } } - DEBUG(" } END Fixup") - if (parentBlock != nullptr) { - fixupNonParametrizedOp(parentBlock->getParentOp(), parametrizedGlweType); + + // Apply the rule to a keyswitch or batched keyswitch operation + void applyKeyswitch(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState, int64_t oid) { + // Operands + TFHE::GLWECipherTextType scalarOperandType = solution.getTFHETypeForKey( + op->getContext(), + solution.lookupSecretKey( + oid, CircuitSolutionWrapper::SolutionKeyKind::KSK_IN)); + setUnresolvedTo(op->getOperands(), scalarOperandType, resolver, + currState); + + // Results + TFHE::GLWECipherTextType scalarResultType = solution.getTFHETypeForKey( + op->getContext(), + solution.lookupSecretKey( + oid, CircuitSolutionWrapper::SolutionKeyKind::KSK_OUT)); + setUnresolvedTo(op->getResults(), scalarResultType, resolver, currState); } - } - void - fixupIncompatibleLeveledOpWithExtraConversionKeys(mlir::func::FuncOp func) { - auto context = func.getContext(); - func.walk([&](mlir::Operation *op) { - // Skip bootstrap/keyswitch - if (mlir::isa(op) || - mlir::isa(op)) { - return; - } - auto attrOptimizerID = op->getAttrOfType("TFHE.OId"); - // Skip operation with no optimizer identifier - if (attrOptimizerID == nullptr) { - return; - } - DEBUG(" -> process op: " << *op) - // TFHE operators have only one ciphertext result - assert(op->getNumResults() == 1); - auto resType = - op->getResult(0).getType().dyn_cast(); - // For each ciphertext operands apply the extra keyswitch if found - for (const auto &p : llvm::enumerate(op->getOperands())) { - if (resType == nullptr) { - // We don't expect tensor operands to exist at this point of the - // pipeline for now, but if we happen to have some, this assert will - // break, and things will need to be changed to allow tensor ops to - // be parameterized. - // TODO: Actually this case could happens with tensor manipulation - // operators, so for now we just skip it and that should be fixed - // and tested. As the operand will not be fixed the validation of - // operators should not validate the operators. - continue; - } - auto operand = p.value(); - auto operandIdx = p.index(); - DEBUG(" -> processing operand " << operand); - auto operandType = - operand.getType().dyn_cast(); - if (operandType == nullptr) { - DEBUG(" -> skip operand, no glwe"); - continue; - } - if (operandType.getKey() == resType.getKey()) { - DEBUG(" -> skip operand, unnecessary conversion"); - continue; - } - // Lookup for the extra conversion key - DEBUG(" -> get extra conversion key") - auto extraConvKey = getExtraConversionKeyAttr( - context, operandType.getKey(), resType.getKey()); - if (extraConvKey == nullptr) { - DEBUG(" -> extra conversion key, not found") - assert(false); + // Apply the rule to a bootstrap or batched bootstrap operation + void applyBootstrap(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState, int64_t oid) { + // Operands + TFHE::GLWECipherTextType scalarOperandType = solution.getTFHETypeForKey( + op->getContext(), + solution.lookupSecretKey( + oid, CircuitSolutionWrapper::SolutionKeyKind::BSK_IN)); + setUnresolvedTo(op->getOperands(), scalarOperandType, resolver, + currState); + + // Results + TFHE::GLWECipherTextType scalarResultType = solution.getTFHETypeForKey( + op->getContext(), + solution.lookupSecretKey( + oid, CircuitSolutionWrapper::SolutionKeyKind::BSK_OUT)); + setUnresolvedTo(op->getResults(), scalarResultType, resolver, currState); + } + + // Apply the rule to any operation that is neither a keyswitch, + // nor a bootstrap operation + void applyGeneric(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState, int64_t oid) { + // Operands + TFHE::GLWECipherTextType scalarOperandType = solution.getTFHETypeForKey( + op->getContext(), + solution.lookupSecretKey( + oid, CircuitSolutionWrapper::SolutionKeyKind::OPERAND)); + setUnresolvedTo(op->getOperands(), scalarOperandType, resolver, + currState); + + // Results + TFHE::GLWECipherTextType scalarResultType = solution.getTFHETypeForKey( + op->getContext(), + solution.lookupSecretKey( + oid, CircuitSolutionWrapper::SolutionKeyKind::RESULT)); + setUnresolvedTo(op->getResults(), scalarResultType, resolver, currState); + } + + protected: + const CircuitSolutionWrapper &solution; + const TypeResolver &typeResolver; + }; + + // Type constraint that applies the type assigned to the arguments + // of a function by a TFHE solver via the `TFHE.OId` attributes of + // the function arguments + class ApplySolverSolutionToFunctionArgsConstraint : public TypeConstraint { + public: + ApplySolverSolutionToFunctionArgsConstraint( + const TypeResolver &typeResolver, + const CircuitSolutionWrapper &solution) + : solution(solution), typeResolver(typeResolver) {} + + void apply(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState) override { + mlir::func::FuncOp func = llvm::cast(op); + + for (size_t i = 0; i < func.getNumArguments(); i++) { + mlir::BlockArgument arg = func.getArgument(i); + if (mlir::IntegerAttr oidAttr = + func.getArgAttrOfType(i, "TFHE.OId")) { + TFHE::GLWECipherTextType scalarOperandType = + solution.getTFHETypeForKey( + func->getContext(), + solution.lookupSecretKey( + oidAttr.getInt(), + CircuitSolutionWrapper::SolutionKeyKind::RESULT)); + + currState.set(arg, TypeInferenceUtils::applyElementType( + scalarOperandType, arg.getType())); } - mlir::IRRewriter rewriter(context); - rewriter.setInsertionPoint(op); - auto newKSK = rewriter.create( - op->getLoc(), resType, operand, extraConvKey); - DEBUG("create extra conversion keyswitch: " << newKSK); - op->setOperand(operandIdx, newKSK); } - }); + } + + protected: + const CircuitSolutionWrapper &solution; + const TypeResolver &typeResolver; + }; + + // Instantiates the constraint types `ConstraintTs`, adds a solver + // solution constraint as the first constraint and converges on all + // constraints + template + void converge(mlir::Operation *op, LocalInferenceState &state, + const LocalInferenceState &inferredTypes) { + TypeConstraintSet<> cs; + + if (solution.has_value()) + cs.addConstraint(*this, solution.value()); + + cs.addConstraints(); + cs.converge(op, *this, state, inferredTypes); + } + + // Return `true` iff `t` is GLWE type that is not parameterized, + // otherwise `false` + static bool isUnparametrizedGLWEType(mlir::Type t) { + std::optional ctt = + tryGetScalarType(t); + + return ctt.has_value() && ctt.value().getKey().isNone(); } - static void fixupNonParametrizedOps(mlir::func::FuncOp func) { - // Lookup all operators that uses function arguments - for (const auto arg : func.getArguments()) { - auto parametrizedGlweType = - getParametrizedGlweTypeFromType(arg.getType()); - if (parametrizedGlweType != nullptr) { - DEBUG(" -> Fixup uses of arg " << arg) - // The argument is glwe, so propagate the glwe parametrization to all - // operators which use it - for (auto userOp : arg.getUsers()) { - fixupNonParametrizedOp(userOp, parametrizedGlweType); + std::optional solution; +}; + +// TFHE-specific rewriter that handles conflicts of contradicting TFHE +// types through the introduction of `tfhe.keyswitch` / +// `tfhe.batched_keyswitch` operations and that removes `TFHE.OId` +// attributes after the rewrite. +class TFHECircuitSolutionRewriter : public TypeInferenceRewriter { +public: + TFHECircuitSolutionRewriter( + const mlir::DataFlowSolver &solver, + TFHEParametrizationTypeResolver &typeResolver, + const std::optional &solution) + : TypeInferenceRewriter(solver, typeResolver), typeResolver(typeResolver), + solution(solution) {} + + virtual mlir::LogicalResult postRewriteHook(mlir::IRRewriter &rewriter, + mlir::Operation *oldOp, + mlir::Operation *newOp) override { + mlir::IntegerAttr oid = newOp->getAttrOfType("TFHE.OId"); + + if (oid) { + newOp->removeAttr("TFHE.OId"); + + if (solution.has_value()) { + // Fixup key attributes + if (TFHE::GLWEKeyswitchKeyAttr attrKeyswitchKey = + newOp->getAttrOfType("key")) { + newOp->setAttr("key", solution->getKeyswitchKeyAttr( + newOp->getContext(), oid.getInt())); + } else if (TFHE::GLWEBootstrapKeyAttr attrBootstrapKey = + newOp->getAttrOfType( + "key")) { + newOp->setAttr("key", solution->getBootstrapKeyAttr( + newOp->getContext(), oid.getInt())); } } } - // Fixup all operators that take at least a parametrized glwe and produce an - // none glwe - func.walk([&](mlir::Operation *op) { - for (auto operand : op->getOperands()) { - auto parametrizedGlweType = - getParametrizedGlweTypeFromType(operand.getType()); - if (parametrizedGlweType != nullptr) { - // An operand is a parametrized glwe - for (auto result : op->getResults()) { - if (isNoneGlweType(result.getType())) { - DEBUG(" -> Fixup illegal op " << *op) - fixupNonParametrizedOp(op, parametrizedGlweType); - return; - } - } - } - } - }); - } - static void removeOptimizerIdentifiers(mlir::func::FuncOp func) { - for (size_t i = 0; i < func.getNumArguments(); i++) { - func.removeArgAttr(i, "TFHE.OId"); + // Bootstrap operations that have changed keys may need an + // adjustment of their lookup tables. This is currently limited to + // bootstrap operations using static LUTs to implement the rounded + // PBS operation and to bootstrap operations whose LUTs are + // encoded within function scope using an encode and expand + // operation. + if (TFHE::BootstrapGLWEOp newBSOp = + llvm::dyn_cast(newOp)) { + TFHE::BootstrapGLWEOp oldBSOp = llvm::cast(oldOp); + + if (checkFixupBootstrapLUTs(rewriter, oldBSOp, newBSOp).failed()) + return mlir::failure(); } - func.walk([&](mlir::Operation *op) { op->removeAttr("TFHE.OId"); }); + + return mlir::success(); } - static void fixupFunctionSignature(mlir::func::FuncOp func) { - mlir::SmallVector inputs; - mlir::SmallVector outputs; - // Set inputs by looking actual arguments types - for (auto arg : func.getArguments()) { - inputs.push_back(arg.getType()); + // Resolves conflicts between cipher text scalar or ciphertext + // tensor types by creating keyswitch / batched keyswitch operations + mlir::Value handleConflict(mlir::IRRewriter &rewriter, + mlir::OpOperand &oldOperand, + mlir::Type resolvedType, + mlir::Value producerValue) override { + mlir::Operation *oldOp = oldOperand.getOwner(); + + std::optional cttFrom = + tryGetScalarType(producerValue.getType()); + std::optional cttTo = + tryGetScalarType(resolvedType); + + // Only handle conflicts wrt. ciphertext types or tensors of ciphertext + // types + if (!cttFrom.has_value() || !cttTo.has_value() || + resolvedType.isa()) { + return TypeInferenceRewriter::handleConflict(rewriter, oldOperand, + resolvedType, producerValue); } - // Look for return to set the outputs - func.walk([&](mlir::func::ReturnOp returnOp) { - // TODO: multiple return op - for (auto output : returnOp->getOperandTypes()) { - outputs.push_back(output); + + // Place keyswitch operation near the producer of the value to + // avoid nesting it too depply into loops + if (mlir::Operation *producer = producerValue.getDefiningOp()) + rewriter.setInsertionPointAfter(producer); + + assert(cttFrom->getKey().getParameterized().has_value()); + assert(cttTo->getKey().getParameterized().has_value()); + + const ::concrete_optimizer::dag::ConversionKeySwitchKey &cksk = + solution->lookupConversionKeyswitchKey( + cttFrom->getKey().getParameterized()->identifier, + cttTo->getKey().getParameterized()->identifier); + + TFHE::GLWEKeyswitchKeyAttr kskAttr = + solution->getKeyswitchKeyAttr(rewriter.getContext(), cksk); + + // For tensor types, conversion must be done using a batched + // keyswitch operation, otherwise a simple keyswitch op is + // sufficient + if (mlir::RankedTensorType rtt = + resolvedType.dyn_cast()) { + if (rtt.getShape().size() == 1) { + // Flat input shapes can be handled directly by a batched + // keyswitch operation + return rewriter.create( + oldOp->getLoc(), resolvedType, producerValue, kskAttr); + } else { + // Input shapes with more dimensions must first be flattened + // using a tensor.collapse_shape operation before passing the + // values to a batched keyswitch operation + + mlir::ReassociationIndices reassocDims(rtt.getShape().size()); + + for (size_t i = 0; i < rtt.getShape().size(); i++) + reassocDims[i] = i; + + llvm::SmallVector reassocs = {reassocDims}; + + // Flatten inputs + mlir::Value collapsed = rewriter.create( + oldOp->getLoc(), producerValue, reassocs); + + mlir::Type collapsedResolvedType = mlir::RankedTensorType::get( + {rtt.getNumElements()}, rtt.getElementType()); + + TFHE::BatchedKeySwitchGLWEOp ksOp = + rewriter.create( + oldOp->getLoc(), collapsedResolvedType, collapsed, kskAttr); + + // Restore original shape for the result + return rewriter.create( + oldOp->getLoc(), resolvedType, ksOp.getResult(), reassocs); } - }); - auto funcType = - mlir::FunctionType::get(func->getContext(), inputs, outputs); - func.setFunctionType(funcType); + } else { + // Scalar inputs are directly handled by a simple keyswitch + // operation + return rewriter.create( + oldOp->getLoc(), resolvedType, producerValue, kskAttr); + } } - const concrete_optimizer::dag::InstructionKeys & - getInstructionKey(size_t optimizerID) { - DEBUG("lookup instruction key: #" << optimizerID); - return solution.instructions_keys[optimizerID]; - } +protected: + // Checks if the lookup table for a freshly rewritten bootstrap + // operation needs to be adjusted and performs the adjustment if + // this is the case. + mlir::LogicalResult checkFixupBootstrapLUTs(mlir::IRRewriter &rewriter, + TFHE::BootstrapGLWEOp oldBSOp, + TFHE::BootstrapGLWEOp newBSOp) { + TFHE::GLWEBootstrapKeyAttr oldBSKeyAttr = + oldBSOp->getAttrOfType("key"); - const TFHE::GLWESecretKey GLWESecretKeyAsLWE(TFHE::GLWESecretKey key) { - auto keyP = key.getParameterized(); - assert(keyP.has_value()); - return TFHE::GLWESecretKey::newParameterized( - keyP->polySize * keyP->dimension, 1, keyP->identifier); - } + assert(oldBSKeyAttr); - const TFHE::GLWESecretKey - toGLWESecretKey(concrete_optimizer::dag::SecretLweKey key) { - return TFHE::GLWESecretKey::newParameterized( - key.glwe_dimension, key.polynomial_size, key.identifier); - } + mlir::Value lut = newBSOp.getLookupTable(); + mlir::RankedTensorType lutType = + lut.getType().cast(); - const TFHE::GLWESecretKey - toLWESecretKey(concrete_optimizer::dag::SecretLweKey key) { - return TFHE::GLWESecretKey::newParameterized( - key.glwe_dimension * key.polynomial_size, 1, key.identifier); - } + assert(lutType.getShape().size() == 1); - const TFHE::GLWESecretKey getLWESecretKey(size_t keyID) { - DEBUG("lookup secret key: #" << keyID); - auto key = solution.circuit_keys.secret_keys[keyID]; - assert(keyID == key.identifier); - return toLWESecretKey(key); - } + if (lutType.getShape()[0] == oldBSKeyAttr.getPolySize()) { + // Parametrization has no effect on LUT + return mlir::success(); + } - const TFHE::GLWESecretKey getInputLWESecretKey(size_t optimizerID) { - auto keyID = getInstructionKey(optimizerID).input_key; - return getLWESecretKey(keyID); - } + mlir::Operation *lutOp = lut.getDefiningOp(); - const TFHE::GLWESecretKey getOutputLWESecretKey(size_t optimizerID) { - auto keyID = getInstructionKey(optimizerID).output_key; - return getLWESecretKey(keyID); - } + TFHE::GLWEBootstrapKeyAttr newBSKeyAttr = + newBSOp->getAttrOfType("key"); - const TFHE::GLWEKeyswitchKeyAttr - getKeyswitchKeyAttr(mlir::MLIRContext *context, size_t optimizerID) { - auto keyID = getInstructionKey(optimizerID).tlu_keyswitch_key; - DEBUG("lookup keyswicth key: #" << keyID); - auto key = solution.circuit_keys.keyswitch_keys[keyID]; - return TFHE::GLWEKeyswitchKeyAttr::get( - context, toLWESecretKey(key.input_key), toLWESecretKey(key.output_key), - key.ks_decomposition_parameter.level, - key.ks_decomposition_parameter.log2_base, -1); - } + assert(newBSKeyAttr); - const TFHE::GLWEKeyswitchKeyAttr - getExtraConversionKeyAttr(mlir::MLIRContext *context, - TFHE::GLWESecretKey inputKey, - TFHE::GLWESecretKey ouputKey) { - auto convKSK = std::find_if( - solution.circuit_keys.conversion_keyswitch_keys.begin(), - solution.circuit_keys.conversion_keyswitch_keys.end(), - [&](concrete_optimizer::dag::ConversionKeySwitchKey &arg) { - assert(ouputKey.isParameterized() && inputKey.isParameterized()); - return arg.input_key.identifier == - inputKey.getParameterized()->identifier && - arg.output_key.identifier == - ouputKey.getParameterized()->identifier; - }); - assert(convKSK != solution.circuit_keys.conversion_keyswitch_keys.end()); - return TFHE::GLWEKeyswitchKeyAttr::get( - context, toLWESecretKey(convKSK->input_key), - toLWESecretKey(convKSK->output_key), - convKSK->ks_decomposition_parameter.level, - convKSK->ks_decomposition_parameter.log2_base, -1); - } + mlir::RankedTensorType newLUTType = mlir::RankedTensorType::get( + mlir::ArrayRef{newBSKeyAttr.getPolySize()}, + rewriter.getI64Type()); - const TFHE::GLWEKeyswitchKeyAttr getExtraConversionKeyAttr( - mlir::MLIRContext *context, - concrete_optimizer::dag::ConversionKeySwitchKey convKSK) { - return TFHE::GLWEKeyswitchKeyAttr::get( - context, toLWESecretKey(convKSK.input_key), - toLWESecretKey(convKSK.output_key), - convKSK.ks_decomposition_parameter.level, - convKSK.ks_decomposition_parameter.log2_base, -1); + if (arith::ConstantOp oldCstOp = + llvm::dyn_cast_or_null(lutOp)) { + // LUT is generated from a constant. Parametrization is only + // supported if this is a scenario, in which the bootstrap + // operation is used as a rounded bootstrap with identical + // entries in the LUT. + mlir::DenseIntElementsAttr oldCstValsAttr = + oldCstOp.getValueAttr().dyn_cast(); + + if (!oldCstValsAttr.isSplat()) { + oldBSOp->emitError( + "Bootstrap operation uses a constant LUT, but with different " + "entries. Only constants with identical elements for the " + "implementation of a rounded PBS are supported for now"); + + return mlir::failure(); + } + + rewriter.setInsertionPointAfter(oldCstOp); + mlir::arith::ConstantOp newCstOp = + rewriter.create( + oldCstOp.getLoc(), newLUTType, + oldCstValsAttr.resizeSplat(newLUTType)); + + newBSOp.setOperand(1, newCstOp); + } else if (TFHE::EncodeExpandLutForBootstrapOp oldEncodeOp = + llvm::dyn_cast_or_null( + lutOp)) { + // For encode and expand operations, simply update the size of + // the polynomial + + rewriter.setInsertionPointAfter(oldEncodeOp); + + TFHE::EncodeExpandLutForBootstrapOp newEncodeOp = + rewriter.create( + oldEncodeOp.getLoc(), newLUTType, + oldEncodeOp.getInputLookupTable(), newBSKeyAttr.getPolySize(), + oldEncodeOp.getOutputBits(), oldEncodeOp.getIsSigned()); + + newBSOp.setOperand(1, newEncodeOp); + } else { + oldBSOp->emitError( + "Cannot update lookup table after parametrization, only constants " + "and tables generated through TFHE.encode_expand_lut_for_bootstrap " + "are supported"); + + return mlir::failure(); + } + + return mlir::success(); } - const TFHE::GLWEBootstrapKeyAttr - getBootstrapKeyAttr(mlir::MLIRContext *context, size_t optimizerID) { - auto keyID = getInstructionKey(optimizerID).tlu_bootstrap_key; - DEBUG("lookup bootstrap key: #" << keyID); - auto key = solution.circuit_keys.bootstrap_keys[keyID]; - return TFHE::GLWEBootstrapKeyAttr::get( - context, toLWESecretKey(key.input_key), toGLWESecretKey(key.output_key), - key.output_key.polynomial_size, key.output_key.glwe_dimension, - key.br_decomposition_parameter.level, - key.br_decomposition_parameter.log2_base, -1); + TFHEParametrizationTypeResolver &typeResolver; + const std::optional &solution; +}; + +// Rewrite pattern that materializes the boundary between two +// partitions specified in the solution of the optimizer by an extra +// conversion key for a bootstrap operation. +// +// Replaces the pattern: +// +// %v = TFHE.bootstrap_glwe(%i0, %i1) : (T0, T1) -> T2 +// ... +// ... someotherop(..., %v, ...) : (..., T2, ...) -> ... +// +// with: +// +// %v = TFHE.bootstrap_glwe(%i0, %i1) : (T0, T1) -> T2 +// %v1 = TypeInference.propagate_upward(%v) : T2 -> CT0 +// %v2 = TFHE.keyswitch_glwe(%v1) : CT0 -> CT1 +// %v3 = TypeInference.propagate_downward(%v) : CT1 -> T2 +// ... +// ... someotherop(..., %v3, ...) : (..., T2, ...) -> ... +// +// The TypeInference operations are necessary to avoid producing +// invalid IR if `T2` is an unparametrized type. +class MaterializePartitionBoundaryPattern + : public mlir::OpRewritePattern { +public: + static constexpr mlir::StringLiteral kTransformMarker = + "__internal_materialize_partition_boundary_marker__"; + + MaterializePartitionBoundaryPattern(mlir::MLIRContext *ctx, + const CircuitSolutionWrapper &solution) + : mlir::OpRewritePattern(ctx, 0), + solution(solution) {} + + mlir::LogicalResult + matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, + mlir::PatternRewriter &rewriter) const override { + // Avoid infinite recursion + if (bsOp->hasAttr(kTransformMarker)) + return mlir::failure(); + + mlir::IntegerAttr oidAttr = + bsOp->getAttrOfType("TFHE.OId"); + + if (!oidAttr) + return mlir::failure(); + + int64_t oid = oidAttr.getInt(); + + const ::concrete_optimizer::dag::InstructionKeys &instrKeys = + solution.lookupInstructionKeys(oid); + + if (instrKeys.extra_conversion_keys.size() == 0) + return mlir::failure(); + + assert(instrKeys.extra_conversion_keys.size() == 1); + + // Mark operation to avoid infinite recursion + bsOp->setAttr(kTransformMarker, rewriter.getUnitAttr()); + + const ::concrete_optimizer::dag::ConversionKeySwitchKey &cksk = + solution.lookupConversionKeyswitchKey(oid); + + TFHE::GLWECipherTextType cksInputType = + solution.getTFHETypeForKey(bsOp->getContext(), cksk.input_key); + + TFHE::GLWECipherTextType cksOutputType = + solution.getTFHETypeForKey(bsOp->getContext(), cksk.output_key); + + rewriter.setInsertionPointAfter(bsOp); + + TypeInference::PropagateUpwardOp puOp = + rewriter.create( + bsOp->getLoc(), cksInputType, bsOp.getResult()); + + TFHE::GLWEKeyswitchKeyAttr keyAttr = + solution.getKeyswitchKeyAttr(rewriter.getContext(), cksk); + + TFHE::KeySwitchGLWEOp ksOp = rewriter.create( + bsOp->getLoc(), cksOutputType, puOp.getResult(), keyAttr); + + mlir::Type unparametrizedType = TFHE::GLWECipherTextType::get( + rewriter.getContext(), TFHE::GLWESecretKey::newNone()); + + TypeInference::PropagateDownwardOp pdOp = + rewriter.create( + bsOp->getLoc(), unparametrizedType, ksOp.getResult()); + + rewriter.replaceAllUsesExcept(bsOp.getResult(), pdOp.getResult(), puOp); + + return mlir::success(); } - const TFHE::GLWECipherTextType - getOutputLWECipherTextType(mlir::MLIRContext *context, size_t optimizerID) { - auto outputKey = getOutputLWESecretKey(optimizerID); - return TFHE::GLWECipherTextType::get(context, outputKey); +protected: + const CircuitSolutionWrapper &solution; +}; + +class TFHECircuitSolutionParametrizationPass + : public TFHECircuitSolutionParametrizationBase< + TFHECircuitSolutionParametrizationPass> { +public: + TFHECircuitSolutionParametrizationPass( + std::optional<::concrete_optimizer::dag::CircuitSolution> solution) + : solution(solution){}; + + void runOnOperation() override { + mlir::ModuleOp module = this->getOperation(); + mlir::DataFlowSolver solver; + std::optional solutionWrapper = + solution.has_value() + ? std::make_optional(solution.value()) + : std::nullopt; + + if (solutionWrapper.has_value()) { + // The optimizer may have decided to place bootstrap operations + // at the edge of a partition. This is indicated by the presence + // of an "extra" conversion key for the OIds of the affected + // bootstrap operations. + // + // To keep type inference and the subsequent rewriting simple, + // materialize the required keyswitch operations straight away. + mlir::RewritePatternSet patterns(module->getContext()); + patterns.add( + module->getContext(), solutionWrapper.value()); + + if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)) + .failed()) + this->signalPassFailure(); + + // Clean operations from transformation marker + module.walk([](TFHE::BootstrapGLWEOp bsOp) { + bsOp->removeAttr(MaterializePartitionBoundaryPattern::kTransformMarker); + }); + } + + TFHEParametrizationTypeResolver typeResolver(solutionWrapper); + mlir::SymbolTableCollection symbolTables; + + solver.load(); + solver.load(); + solver.load(typeResolver); + solver.load(symbolTables, typeResolver); + + if (failed(solver.initializeAndRun(module))) + return signalPassFailure(); + + TFHECircuitSolutionRewriter tir(solver, typeResolver, solutionWrapper); + + if (tir.rewrite(module).failed()) + signalPassFailure(); } private: - concrete_optimizer::dag::CircuitSolution solution; + std::optional<::concrete_optimizer::dag::CircuitSolution> solution; }; } // end anonymous namespace -std::unique_ptr> +std::unique_ptr> createTFHECircuitSolutionParametrizationPass( - concrete_optimizer::dag::CircuitSolution solution) { + std::optional<::concrete_optimizer::dag::CircuitSolution> solution) { return std::make_unique(solution); } diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index a1d551c38e..6b62c75d3f 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -45,6 +45,7 @@ #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/Tracing/IR/TracingDialect.h" #include "concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h" +#include "concretelang/Dialect/TypeInference/IR/TypeInferenceDialect.h" #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Encodings.h" @@ -81,6 +82,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { if (this->mlirContext == nullptr) { mlir::DialectRegistry registry; registry.insert< + mlir::concretelang::TypeInference::TypeInferenceDialect, mlir::concretelang::Tracing::TracingDialect, mlir::concretelang::RT::RTDialect, mlir::concretelang::FHE::FHEDialect, mlir::concretelang::TFHE::TFHEDialect, diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index f18c6570fe..adb141f975 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -283,23 +283,28 @@ mlir::LogicalResult parametrizeTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::optional &fheContext, std::function enablePass) { - if (!fheContext) - return mlir::success(); - mlir::PassManager pm(&context); pipelinePrinting("ParametrizeTFHE", pm, context); - if (auto monoSolution = std::get_if(&fheContext->solution); - monoSolution != nullptr) { + if (!fheContext) { + // For tests, which invoke the pipeline without determining FHE + // parameters + addPotentiallyNestedPass( + pm, + mlir::concretelang::createTFHECircuitSolutionParametrizationPass( + std::nullopt), + enablePass); + } else if (auto monoSolution = + std::get_if(&fheContext->solution); + monoSolution != nullptr) { addPotentiallyNestedPass( pm, mlir::concretelang::createConvertTFHEGlobalParametrizationPass( *monoSolution), enablePass); - } - if (auto circuitSolution = - std::get_if(&fheContext->solution); - circuitSolution != nullptr) { + } else if (auto circuitSolution = + std::get_if(&fheContext->solution); + circuitSolution != nullptr) { addPotentiallyNestedPass( pm, mlir::concretelang::createTFHECircuitSolutionParametrizationPass(