From 3d16efb68152906fdf0ccc70878c7615c9abe8a8 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 5 Feb 2024 11:51:39 +0100 Subject: [PATCH 1/6] feat(compiler): Add functions for type inference debugging to TypeInferenceUtils The main debugging function is `TypeInferenceUtils::dumpAllState(mlir::Operation* op)` which dumps the entire state of type inference for the function containing `op`. --- .../Analysis/TypeInferenceAnalysis.h | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h index 0851766de5..9818d42885 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h @@ -916,6 +916,63 @@ class TypeInferenceAnalysisBase : public AnalysisT { } } + // Prints an indentation composed of `indent` times `" "`. + void printIndent(int indent) { + for (int i = 0; i < indent; i++) + llvm::dbgs() << " "; + } + + // Dumps the state of type inference for the operation `op` with an + // indentation level of `indent` as the name of the operation, + // followed by the types inferred for each operand, followed by + // `->`, followed by a dump of the state for any operation nested in + // any region of `op`. + void dumpStateForOp(mlir::Operation *op, int indent) { + const LocalInferenceState state = getCurrentInferredTypes(op); + + printIndent(indent); + llvm::dbgs() << op->getName() << " {"; + + llvm::interleaveComma( + op->getAttrs(), llvm::dbgs(), [&](const mlir::NamedAttribute &attr) { + llvm::dbgs() << attr.getName() << " = " << attr.getValue(); + }); + + llvm::dbgs() << "} : ("; + + llvm::interleaveComma(op->getOperands(), llvm::dbgs(), [&](mlir::Value v) { + llvm::dbgs() << state.find(v); + }); + + llvm::dbgs() << ") -> ("; + + llvm::interleaveComma(op->getResults(), llvm::dbgs(), [&](mlir::Value v) { + llvm::dbgs() << state.find(v); + }); + + llvm::dbgs() << ")\n"; + + for (mlir::Region &r : op->getRegions()) + for (mlir::Block &b : r.getBlocks()) + for (mlir::Operation &childOp : b.getOperations()) + dumpStateForOp(&childOp, indent + 1); + } + + // Dumps the entire state of type inference for the function + // containing the operation `op`. For each operation, this prints + // the name of the operation, followed by the types inferred for + // each operand, followed by `->`, followed by the types inferred + // for the results. + void dumpAllState(mlir::Operation *op) { + mlir::Operation *funcOp = op; + while (funcOp && !llvm::isa(funcOp)) + funcOp = funcOp->getParentOp(); + + assert(funcOp); + + dumpStateForOp(funcOp, 0); + } + TypeResolver &resolver; }; From 8e660e2f75a442e5713727e32c15295a2acc408a Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 20 Feb 2024 14:40:27 +0100 Subject: [PATCH 2/6] feat(compiler): Add dialect with operations related to the optimizer This adds a new dialect called `Optimizer` with operations related to the Concrete Optimizer. Currently, there is only one operation `optimizer.partition_frontier` that can be inserted between a producer and a consumer which belong to different partitions computed by the optimizer. The purpose of this operation is to preserve explicit key changes from the invocation of the optimizer on high-level dialects (i.e., FHELinalg / FHE) until the IR is provided with actual references to keys in low-level dialects (i.e., TFHE). --- .../concretelang/Dialect/CMakeLists.txt | 1 + .../Dialect/Optimizer/CMakeLists.txt | 1 + .../Dialect/Optimizer/IR/CMakeLists.txt | 10 +++++ .../Dialect/Optimizer/IR/OptimizerDialect.h | 15 +++++++ .../Dialect/Optimizer/IR/OptimizerDialect.td | 20 +++++++++ .../Dialect/Optimizer/IR/OptimizerOps.h | 16 +++++++ .../Dialect/Optimizer/IR/OptimizerOps.td | 37 ++++++++++++++++ .../compiler/lib/Dialect/CMakeLists.txt | 1 + .../lib/Dialect/Optimizer/CMakeLists.txt | 1 + .../lib/Dialect/Optimizer/IR/CMakeLists.txt | 13 ++++++ .../Dialect/Optimizer/IR/OptimizerDialect.cpp | 18 ++++++++ .../lib/Dialect/Optimizer/IR/OptimizerOps.cpp | 9 ++++ .../compiler/src/CMakeLists.txt | 1 + docs/dev/compilation/OptimizerDialect.md | 42 +++++++++++++++++++ 14 files changed, 185 insertions(+) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.h create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.td create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.h create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.td create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerDialect.cpp create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerOps.cpp create mode 100644 docs/dev/compilation/OptimizerDialect.md diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/CMakeLists.txt index 7f2a2d1b81..67438fa2cf 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(RT) add_subdirectory(SDFG) add_subdirectory(Tracing) add_subdirectory(TypeInference) +add_subdirectory(Optimizer) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/CMakeLists.txt new file mode 100644 index 0000000000..7e142837af --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +set(LLVM_TARGET_DEFINITIONS OptimizerOps.td) +mlir_tablegen(OptimizerOps.h.inc -gen-op-decls) +mlir_tablegen(OptimizerOps.cpp.inc -gen-op-defs) +mlir_tablegen(OptimizerOpsDialect.h.inc -gen-dialect-decls -dialect=Optimizer) +mlir_tablegen(OptimizerOpsDialect.cpp.inc -gen-dialect-defs -dialect=Optimizer) +add_public_tablegen_target(MLIROptimizerOpsIncGen) +add_dependencies(mlir-headers MLIROptimizerOpsIncGen) + +add_concretelang_doc(OptimizerOps OptimizerDialect concretelang/ -gen-dialect-doc -dialect=Optimizer) +add_concretelang_doc(OptimizerOps OptimizerOps concretelang/ -gen-op-doc) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.h new file mode 100644 index 0000000000..4160a7a52d --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.h @@ -0,0 +1,15 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZERDIALECT_H +#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZERDIALECT_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" + +#include "concretelang/Dialect/Optimizer/IR/OptimizerOpsDialect.h.inc" + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.td new file mode 100644 index 0000000000..c8a466c9d5 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerDialect.td @@ -0,0 +1,20 @@ +//===- OptimizerDialect.td - Optimizer dialect ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_DIALECT +#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_DIALECT + +include "mlir/IR/OpBase.td" + +def Optimizer_Dialect : Dialect { + let name = "Optimizer"; + let summary = "Auxiliary operations for the interaction with the optimizer"; + let cppNamespace = "::mlir::concretelang::Optimizer"; +} + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.h new file mode 100644 index 0000000000..b3e9762301 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.h @@ -0,0 +1,16 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZEROPS_H +#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZEROPS_H + +#include +#include +#include + +#define GET_OP_CLASSES +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h.inc" + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.td new file mode 100644 index 0000000000..84602d1a07 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Optimizer/IR/OptimizerOps.td @@ -0,0 +1,37 @@ +//===- OptimizerOps.td - Optimizer dialect ops ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_OPS +#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "concretelang/Dialect/Optimizer/IR/OptimizerDialect.td" + +class Optimizer_Op traits = []> : + Op; + +def Optimizer_PartitionFrontierOp : Optimizer_Op<"partition_frontier", [Pure]> { + let summary = "Models an explicit edge between two partitions"; + + let description = [{ + Models an explicit edge between two partitions in the solution + determined by the optimizer requiring a key change between the + encrypted values of the operand and the encrypted values of + the result. + }]; + + let arguments = (ins + AnyType:$input, + I64Attr:$inputKeyID, + I32Attr:$outputKeyID + ); + + let results = (outs AnyType); +} + +#endif diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/CMakeLists.txt index 4502cb2e7b..599e2661c5 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(RT) add_subdirectory(SDFG) add_subdirectory(Tracing) add_subdirectory(TypeInference) +add_subdirectory(Optimizer) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/CMakeLists.txt new file mode 100644 index 0000000000..d30bca394c --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library( + OptimizerDialect + OptimizerDialect.cpp + OptimizerOps.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Optimizer + DEPENDS + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR) + +target_link_libraries(OptimizerDialect PUBLIC MLIRIR) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerDialect.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerDialect.cpp new file mode 100644 index 0000000000..19a9f0c89f --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerDialect.cpp @@ -0,0 +1,18 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/Optimizer/IR/OptimizerDialect.h" +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h" + +#include "concretelang/Dialect/Optimizer/IR/OptimizerOpsDialect.cpp.inc" + +using namespace mlir::concretelang::Optimizer; + +void OptimizerDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.cpp.inc" + >(); +} diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerOps.cpp new file mode 100644 index 0000000000..98965198ed --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Optimizer/IR/OptimizerOps.cpp @@ -0,0 +1,9 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h" + +#define GET_OP_CLASSES +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.cpp.inc" diff --git a/compilers/concrete-compiler/compiler/src/CMakeLists.txt b/compilers/concrete-compiler/compiler/src/CMakeLists.txt index 5aba3e8ccd..323bf08c41 100644 --- a/compilers/concrete-compiler/compiler/src/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/src/CMakeLists.txt @@ -16,6 +16,7 @@ target_link_libraries( TFHEDialect FHEDialect SDFGDialect + OptimizerDialect ConcretelangSupport ConcretelangTransforms MLIRIR diff --git a/docs/dev/compilation/OptimizerDialect.md b/docs/dev/compilation/OptimizerDialect.md new file mode 100644 index 0000000000..a146959719 --- /dev/null +++ b/docs/dev/compilation/OptimizerDialect.md @@ -0,0 +1,42 @@ + +# 'Optimizer' Dialect + +Auxiliary operations for the interaction with the optimizer + + +## Operation definition + +### `Optimizer.partition_frontier` (::mlir::concretelang::Optimizer::PartitionFrontierOp) + +Models an explicit edge between two partitions + +Models an explicit edge between two partitions in the solution +determined by the optimizer requiring a key change between the +encrypted values of the operand and the encrypted values of +the result. + +Traits: AlwaysSpeculatableImplTrait + +Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) + +Effects: MemoryEffects::Effect{} + +#### Attributes: + +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `inputKeyID` | ::mlir::IntegerAttr | 64-bit signless integer attribute +| `outputKeyID` | ::mlir::IntegerAttr | 32-bit signless integer attribute + +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `input` | any type + +#### Results: + +| Result | Description | +| :----: | ----------- | +«unnamed» | any type + From a701b3a74227b0f62b481a32e1a3bdb5b9a7fd9b Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 20 Feb 2024 14:58:58 +0100 Subject: [PATCH 3/6] feat(compiler): Add support for tensor.empty in the pipeline from FHE to std --- .../lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp | 5 +++++ .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 6 ++++++ .../TFHEGlobalParametrization.cpp | 6 ++++++ .../TFHEKeyNormalization/TFHEKeyNormalization.cpp | 6 ++++++ .../Conversion/TFHEToConcrete/TFHEToConcrete.cpp | 15 ++++++++------- .../TFHECircuitSolutionParametrization.cpp | 2 +- .../compiler/lib/Support/Pipeline.cpp | 6 ++++++ 7 files changed, 38 insertions(+), 8 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 3bcb127f0c..2f88e5c26c 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -996,6 +996,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { mlir::tensor::CollapseShapeOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); //---------------------------------------------------------- Adding patterns mlir::RewritePatternSet patterns(&getContext()); @@ -1076,6 +1078,9 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { patterns.add>(&getContext(), converter); + patterns.add>(&getContext(), converter); + mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 5e494dbc34..9aa0f5b9b7 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -888,6 +888,12 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { patterns.add>(&getContext(), converter); + patterns.add>(&getContext(), converter); + + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 0ecde5abfa..575f0e712b 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -347,6 +347,12 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp< mlir::bufferization::AllocTensorOp>(target, converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + patterns.add>( &getContext(), converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp index 7d9efdbca4..a80edee508 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp @@ -365,6 +365,12 @@ void TFHEKeyNormalizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp< mlir::bufferization::AllocTensorOp>(target, typeConverter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, typeConverter); + patterns.add>( &getContext(), typeConverter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index f71c88c0de..f35159ab0d 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -875,11 +875,11 @@ void TFHEToConcretePass::runOnOperation() { mlir::tensor::YieldOp, mlir::scf::YieldOp, mlir::tensor::GenerateOp, mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, mlir::tensor::InsertSliceOp, mlir::tensor::ExpandShapeOp, - mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp>( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); + mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp, + mlir::tensor::EmptyOp>([&](mlir::Operation *op) { + return converter.isLegal(op->getResultTypes()) && + converter.isLegal(op->getOperandTypes()); + }); // rewrite scf for loops if working on illegal types patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< - mlir::bufferization::AllocTensorOp, true>>(&getContext(), - converter); + mlir::bufferization::AllocTensorOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::EmptyOp, true>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp index 797d75a1ef..58a3210466 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp @@ -227,7 +227,7 @@ class TFHEParametrizationTypeResolver : public TypeResolver { TFHE::BatchedBootstrapGLWEOp, TFHE::EncodeExpandLutForBootstrapOp, TFHE::EncodeLutForCrtWopPBSOp, TFHE::EncodePlaintextWithCrtOp, TFHE::WopPBSGLWEOp, mlir::func::ReturnOp, - Tracing::TraceCiphertextOp>([&](auto op) { + Tracing::TraceCiphertextOp, mlir::tensor::EmptyOp>([&](auto op) { converge(op, state, inferredTypes); }) diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 4dcd52dd15..0221880b10 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -445,6 +445,12 @@ mlir::LogicalResult lowerToStd(mlir::MLIRContext &context, mlir::PassManager pm(&context); pipelinePrinting("Lowering to Std", pm, context); + // Replace non-bufferizable ops (e;g., `tensor.empty` -> + // `bufferization.alloc_tensor`) + addPotentiallyNestedPass( + pm, mlir::bufferization::createEmptyTensorToAllocTensorPass(), + enablePass); + // Bufferize mlir::bufferization::OneShotBufferizationOptions bufferizationOptions; bufferizationOptions.allowReturnAllocs = true; From 9b6878316fbcab5bea38311183d8754bbd187103 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 20 Feb 2024 15:02:40 +0100 Subject: [PATCH 4/6] fix(compiler): Preserve explicit optimizer partition boundaries through the pipeline The Concrete Optimizer is invoked on a representation of the program in the high-level FHELinalg / FHE Dialects and yields a solution with a one-to-one mapping of operations to keys. However, the abstractions used by these dialects do not allow for references to keys and the application of the solution is delayed until the pipeline reaches a representation of the program in the lower-level TFHE dialect. Various transformations applied by the pipeline along the way may break the one-to-one mapping and add indirections into producer-consumer relationships, resulting in ambiguous or partial mappings of TFHE operations to the keys. In particular, explicit frontiers between optimizer partitions may not be recovered. This commit preserves explicit frontiers between optimizer partitions as `optimizer.partition_frontier` operations and lowers these to keyswitch operations before parametrization of TFHE operations. --- .../Dialect/FHE/Transforms/CMakeLists.txt | 1 + .../FHE/Transforms/Optimizer/CMakeLists.txt | 4 + .../FHE/Transforms/Optimizer/Optimizer.h | 30 ++++ .../FHE/Transforms/Optimizer/Optimizer.td | 23 +++ .../concretelang/Support/CompilerEngine.h | 2 + .../include/concretelang/Support/Pipeline.h | 5 + .../FHETensorOpsToLinalg/CMakeLists.txt | 4 +- .../TensorOpsToLinalg.cpp | 80 ++++++++++ .../Conversion/FHEToTFHECrt/CMakeLists.txt | 1 + .../Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp | 7 + .../Conversion/FHEToTFHEScalar/CMakeLists.txt | 1 + .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 8 + .../lib/Dialect/FHE/Transforms/CMakeLists.txt | 5 +- .../lib/Dialect/FHE/Transforms/Optimizer.cpp | 143 ++++++++++++++++++ .../Dialect/TFHE/Transforms/CMakeLists.txt | 4 +- .../TFHECircuitSolutionParametrization.cpp | 66 +++----- .../compiler/lib/Support/CompilerEngine.cpp | 19 +++ .../compiler/lib/Support/Pipeline.cpp | 28 ++++ 18 files changed, 379 insertions(+), 52 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.td create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Optimizer.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt index 9c798458c1..71e5cb7030 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(DynamicTLU) add_subdirectory(BigInt) add_subdirectory(Boolean) add_subdirectory(Max) +add_subdirectory(Optimizer) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/CMakeLists.txt new file mode 100644 index 0000000000..dc69f1507a --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Optimizer.td) +mlir_tablegen(Optimizer.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangFHEOptimizerPassIncGen) +add_dependencies(mlir-headers ConcretelangFHEOptimizerPassIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h new file mode 100644 index 0000000000..aea2c46857 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h @@ -0,0 +1,30 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES_H +#define CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES_H + +#include +#include + +#include +#include +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace concretelang { + +std::unique_ptr> +createOptimizerPartitionFrontierMaterializationPass( + const optimizer::CircuitSolution &solverSolution); + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.td new file mode 100644 index 0000000000..9fa3002dca --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.td @@ -0,0 +1,23 @@ +#ifndef CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES +#define CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def OptimizerPartitionFrontierMaterializationPass + : Pass<"optimizer-partition-frontier-materialization", + "::mlir::func::FuncOp"> { + let summary = + "Inserts Optimizer.partition_frontier operations between FHE operations " + "that were explicitly marked by the optimizer as belonging to separate " + "partitions via an extra conversion key in the optimizer solution."; + + let constructor = "mlir::concretelang::" + "createOptimizerPartitionFrontierMaterializationPass()"; + let options = []; + let dependentDialects = [ + "mlir::concretelang::FHE::FHEDialect", + "mlir::concretelang::Optimizer::OptimizerDialect" + ]; +} + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h index 7422ea094a..ba08f29dd6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h @@ -342,6 +342,8 @@ class CompilerEngine { llvm::Expected> getConcreteOptimizerDescription(CompilationResult &res); llvm::Error determineFHEParameters(CompilationResult &res); + mlir::LogicalResult + materializeOptimizerPartitionFrontiers(CompilationResult &res); }; } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index b5e72d8b97..6ea78678ef 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -19,6 +19,11 @@ namespace pipeline { mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult materializeOptimizerPartitionFrontiers( + mlir::MLIRContext &context, mlir::ModuleOp &module, + std::optional &fheContext, + std::function enablePass); + llvm::Expected>> getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, optimizer::Config config, diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/CMakeLists.txt index b8668cbb97..e00142e7bb 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/CMakeLists.txt @@ -6,11 +6,13 @@ add_mlir_dialect_library( DEPENDS FHEDialect FHELinalgDialect + OptimizerDialect mlir-headers LINK_LIBS PUBLIC MLIRIR FHEDialect - FHELinalgDialect) + FHELinalgDialect + OptimizerDialect) target_link_libraries(FHEDialect PUBLIC MLIRIR) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index e89ab824cf..e8eb990e9a 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -22,6 +22,7 @@ #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h" #include "concretelang/Support/Constants.h" #include "concretelang/Support/logging.h" @@ -1936,6 +1937,77 @@ struct FHELinalgUnaryOpToLinalgGeneric }; }; +// Replaces a `optimizer.partition_frontier` operation with a tensor +// operand and a tensor result with a `linalg.generic` operation +// applying a `optimizer.partition_frontier` operation with scalar +// operands. +struct TensorPartitionFrontierOpToLinalgGeneric + : public mlir::OpRewritePattern< + mlir::concretelang::Optimizer::PartitionFrontierOp> { + TensorPartitionFrontierOpToLinalgGeneric( + ::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) + : ::mlir::OpRewritePattern< + mlir::concretelang::Optimizer::PartitionFrontierOp>(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::Optimizer::PartitionFrontierOp pfOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::RankedTensorType resultTy = + pfOp.getResult().getType().cast(); + mlir::RankedTensorType tensorTy = + pfOp.getInput().getType().cast(); + + mlir::Value init = rewriter.create( + pfOp.getLoc(), resultTy, mlir::ValueRange{}); + + // Create affine maps and iterator types for an embarassingly + // parallel op + llvm::SmallVector maps{ + mlir::AffineMap::getMultiDimIdentityMap(tensorTy.getShape().size(), + this->getContext()), + mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), + this->getContext()), + }; + + llvm::SmallVector iteratorTypes( + resultTy.getShape().size(), mlir::utils::IteratorType::parallel); + + // Create the body of the `linalg.generic` op applying a + // `tensor.partition_frontier` op on the scalar arguments + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::concretelang::Optimizer::PartitionFrontierOp scalarOp = + nestedBuilder + .create( + pfOp.getLoc(), resultTy.getElementType(), blockArgs[0], + pfOp->getAttrs()); + + nestedBuilder.create(pfOp.getLoc(), + scalarOp.getResult()); + }; + + // Create the `linalg.generic` op + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{pfOp.getInput()}; + llvm::SmallVector outs{init}; + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + mlir::linalg::GenericOp genericOp = + rewriter.create(pfOp.getLoc(), resTypes, ins, + outs, maps, iteratorTypes, doc, + call, bodyBuilder); + + rewriter.replaceOp(pfOp, {genericOp.getResult(0)}); + + return ::mlir::success(); + }; +}; + namespace { struct FHETensorOpsToLinalg : public FHETensorOpsToLinalgBase { @@ -1956,6 +2028,13 @@ void FHETensorOpsToLinalg::runOnOperation() { target.addIllegalOp(); target.addIllegalDialect(); + target.addDynamicallyLegalOp< + mlir::concretelang::Optimizer::PartitionFrontierOp>( + [&](mlir::concretelang::Optimizer::PartitionFrontierOp op) { + return !op.getInput().getType().isa() && + !op.getResult().getType().isa(); + }); + mlir::RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/CMakeLists.txt index f79790fa24..631a80cf68 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library( ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE DEPENDS FHEDialect + OptimizerDialect mlir-headers LINK_LIBS PUBLIC diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 2f88e5c26c..d6c1e88571 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -11,6 +11,7 @@ #include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" #include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -996,6 +997,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { mlir::tensor::CollapseShapeOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::Optimizer::PartitionFrontierOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); @@ -1078,6 +1081,10 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { patterns.add>(&getContext(), converter); + patterns.add>( + &getContext(), converter); + patterns.add>(&getContext(), converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/CMakeLists.txt index 57655c8a39..115be1775b 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library( ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE DEPENDS FHEDialect + OptimizerDialect mlir-headers LINK_LIBS PUBLIC diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 9aa0f5b9b7..4caa26eb2c 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -10,6 +10,7 @@ #include #include +#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -888,6 +889,13 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { patterns.add>(&getContext(), converter); + patterns.add>( + &getContext(), converter); + + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::Optimizer::PartitionFrontierOp>(target, converter); + patterns.add>(&getContext(), converter); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt index e15a7fa856..97a27bbbec 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt @@ -5,12 +5,15 @@ add_mlir_library( Max.cpp EncryptedMulToDoubleTLU.cpp DynamicTLU.cpp + Optimizer.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE DEPENDS FHEDialect + OptimizerDialect mlir-headers LINK_LIBS PUBLIC MLIRIR - FHEDialect) + FHEDialect + OptimizerDialect) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Optimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Optimizer.cpp new file mode 100644 index 0000000000..96b27f2466 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Optimizer.cpp @@ -0,0 +1,143 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include + +#include +#include +#include + +#include + +namespace mlir { +namespace concretelang { + +struct OptimizerPartitionFrontierMaterializationPass + : public OptimizerPartitionFrontierMaterializationPassBase< + OptimizerPartitionFrontierMaterializationPass> { + + OptimizerPartitionFrontierMaterializationPass( + const optimizer::CircuitSolution &solverSolution) + : solverSolution(solverSolution) {} + + enum class OperationKind { PRODUCER, CONSUMER }; + + std::optional getOid(mlir::Operation *op, OperationKind kind) { + if (mlir::IntegerAttr oidAttr = + op->getAttrOfType("TFHE.OId")) { + return oidAttr.getInt(); + } else if (mlir::DenseI32ArrayAttr oidArrayAttr = + op->getAttrOfType("TFHE.OId")) { + assert(oidArrayAttr.size() > 0); + + if (kind == OperationKind::CONSUMER) { + return oidArrayAttr[0]; + } else { + // All operations with a `TFHE.OId` array attribute store the + // OId of the result at the last position, except + // multiplications, which use the 6th element (at index 5), + // see `mlir::concretelang::optimizer::FunctionToDag::addMul`. + if (llvm::dyn_cast(op) || + llvm::dyn_cast(op)) { + assert(oidArrayAttr.size() >= 6); + return oidArrayAttr[5]; + } else { + return oidArrayAttr[oidArrayAttr.size() - 1]; + } + } + } else { + return std::nullopt; + } + } + + void runOnOperation() final { + mlir::func::FuncOp func = this->getOperation(); + + func.walk([&](mlir::Operation *producer) { + std::optional producerOid = + getOid(producer, OperationKind::PRODUCER); + + if (!producerOid.has_value()) + return; + + assert(*producerOid < solverSolution.instructions_keys.size()); + + auto &eck = + solverSolution.instructions_keys[*producerOid].extra_conversion_keys; + + if (eck.size() == 0) + return; + + assert(eck.size() == 1); + assert(eck[0] < + solverSolution.circuit_keys.conversion_keyswitch_keys.size()); + + uint64_t producerOutKeyID = + solverSolution.instructions_keys[*producerOid].output_key; + + uint64_t conversionOutKeyID = + solverSolution.circuit_keys.conversion_keyswitch_keys[eck[0]] + .output_key.identifier; + + mlir::IRRewriter rewriter(producer->getContext()); + rewriter.setInsertionPointAfter(producer); + + for (mlir::Value res : producer->getResults()) { + mlir::Value resConverted; + + for (mlir::OpOperand &operand : + llvm::make_early_inc_range(res.getUses())) { + mlir::Operation *consumer = operand.getOwner(); + + std::optional consumerOid = + getOid(consumer, OperationKind::CONSUMER); + + // By default, all consumers need the converted value, + // unless it is explicitly specified that the original value + // is needed + bool needsConvertedValue = true; + + if (consumerOid.has_value()) { + assert(*consumerOid < solverSolution.instructions_keys.size()); + + uint64_t consumerInKeyID = + solverSolution.instructions_keys[*consumerOid].input_key; + + if (consumerInKeyID == producerOutKeyID) { + needsConvertedValue = false; + } else { + assert(consumerInKeyID == conversionOutKeyID && + "Consumer needs converted value, but with a key that is " + "not the extra conversion key of the producer"); + } + } + + if (needsConvertedValue) { + if (!resConverted) { + resConverted = rewriter.create( + producer->getLoc(), res.getType(), res, producerOutKeyID, + conversionOutKeyID); + } + + operand.set(resConverted); + } + } + } + }); + } + +protected: + const optimizer::CircuitSolution &solverSolution; +}; + +std::unique_ptr> +createOptimizerPartitionFrontierMaterializationPass( + const optimizer::CircuitSolution &solverSolution) { + return std::make_unique( + solverSolution); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt index fc47e62a30..375d2816f5 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt @@ -6,8 +6,10 @@ add_mlir_library( ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/TFHE DEPENDS TFHEDialect + OptimizerDialect mlir-headers LINK_LIBS PUBLIC MLIRIR - TFHEDialect) + TFHEDialect + OptimizerDialect) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp index 58a3210466..f1d81d6923 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp @@ -8,6 +8,7 @@ #include "concretelang/Dialect/TypeInference/IR/TypeInferenceOps.h" #include +#include #include #include #include @@ -765,71 +766,46 @@ class TFHECircuitSolutionRewriter : public TypeInferenceRewriter { // The TypeInference operations are necessary to avoid producing // invalid IR if `T2` is an unparametrized type. class MaterializePartitionBoundaryPattern - : public mlir::OpRewritePattern { + : public mlir::OpRewritePattern { public: - static constexpr mlir::StringLiteral kTransformMarker = - "__internal_materialize_partition_boundary_marker__"; - MaterializePartitionBoundaryPattern(mlir::MLIRContext *ctx, const CircuitSolutionWrapper &solution) - : mlir::OpRewritePattern(ctx, 0), + : mlir::OpRewritePattern(ctx, 0), solution(solution) {} mlir::LogicalResult - matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, + matchAndRewrite(Optimizer::PartitionFrontierOp pfOp, mlir::PatternRewriter &rewriter) const override { - // Avoid infinite recursion - if (bsOp->hasAttr(kTransformMarker)) - return mlir::failure(); - - mlir::IntegerAttr oidAttr = - bsOp->getAttrOfType("TFHE.OId"); - - if (!oidAttr) - return mlir::failure(); - - int64_t oid = oidAttr.getInt(); - - const ::concrete_optimizer::dag::InstructionKeys &instrKeys = - solution.lookupInstructionKeys(oid); - - if (instrKeys.extra_conversion_keys.size() == 0) - return mlir::failure(); - - assert(instrKeys.extra_conversion_keys.size() == 1); - - // Mark operation to avoid infinite recursion - bsOp->setAttr(kTransformMarker, rewriter.getUnitAttr()); - const ::concrete_optimizer::dag::ConversionKeySwitchKey &cksk = - solution.lookupConversionKeyswitchKey(oid); + solution.lookupConversionKeyswitchKey(pfOp.getInputKeyID(), + pfOp.getOutputKeyID()); TFHE::GLWECipherTextType cksInputType = - solution.getTFHETypeForKey(bsOp->getContext(), cksk.input_key); + solution.getTFHETypeForKey(pfOp->getContext(), cksk.input_key); TFHE::GLWECipherTextType cksOutputType = - solution.getTFHETypeForKey(bsOp->getContext(), cksk.output_key); + solution.getTFHETypeForKey(pfOp->getContext(), cksk.output_key); - rewriter.setInsertionPointAfter(bsOp); + rewriter.setInsertionPointAfter(pfOp); TypeInference::PropagateUpwardOp puOp = rewriter.create( - bsOp->getLoc(), cksInputType, bsOp.getResult()); + pfOp->getLoc(), cksInputType, pfOp.getInput()); TFHE::GLWEKeyswitchKeyAttr keyAttr = solution.getKeyswitchKeyAttr(rewriter.getContext(), cksk); TFHE::KeySwitchGLWEOp ksOp = rewriter.create( - bsOp->getLoc(), cksOutputType, puOp.getResult(), keyAttr); + pfOp->getLoc(), cksOutputType, puOp.getResult(), keyAttr); mlir::Type unparametrizedType = TFHE::GLWECipherTextType::get( rewriter.getContext(), TFHE::GLWESecretKey::newNone()); TypeInference::PropagateDownwardOp pdOp = rewriter.create( - bsOp->getLoc(), unparametrizedType, ksOp.getResult()); + pfOp->getLoc(), unparametrizedType, ksOp.getResult()); - rewriter.replaceAllUsesExcept(bsOp.getResult(), pdOp.getResult(), puOp); + rewriter.replaceAllUsesWith(pfOp.getResult(), pdOp.getResult()); return mlir::success(); } @@ -855,13 +831,10 @@ class TFHECircuitSolutionParametrizationPass : std::nullopt; if (solutionWrapper.has_value()) { - // The optimizer may have decided to place bootstrap operations - // at the edge of a partition. This is indicated by the presence - // of an "extra" conversion key for the OIds of the affected - // bootstrap operations. - // - // To keep type inference and the subsequent rewriting simple, - // materialize the required keyswitch operations straight away. + // Materialize explicit transitions between optimizer partitions + // by replacing `optimizer.partition_frontier` operations with + // keyswitch operations in order to keep type inference and the + // subsequent rewriting simple. mlir::RewritePatternSet patterns(module->getContext()); patterns.add( module->getContext(), solutionWrapper.value()); @@ -869,11 +842,6 @@ class TFHECircuitSolutionParametrizationPass if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)) .failed()) this->signalPassFailure(); - - // Clean operations from transformation marker - module.walk([](TFHE::BootstrapGLWEOp bsOp) { - bsOp->removeAttr(MaterializePartitionBoundaryPattern::kTransformMarker); - }); } TFHEParametrizationTypeResolver typeResolver(solutionWrapper); diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index cddc262c45..e0d98c7274 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -38,6 +38,7 @@ #include "concretelang/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h" +#include "concretelang/Dialect/Optimizer/IR/OptimizerDialect.h" #include "concretelang/Dialect/RT/IR/RTDialect.h" #include "concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h" #include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" @@ -85,6 +86,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { registry.insert< mlir::concretelang::TypeInference::TypeInferenceDialect, mlir::concretelang::Tracing::TracingDialect, + mlir::concretelang::Optimizer::OptimizerDialect, mlir::concretelang::RT::RTDialect, mlir::concretelang::FHE::FHEDialect, mlir::concretelang::TFHE::TFHEDialect, mlir::concretelang::FHELinalg::FHELinalgDialect, @@ -246,6 +248,18 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { return llvm::Error::success(); } +mlir::LogicalResult +CompilerEngine::materializeOptimizerPartitionFrontiers(CompilationResult &res) { + mlir::ModuleOp module = res.mlirModuleRef->get(); + + if (res.fheContext.has_value()) { + return pipeline::materializeOptimizerPartitionFrontiers( + *module.getContext(), module, res.fheContext, enablePass); + } + + return mlir::success(); +} + using OptionalLib = std::optional>; // Compile the sources managed by the source manager `sm` to the // target dialect `target`. If successful, the result can be retrieved @@ -364,6 +378,11 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, if (auto err = this->determineFHEParameters(res)) return std::move(err); + if (this->materializeOptimizerPartitionFrontiers(res).failed()) { + return StreamStringError( + "Could not materialize explicit optimizer partition frontiers"); + } + // Now that FHE Parameters were computed, we can set the encoding mode of // integer ciphered inputs. if ((this->generateProgramInfo || target == Target::LIBRARY)) { diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 0221880b10..74d3f4d150 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -6,6 +6,7 @@ #include "llvm/Support/TargetSelect.h" #include "concretelang/Support/CompilationFeedback.h" +#include "concretelang/Support/V0Parameters.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -40,6 +41,7 @@ #include "concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h" #include "concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU/EncryptedMulToDoubleTLU.h" #include "concretelang/Dialect/FHE/Transforms/Max/Max.h" +#include "concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h" #include "concretelang/Dialect/FHELinalg/Transforms/Tiling.h" #include "concretelang/Dialect/RT/Analysis/Autopar.h" #include "concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h" @@ -146,6 +148,32 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, return std::move(descriptions); } +mlir::LogicalResult materializeOptimizerPartitionFrontiers( + mlir::MLIRContext &context, mlir::ModuleOp &module, + std::optional &fheContext, + std::function enablePass) { + + if (!fheContext.has_value()) + return mlir::success(); + + optimizer::CircuitSolution *circuitSolution = + std::get_if(&fheContext->solution); + + if (!circuitSolution) + return mlir::success(); + + mlir::PassManager pm(&context); + pipelinePrinting("MaterializeOptimizerPartitionFrontiers", pm, context); + + addPotentiallyNestedPass( + pm, + mlir::concretelang::createOptimizerPartitionFrontierMaterializationPass( + *circuitSolution), + enablePass); + + return pm.run(module.getOperation()); +} + mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { mlir::PassManager pm(&context); From b9589146f45550b1c696e7a5169ef2665c45ed86 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 9 Feb 2024 06:19:41 +0100 Subject: [PATCH 5/6] test(compiler): Add tests generating explicit optimizer partition frontiers --- .../check_tests/BugReport/bug_report_538.mlir | 31 +++++++++ .../consumers_in_different_partitions.mlir | 66 +++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_538.mlir create mode 100644 compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHECircuitParametrization/consumers_in_different_partitions.mlir diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_538.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_538.mlir new file mode 100644 index 0000000000..e438fe6f14 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/BugReport/bug_report_538.mlir @@ -0,0 +1,31 @@ +// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi %s + +// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: !TFHE.glwe>, %[[Varg1:.*]]: !TFHE.glwe>) -> (!TFHE.glwe>, !TFHE.glwe>) { +// CHECK-NEXT: %[[Vcst:.*]] = arith.constant dense<0> : tensor<256xi64> +// CHECK-NEXT: %[[Vcst_0:.*]] = arith.constant dense<0> : tensor<128xi64> +// CHECK-NEXT: %[[Vcst_1:.*]] = arith.constant dense<{{\[0, 1\]}}> : tensor<2xi64> +// CHECK-NEXT: %[[V0:.*]] = "TFHE.encode_expand_lut_for_bootstrap"(%[[Vcst_1]]) {isSigned = false, outputBits = 8 : i32, polySize = 256 : i32} : (tensor<2xi64>) -> tensor<256xi64> +// CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[Varg0]]) {TFHE.OId = 2 : i32, key = #TFHE.ksk, sk<3,1,601>, 3, 3>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {TFHE.OId = 2 : i32, key = #TFHE.bsk, sk<0,1,1536>, 256, 6, 2, 12>} : (!TFHE.glwe>, tensor<256xi64>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V3:.*]] = "TFHE.encode_expand_lut_for_bootstrap"(%[[Vcst_0]]) {isSigned = false, outputBits = 8 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64> +// CHECK-NEXT: %[[V4:.*]] = "TFHE.keyswitch_glwe"(%[[Varg1]]) {TFHE.OId = 3 : i32, key = #TFHE.ksk, sk<4,1,923>, 6, 3>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V5:.*]] = "TFHE.bootstrap_glwe"(%[[V4]], %[[V3]]) {TFHE.OId = 3 : i32, key = #TFHE.bsk, sk<1,1,8192>, 8192, 1, 2, 15>} : (!TFHE.glwe>, tensor<8192xi64>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V6:.*]] = "TFHE.keyswitch_glwe"(%[[V5]]) {key = #TFHE.ksk, sk<0,1,1536>, 1, 19>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V7:.*]] = "TFHE.add_glwe"(%[[V2]], %[[V6]]) {TFHE.OId = 4 : i32} : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V8:.*]] = "TFHE.encode_expand_lut_for_bootstrap"(%[[Vcst]]) {isSigned = false, outputBits = 5 : i32, polySize = 16384 : i32} : (tensor<256xi64>) -> tensor<16384xi64> +// CHECK-NEXT: %[[V9:.*]] = "TFHE.keyswitch_glwe"(%[[V5]]) {TFHE.OId = 5 : i32, key = #TFHE.ksk, sk<5,1,967>, 7, 3>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V10:.*]] = "TFHE.bootstrap_glwe"(%[[V9]], %[[V8]]) {TFHE.OId = 5 : i32, key = #TFHE.bsk, sk<2,1,16384>, 16384, 1, 1, 22>} : (!TFHE.glwe>, tensor<16384xi64>) -> !TFHE.glwe> +// CHECK-NEXT: return %[[V7]], %[[V10]] : !TFHE.glwe>, !TFHE.glwe> +// CHECK-NEXT: } +func.func @main(%arg0: !FHE.eint<1>, %arg1: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<5>) { + %cst = arith.constant dense<[0, 1]> : tensor<2xi64> + %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<1>, tensor<2xi64>) -> !FHE.eint<8> + %cst_0 = arith.constant dense<0> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg1, %cst_0) : (!FHE.eint<7>, tensor<128xi64>) -> !FHE.eint<8> + %2 = "FHE.add_eint"(%0, %1) : (!FHE.eint<8>, !FHE.eint<8>) -> !FHE.eint<8> + %c4_i4 = arith.constant 4 : i4 + %cst_1 = arith.constant dense<0> : tensor<256xi64> + %3 = "FHE.apply_lookup_table"(%1, %cst_1) : (!FHE.eint<8>, tensor<256xi64>) -> !FHE.eint<5> + return %2, %3 : !FHE.eint<8>, !FHE.eint<5> +} + diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHECircuitParametrization/consumers_in_different_partitions.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHECircuitParametrization/consumers_in_different_partitions.mlir new file mode 100644 index 0000000000..392aa0309f --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/TFHECircuitParametrization/consumers_in_different_partitions.mlir @@ -0,0 +1,66 @@ +// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi %s + +// CHECK: module { +// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x!TFHE.glwe>>, %[[Varg1:.*]]: tensor<2x!TFHE.glwe>>) -> (tensor<2x!TFHE.glwe>>, tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[Vcst:.*]] = arith.constant dense<0> : tensor<128xi64> +// CHECK-NEXT: %[[Vcst_0:.*]] = arith.constant dense<{{\[0, 1\]}}> : tensor<2xi64> +// CHECK-NEXT: %[[V0:.*]] = "TFHE.zero_tensor"() : () -> tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg0]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V10:.*]] = "TFHE.encode_expand_lut_for_bootstrap"(%[[Vcst_0]]) {isSigned = false, outputBits = 8 : i32, polySize = 256 : i32} : (tensor<2xi64>) -> tensor<256xi64> +// CHECK-NEXT: %[[V11:.*]] = "TFHE.keyswitch_glwe"(%[[Vextracted]]) {key = #TFHE.ksk, sk<2,1,604>, 3, 3>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V12:.*]] = "TFHE.bootstrap_glwe"(%[[V11]], %[[V10]]) {key = #TFHE.bsk, sk<0,1,1536>, 256, 6, 2, 12>} : (!TFHE.glwe>, tensor<256xi64>) -> !TFHE.glwe> +// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V12]] into %[[Varg3]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: %[[V2:.*]] = "TFHE.zero_tensor"() : () -> tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V2]]) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V10:.*]] = "TFHE.encode_expand_lut_for_bootstrap"(%[[Vcst]]) {isSigned = false, outputBits = 8 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64> +// CHECK-NEXT: %[[V11:.*]] = "TFHE.keyswitch_glwe"(%[[Vextracted]]) {key = #TFHE.ksk, sk<3,1,942>, 6, 3>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V12:.*]] = "TFHE.bootstrap_glwe"(%[[V11]], %[[V10]]) {key = #TFHE.bsk, sk<1,1,8192>, 8192, 1, 1, 22>} : (!TFHE.glwe>, tensor<8192xi64>) -> !TFHE.glwe> +// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V12]] into %[[Varg3]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: %[[V4:.*]] = tensor.empty() : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V5:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V4]]) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[V3]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V10:.*]] = "TFHE.keyswitch_glwe"(%[[Vextracted]]) {key = #TFHE.ksk, sk<0,1,1536>, 1, 19>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg3]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: %[[V6:.*]] = "TFHE.zero_tensor"() : () -> tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V7:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V6]]) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[V1]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[Vextracted_1:.*]] = tensor.extract %[[V5]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V10:.*]] = "TFHE.add_glwe"(%[[Vextracted]], %[[Vextracted_1]]) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V10]] into %[[Varg3]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: %[[V8:.*]] = "TFHE.zero_tensor"() : () -> tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V9:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V8]]) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %[[Vextracted:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %[[V10:.*]] = "TFHE.encode_expand_lut_for_bootstrap"(%[[Vcst]]) {isSigned = false, outputBits = 5 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64> +// CHECK-NEXT: %[[V11:.*]] = "TFHE.keyswitch_glwe"(%[[Vextracted]]) {key = #TFHE.ksk, sk<3,1,942>, 6, 3>} : (!TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %[[V12:.*]] = "TFHE.bootstrap_glwe"(%[[V11]], %[[V10]]) {key = #TFHE.bsk, sk<1,1,8192>, 8192, 1, 1, 22>} : (!TFHE.glwe>, tensor<8192xi64>) -> !TFHE.glwe> +// CHECK-NEXT: %[[Vinserted:.*]] = tensor.insert %[[V12]] into %[[Varg3]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %[[Vinserted]] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V7]], %[[V9]] : tensor<2x!TFHE.glwe>>, tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +func.func @main(%arg0: tensor<2x!FHE.eint<1>>, %arg1: tensor<2x!FHE.eint<7>>) -> (tensor<2x!FHE.eint<8>>, tensor<2x!FHE.eint<5>>) { + %cst = arith.constant dense<[0, 1]> : tensor<2xi64> + %0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<1>>, tensor<2xi64>) -> tensor<2x!FHE.eint<8>> + %cst_0 = arith.constant dense<0> : tensor<128xi64> + %1 = "FHELinalg.apply_lookup_table"(%arg1, %cst_0) : (tensor<2x!FHE.eint<7>>, tensor<128xi64>) -> tensor<2x!FHE.eint<8>> + %2 = "FHELinalg.add_eint"(%0, %1) : (tensor<2x!FHE.eint<8>>, tensor<2x!FHE.eint<8>>) -> tensor<2x!FHE.eint<8>> + %c4_i4 = arith.constant 4 : i4 + %cst_1 = arith.constant dense<0> : tensor<256xi64> + %3 = "FHELinalg.apply_lookup_table"(%1, %cst_1) : (tensor<2x!FHE.eint<8>>, tensor<256xi64>) -> tensor<2x!FHE.eint<5>> + return %2, %3 : tensor<2x!FHE.eint<8>>, tensor<2x!FHE.eint<5>> +} From 32199292bb755c1db03de91fafc521ec408c795b Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 23 Feb 2024 16:07:00 +0100 Subject: [PATCH 6/6] test(frontend-python): Re-enable min / max tests for multi-parameter optimization --- .../tests/execution/test_min_max.py | 36 ++++--------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/frontends/concrete-python/tests/execution/test_min_max.py b/frontends/concrete-python/tests/execution/test_min_max.py index c12b5f899c..c25d8fc9e3 100644 --- a/frontends/concrete-python/tests/execution/test_min_max.py +++ b/frontends/concrete-python/tests/execution/test_min_max.py @@ -11,8 +11,6 @@ from concrete.fhe.dtypes import Integer from concrete.fhe.values import ValueDescription -KNOWN_MULTI_BUG = "_mark_xfail_if_multi" - cases = [ [ # operation @@ -107,7 +105,7 @@ for rhs_is_signed in [False, True] for operation in [ ( - "maximum" + KNOWN_MULTI_BUG, + "maximum", lambda x, y: np.maximum(x, y), ), ] @@ -144,8 +142,8 @@ strategy, ] for operation in [ - ("minimum" + KNOWN_MULTI_BUG, lambda x, y: np.minimum(x, y)), - ("maximum" + KNOWN_MULTI_BUG, lambda x, y: np.maximum(x, y)), + ("minimum", lambda x, y: np.minimum(x, y)), + ("maximum", lambda x, y: np.maximum(x, y)), ] for strategy in strategies ] @@ -195,14 +193,6 @@ def test_minimum_maximum( parameter_encryption_statuses = {"x": "encrypted", "y": "encrypted"} configuration = helpers.configuration() - can_be_known_failure = ( - name.endswith("xfail_if_multi") - and configuration.parameter_selection_strategy == fhe.ParameterSelectionStrategy.MULTI - ) - if can_be_known_failure: - pytest.skip( - reason="Compiler parametrization pass make the compilation fail with an assertion" - ) if strategy is not None: configuration = configuration.fork(min_max_strategy_preference=[strategy]) @@ -216,10 +206,8 @@ def test_minimum_maximum( ) for _ in range(100) ] - try: - circuit = compiler.compile(inputset, configuration) - except AssertionError: - pytest.xfail("Known cp bug") + + circuit = compiler.compile(inputset, configuration) samples = [ [ @@ -244,16 +232,4 @@ def test_minimum_maximum( ], ] for sample in samples: - try: - helpers.check_execution(circuit, function, sample, retries=5) - except RuntimeError as exc: - if ( - str(exc) == "RuntimeError: Can't compile: Parametrization of TFHE operations failed" - and can_be_known_failure - ): - pytest.xfail("Known multi output bug, bad compilation") - raise - except AssertionError: - if can_be_known_failure: - pytest.xfail("Known multi output bug, bad result") - raise + helpers.check_execution(circuit, function, sample, retries=5)