From 54daaf5281971150b6150906b2a36206bdd0d499 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 29 Aug 2023 15:49:09 +0100 Subject: [PATCH 1/2] refactor(compiler): clean statistic passes --- .../include/concretelang/Analysis/Utils.h | 30 + .../Dialect/Concrete/Analysis/Analysis.td | 13 + .../Dialect/Concrete/Analysis/CMakeLists.txt | 4 + .../Dialect/Concrete/Analysis/MemoryUsage.h | 47 +- .../Dialect/Concrete/CMakeLists.txt | 1 + .../Dialect/TFHE/Analysis/Analysis.td | 14 + .../Dialect/TFHE/Analysis/CMakeLists.txt | 4 + .../Dialect/TFHE/Analysis/ExtractStatistics.h | 47 +- .../concretelang/Dialect/TFHE/CMakeLists.txt | 1 + .../compiler/lib/Analysis/CMakeLists.txt | 8 + .../compiler/lib/Analysis/Utils.cpp | 67 ++ .../compiler/lib/Bindings/Rust/build.rs | 1 + .../compiler/lib/CMakeLists.txt | 1 + .../Dialect/Concrete/Analysis/CMakeLists.txt | 3 +- .../Dialect/Concrete/Analysis/MemoryUsage.cpp | 340 +++++----- .../lib/Dialect/TFHE/Analysis/CMakeLists.txt | 3 +- .../TFHE/Analysis/ExtractStatistics.cpp | 589 +++++++++--------- .../compiler/lib/Support/Pipeline.cpp | 4 +- 18 files changed, 599 insertions(+), 578 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/Analysis.td create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/Analysis.td create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/lib/Analysis/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h new file mode 100644 index 0000000000..132cf966c0 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h @@ -0,0 +1,30 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_ANALYSIS_UTILS_H +#define CONCRETELANG_ANALYSIS_UTILS_H + +#include +#include +#include +#include + +namespace mlir { +namespace concretelang { + +/// Get the string representation of a location +std::string locationString(mlir::Location loc); + +/// Compute the number of iterations based on loop info +int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step); + +/// Compute the number of iterations of an scf for loop +outcome::checked +calculateNumberOfIterations(scf::ForOp &op); + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/Analysis.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/Analysis.td new file mode 100644 index 0000000000..a351612fb0 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/Analysis.td @@ -0,0 +1,13 @@ +#ifndef CONCRETELANG_DIALECT_CONCRETE_ANALYSIS +#define CONCRETELANG_DIALECT_CONCRETE_ANALYSIS + +include "mlir/Pass/PassBase.td" + +def MemoryUsage : Pass<"MemoryUsage", "::mlir::ModuleOp"> { + let summary = "Compute memory usage"; + let description = [{ + Computes memory usage per location, and provides those numbers throught the CompilationFeedback. + }]; +} + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..6fa5f625a3 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Analysis.td) +mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis) +add_public_tablegen_target(ConcretelangConcreteAnalysisPassIncGen) +add_dependencies(mlir-headers ConcretelangConcreteAnalysisPassIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h index 22d815001b..38d1f5285b 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/Analysis/MemoryUsage.h @@ -7,59 +7,16 @@ #define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H #include -#include #include #include namespace mlir { namespace concretelang { -namespace Concrete { -struct MemoryUsagePass - : public PassWrapper> { +std::unique_ptr> +createMemoryUsagePass(CompilationFeedback &feedback); - CompilationFeedback &feedback; - - MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {}; - - void runOnOperation() override { - WalkResult walk = - getOperation()->walk([&](Operation *op, const WalkStage &stage) { - if (stage.isBeforeAllRegions()) { - std::optional error = this->enter(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); - } - } - - if (stage.isAfterAllRegions()) { - std::optional error = this->exit(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); - } - } - - return WalkResult::advance(); - }); - - if (walk.wasInterrupted()) { - signalPassFailure(); - } - } - - std::optional enter(Operation *op); - - std::optional exit(Operation *op); - - std::map> visitedValuesPerLoc; - - size_t iterations = 1; -}; - -} // namespace Concrete } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt index 9f57627c32..306b439685 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/Analysis.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/Analysis.td new file mode 100644 index 0000000000..43e1da32b2 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/Analysis.td @@ -0,0 +1,14 @@ +#ifndef CONCRETELANG_DIALECT_TFHE_ANALYSIS +#define CONCRETELANG_DIALECT_TFHE_ANALYSIS + +include "mlir/Pass/PassBase.td" + +def ExtractStatistics : Pass<"ExtractStatistics", "::mlir::ModuleOp"> { + let summary = "Extracts statistics"; + let description = [{ + Extracts different statistics (e.g. number of certain crypto operations), + and provides those numbers throught the CompilationFeedback. + }]; +} + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..902b098392 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Analysis.td) +mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis) +add_public_tablegen_target(ConcretelangTFHEAnalysisPassIncGen) +add_dependencies(mlir-headers ConcretelangTFHEAnalysisPassIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h index d253073b2b..866ff6d04a 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h @@ -7,58 +7,15 @@ #define CONCRETELANG_DIALECT_TFHE_ANALYSIS_EXTRACT_STATISTICS_H #include -#include #include #include namespace mlir { namespace concretelang { -namespace TFHE { -struct ExtractTFHEStatisticsPass - : public PassWrapper> { - - CompilationFeedback &feedback; - - ExtractTFHEStatisticsPass(CompilationFeedback &feedback) - : feedback{feedback} {}; - - void runOnOperation() override { - WalkResult walk = - getOperation()->walk([&](Operation *op, const WalkStage &stage) { - if (stage.isBeforeAllRegions()) { - std::optional error = this->enter(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); - } - } - - if (stage.isAfterAllRegions()) { - std::optional error = this->exit(op); - if (error.has_value()) { - op->emitError() << error->mesg; - return WalkResult::interrupt(); - } - } - - return WalkResult::advance(); - }); - - if (walk.wasInterrupted()) { - signalPassFailure(); - } - } - - std::optional enter(Operation *op); - - std::optional exit(Operation *op); - - size_t iterations = 1; -}; - -} // namespace TFHE +std::unique_ptr> +createStatisticExtractionPass(CompilationFeedback &feedback); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt index 9f57627c32..306b439685 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..51cfb96338 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Analysis/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_library( + AnalysisUtils + Utils.cpp + DEPENDS + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR) diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp b/compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp new file mode 100644 index 0000000000..fd795cc838 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp @@ -0,0 +1,67 @@ +#include +#include + +using ::concretelang::error::StringError; + +namespace mlir { +namespace concretelang { +std::string locationString(mlir::Location loc) { + auto location = std::string(); + auto locationStream = llvm::raw_string_ostream(location); + loc->print(locationStream); + return location; +} + +int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) { + int64_t high; + int64_t low; + + if (step > 0) { + low = start; + high = stop; + } else { + low = stop; + high = start; + step = -step; + } + + if (low >= high) { + return 0; + } + + return ((high - low - 1) / step) + 1; +} + +outcome::checked +calculateNumberOfIterations(scf::ForOp &op) { + mlir::Value startValue = op.getLowerBound(); + mlir::Value stopValue = op.getUpperBound(); + mlir::Value stepValue = op.getStep(); + + auto startOp = + llvm::dyn_cast_or_null(startValue.getDefiningOp()); + auto stopOp = + llvm::dyn_cast_or_null(stopValue.getDefiningOp()); + auto stepOp = + llvm::dyn_cast_or_null(stepValue.getDefiningOp()); + + if (!startOp || !stopOp || !stepOp) { + return StringError("only static loops can be analyzed"); + } + + auto startAttr = startOp.getValue().cast(); + auto stopAttr = stopOp.getValue().cast(); + auto stepAttr = stepOp.getValue().cast(); + + if (!startOp || !stopOp || !stepOp) { + return StringError("only integer loops can be analyzed"); + } + + int64_t start = startAttr.getInt(); + int64_t stop = stopAttr.getInt(); + int64_t step = stepAttr.getInt(); + + return calculateNumberOfIterations(start, stop, step); +} +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs index 92998b8041..7da8bb0e4a 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs @@ -262,6 +262,7 @@ const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &[ const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &["LLVMX86CodeGen", "LLVMX86Desc", "LLVMX86Info"]; const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[ + "AnalysisUtils", "RTDialect", "RTDialectTransforms", "ConcretelangSupport", diff --git a/compilers/concrete-compiler/compiler/lib/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/CMakeLists.txt index 21644c29ad..541b3d708d 100644 --- a/compilers/concrete-compiler/compiler/lib/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/CMakeLists.txt index 5b77920c98..b276cc6cb7 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/CMakeLists.txt @@ -9,4 +9,5 @@ add_mlir_library( LINK_LIBS PUBLIC MLIRIR - ConcreteDialect) + ConcreteDialect + AnalysisUtils) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp index 7c8de9694c..d90c62ef73 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp @@ -1,109 +1,20 @@ +#include #include #include #include #include #include #include +#include +#include #include #include using namespace mlir::concretelang; using namespace mlir; -using Concrete::MemoryUsagePass; - namespace { -std::string locationString(mlir::Location loc) { - auto location = std::string(); - auto locationStream = llvm::raw_string_ostream(location); - loc->print(locationStream); - return location; -} - -int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) { - int64_t high; - int64_t low; - - if (step > 0) { - low = start; - high = stop; - } else { - low = stop; - high = start; - step = -step; - } - - if (low >= high) { - return 0; - } - - return ((high - low - 1) / step) + 1; -} - -std::optional calculateNumberOfIterations(scf::ForOp &op, - int64_t &result) { - mlir::Value startValue = op.getLowerBound(); - mlir::Value stopValue = op.getUpperBound(); - mlir::Value stepValue = op.getStep(); - - auto startOp = - llvm::dyn_cast_or_null(startValue.getDefiningOp()); - auto stopOp = - llvm::dyn_cast_or_null(stopValue.getDefiningOp()); - auto stepOp = - llvm::dyn_cast_or_null(stepValue.getDefiningOp()); - - if (!startOp || !stopOp || !stepOp) { - return StringError("only static loops can be analyzed"); - } - - auto startAttr = startOp.getValue().cast(); - auto stopAttr = stopOp.getValue().cast(); - auto stepAttr = stepOp.getValue().cast(); - - if (!startOp || !stopOp || !stepOp) { - return StringError("only integer loops can be analyzed"); - } - - int64_t start = startAttr.getInt(); - int64_t stop = stopAttr.getInt(); - int64_t step = stepAttr.getInt(); - - result = calculateNumberOfIterations(start, stop, step); - return std::nullopt; -} - -static std::optional on_enter(scf::ForOp &op, - MemoryUsagePass &pass) { - int64_t numberOfIterations; - - std::optional error = - calculateNumberOfIterations(op, numberOfIterations); - if (error.has_value()) { - return error; - } - - assert(numberOfIterations > 0); - pass.iterations *= (uint64_t)numberOfIterations; - return std::nullopt; -} - -static std::optional on_exit(scf::ForOp &op, - MemoryUsagePass &pass) { - int64_t numberOfIterations; - - std::optional error = - calculateNumberOfIterations(op, numberOfIterations); - if (error.has_value()) { - return error; - } - - assert(numberOfIterations > 0); - pass.iterations /= (uint64_t)numberOfIterations; - return std::nullopt; -} - int64_t getElementTypeSize(mlir::Type elementType) { if (auto integerType = mlir::dyn_cast(elementType)) { auto width = integerType.getWidth(); @@ -146,108 +57,185 @@ bool isBufferDeallocated(mlir::Value buffer) { return false; } -static std::optional on_enter(memref::AllocOp &op, - MemoryUsagePass &pass) { +} // namespace - auto maybeBufferSize = getBufferSize(op.getResult().getType()); - if (!maybeBufferSize) { - return maybeBufferSize.error(); +namespace mlir { +namespace concretelang { +namespace Concrete { + +struct MemoryUsagePass + : public PassWrapper> { + + CompilationFeedback &feedback; + + MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {}; + + void runOnOperation() override { + WalkResult walk = + getOperation()->walk([&](Operation *op, const WalkStage &stage) { + if (stage.isBeforeAllRegions()) { + std::optional error = this->enter(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } + } + + if (stage.isAfterAllRegions()) { + std::optional error = this->exit(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } + } + + return WalkResult::advance(); + }); + + if (walk.wasInterrupted()) { + signalPassFailure(); + } } - // if the allocated buffer is being deallocated then count it as one. - // Otherwise (and there must be a problem) multiply it by the number of - // iterations - int64_t numberOfAlloc = - isBufferDeallocated(op.getResult()) ? 1 : pass.iterations; - - auto location = locationString(op.getLoc()); - // pass.iterations number of allocation of size: shape_1 * ... * shape_n * - // element_size - auto memoryUsage = numberOfAlloc * maybeBufferSize.value(); - - pass.feedback.memoryUsagePerLoc[location] += memoryUsage; - return std::nullopt; -} - -static std::optional on_enter(mlir::Operation *op, - MemoryUsagePass &pass) { - for (auto operand : op->getOperands()) { - // we only consider buffers - if (!mlir::isa(operand.getType())) - continue; - // find the origin of the buffer - auto definingOp = operand.getDefiningOp(); - mlir::Value lastVisitedBuffer = operand; - while (definingOp) { - mlir::ViewLikeOpInterface viewLikeOp = - mlir::dyn_cast(definingOp); - if (viewLikeOp) { - lastVisitedBuffer = viewLikeOp.getViewSource(); - definingOp = lastVisitedBuffer.getDefiningOp(); - } else { - break; + std::optional enter(mlir::Operation *op) { + // specialized calls + if (auto typedOp = llvm::dyn_cast(op)) { + std::optional error = on_enter(typedOp, *this); + if (error.has_value()) { + return error; + } + } + if (auto typedOp = llvm::dyn_cast(op)) { + std::optional error = on_enter(typedOp, *this); + if (error.has_value()) { + return error; } } - // we already count allocations separately - if (definingOp && mlir::isa(definingOp) && - definingOp->getLoc() == op->getLoc()) - continue; - - auto location = locationString(op->getLoc()); - - std::vector &visited = pass.visitedValuesPerLoc[location]; - // the search would be faster if we use an unsorted_set, but we need a hash - // function for mlir::Value - if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) == - visited.end()) { - visited.push_back(lastVisitedBuffer); + // call generic enter + std::optional error = on_enter(op, *this); + if (error.has_value()) { + return error; + } + return std::nullopt; + } - auto maybeBufferSize = - getBufferSize(lastVisitedBuffer.getType().cast()); - if (!maybeBufferSize) { - return maybeBufferSize.error(); + std::optional exit(mlir::Operation *op) { + if (auto typedOp = llvm::dyn_cast(op)) { + std::optional error = on_exit(typedOp, *this); + if (error.has_value()) { + return error; } - auto bufferSize = maybeBufferSize.value(); - - pass.feedback.memoryUsagePerLoc[location] += bufferSize; } + return std::nullopt; } - return std::nullopt; -} + static std::optional on_enter(scf::ForOp &op, + MemoryUsagePass &pass) { + auto numberOfIterations = calculateNumberOfIterations(op); + if (!numberOfIterations) { + return numberOfIterations.error(); + } -} // namespace + assert(numberOfIterations.value() > 0); + pass.iterations *= (uint64_t)numberOfIterations.value(); + return std::nullopt; + } -std::optional MemoryUsagePass::enter(mlir::Operation *op) { - // specialized calls - if (auto typedOp = llvm::dyn_cast(op)) { - std::optional error = on_enter(typedOp, *this); - if (error.has_value()) { - return error; + static std::optional on_exit(scf::ForOp &op, + MemoryUsagePass &pass) { + auto numberOfIterations = calculateNumberOfIterations(op); + if (!numberOfIterations) { + return numberOfIterations.error(); } + + assert(numberOfIterations.value() > 0); + pass.iterations /= (uint64_t)numberOfIterations.value(); + return std::nullopt; } - if (auto typedOp = llvm::dyn_cast(op)) { - std::optional error = on_enter(typedOp, *this); - if (error.has_value()) { - return error; + + static std::optional on_enter(memref::AllocOp &op, + MemoryUsagePass &pass) { + + auto maybeBufferSize = getBufferSize(op.getResult().getType()); + if (!maybeBufferSize) { + return maybeBufferSize.error(); } - } + // if the allocated buffer is being deallocated then count it as one. + // Otherwise (and there must be a problem) multiply it by the number of + // iterations + int64_t numberOfAlloc = + isBufferDeallocated(op.getResult()) ? 1 : pass.iterations; + + auto location = locationString(op.getLoc()); + // pass.iterations number of allocation of size: shape_1 * ... * shape_n * + // element_size + auto memoryUsage = numberOfAlloc * maybeBufferSize.value(); + + pass.feedback.memoryUsagePerLoc[location] += memoryUsage; + + return std::nullopt; + } + + static std::optional on_enter(mlir::Operation *op, + MemoryUsagePass &pass) { + for (auto operand : op->getOperands()) { + // we only consider buffers + if (!mlir::isa(operand.getType())) + continue; + // find the origin of the buffer + auto definingOp = operand.getDefiningOp(); + mlir::Value lastVisitedBuffer = operand; + while (definingOp) { + mlir::ViewLikeOpInterface viewLikeOp = + mlir::dyn_cast(definingOp); + if (viewLikeOp) { + lastVisitedBuffer = viewLikeOp.getViewSource(); + definingOp = lastVisitedBuffer.getDefiningOp(); + } else { + break; + } + } + // we already count allocations separately + if (definingOp && mlir::isa(definingOp) && + definingOp->getLoc() == op->getLoc()) + continue; - // call generic enter - std::optional error = on_enter(op, *this); - if (error.has_value()) { - return error; - } - return std::nullopt; -} + auto location = locationString(op->getLoc()); -std::optional MemoryUsagePass::exit(mlir::Operation *op) { - if (auto typedOp = llvm::dyn_cast(op)) { - std::optional error = on_exit(typedOp, *this); - if (error.has_value()) { - return error; + std::vector &visited = pass.visitedValuesPerLoc[location]; + + // the search would be faster if we use an unsorted_set, but we need a + // hash function for mlir::Value + if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) == + visited.end()) { + visited.push_back(lastVisitedBuffer); + + auto maybeBufferSize = + getBufferSize(lastVisitedBuffer.getType().cast()); + if (!maybeBufferSize) { + return maybeBufferSize.error(); + } + auto bufferSize = maybeBufferSize.value(); + + pass.feedback.memoryUsagePerLoc[location] += bufferSize; + } } + + return std::nullopt; } - return std::nullopt; + + std::map> visitedValuesPerLoc; + + size_t iterations = 1; +}; + +} // namespace Concrete + +std::unique_ptr> +createMemoryUsagePass(CompilationFeedback &feedback) { + return std::make_unique(feedback); } + +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/CMakeLists.txt index 97e005a754..b0eb8a2ba1 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/CMakeLists.txt @@ -9,4 +9,5 @@ add_mlir_library( LINK_LIBS PUBLIC MLIRIR - TFHEDialect) + TFHEDialect + AnalysisUtils) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp index 2a13c03f22..781e46fa75 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp @@ -1,383 +1,356 @@ +#include #include #include #include +#include +#include #include using namespace mlir::concretelang; using namespace mlir; -using TFHE::ExtractTFHEStatisticsPass; +namespace mlir { +namespace concretelang { +namespace TFHE { -// ######### -// Utilities -// ######### - -template std::string locationOf(Op op) { - auto location = std::string(); - auto locationStream = llvm::raw_string_ostream(location); - op.getLoc()->print(locationStream); - return location.substr(5, location.size() - 2 - 5); // remove loc(" and ") -} - -// ####### -// scf.for -// ####### - -int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) { - int64_t high; - int64_t low; - - if (step > 0) { - low = start; - high = stop; - } else { - low = stop; - high = start; - step = -step; +#define DISPATCH_ENTER(type) \ + if (auto typedOp = llvm::dyn_cast(op)) { \ + std::optional error = on_enter(typedOp, *this); \ + if (error.has_value()) { \ + return error; \ + } \ } - if (low >= high) { - return 0; +#define DISPATCH_EXIT(type) \ + if (auto typedOp = llvm::dyn_cast(op)) { \ + std::optional error = on_exit(typedOp, *this); \ + if (error.has_value()) { \ + return error; \ + } \ } - return ((high - low - 1) / step) + 1; -} - -std::optional calculateNumberOfIterations(scf::ForOp &op, - int64_t &result) { - mlir::Value startValue = op.getLowerBound(); - mlir::Value stopValue = op.getUpperBound(); - mlir::Value stepValue = op.getStep(); - - auto startOp = - llvm::dyn_cast_or_null(startValue.getDefiningOp()); - auto stopOp = - llvm::dyn_cast_or_null(stopValue.getDefiningOp()); - auto stepOp = - llvm::dyn_cast_or_null(stepValue.getDefiningOp()); - - if (!startOp || !stopOp || !stepOp) { - return StringError("only static loops can be analyzed"); +struct ExtractTFHEStatisticsPass + : public PassWrapper> { + + CompilationFeedback &feedback; + + ExtractTFHEStatisticsPass(CompilationFeedback &feedback) + : feedback{feedback} {}; + + void runOnOperation() override { + WalkResult walk = + getOperation()->walk([&](Operation *op, const WalkStage &stage) { + if (stage.isBeforeAllRegions()) { + std::optional error = this->enter(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } + } + + if (stage.isAfterAllRegions()) { + std::optional error = this->exit(op); + if (error.has_value()) { + op->emitError() << error->mesg; + return WalkResult::interrupt(); + } + } + + return WalkResult::advance(); + }); + + if (walk.wasInterrupted()) { + signalPassFailure(); + } } - auto startAttr = startOp.getValue().cast(); - auto stopAttr = stopOp.getValue().cast(); - auto stepAttr = stepOp.getValue().cast(); - - if (!startOp || !stopOp || !stepOp) { - return StringError("only integer loops can be analyzed"); + std::optional enter(mlir::Operation *op) { + DISPATCH_ENTER(scf::ForOp) + DISPATCH_ENTER(TFHE::AddGLWEOp) + DISPATCH_ENTER(TFHE::AddGLWEIntOp) + DISPATCH_ENTER(TFHE::BootstrapGLWEOp) + DISPATCH_ENTER(TFHE::KeySwitchGLWEOp) + DISPATCH_ENTER(TFHE::MulGLWEIntOp) + DISPATCH_ENTER(TFHE::NegGLWEOp) + DISPATCH_ENTER(TFHE::SubGLWEIntOp) + DISPATCH_ENTER(TFHE::WopPBSGLWEOp) + return std::nullopt; } - int64_t start = startAttr.getInt(); - int64_t stop = stopAttr.getInt(); - int64_t step = stepAttr.getInt(); - - result = calculateNumberOfIterations(start, stop, step); - return std::nullopt; -} - -static std::optional on_enter(scf::ForOp &op, - ExtractTFHEStatisticsPass &pass) { - int64_t numberOfIterations; - - std::optional error = - calculateNumberOfIterations(op, numberOfIterations); - if (error.has_value()) { - return error; + std::optional exit(mlir::Operation *op) { + DISPATCH_EXIT(scf::ForOp) + return std::nullopt; } - assert(numberOfIterations > 0); - pass.iterations *= (uint64_t)numberOfIterations; - return std::nullopt; -} - -static std::optional on_exit(scf::ForOp &op, - ExtractTFHEStatisticsPass &pass) { - int64_t numberOfIterations; + static std::optional on_enter(scf::ForOp &op, + ExtractTFHEStatisticsPass &pass) { + auto numberOfIterations = calculateNumberOfIterations(op); + if (!numberOfIterations) { + return numberOfIterations.error(); + } - std::optional error = - calculateNumberOfIterations(op, numberOfIterations); - if (error.has_value()) { - return error; + assert(numberOfIterations.value() > 0); + pass.iterations *= (uint64_t)numberOfIterations.value(); + return std::nullopt; } - assert(numberOfIterations > 0); - pass.iterations /= (uint64_t)numberOfIterations; - return std::nullopt; -} + static std::optional on_exit(scf::ForOp &op, + ExtractTFHEStatisticsPass &pass) { + auto numberOfIterations = calculateNumberOfIterations(op); + if (!numberOfIterations) { + return numberOfIterations.error(); + } -// ############# -// TFHE.add_glwe -// ############# + assert(numberOfIterations.value() > 0); + pass.iterations /= (uint64_t)numberOfIterations.value(); + return std::nullopt; + } -static std::optional on_enter(TFHE::AddGLWEOp &op, - ExtractTFHEStatisticsPass &pass) { - auto resultingKey = op.getType().getKey().getNormalized(); + // ############# + // TFHE.add_glwe + // ############# - auto location = locationOf(op); - auto operation = PrimitiveOperation::ENCRYPTED_ADDITION; - auto keys = std::vector>(); - auto count = pass.iterations; + static std::optional on_enter(TFHE::AddGLWEOp &op, + ExtractTFHEStatisticsPass &pass) { + auto resultingKey = op.getType().getKey().getNormalized(); - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); - keys.push_back(key); + auto location = locationString(op.getLoc()); + auto operation = PrimitiveOperation::ENCRYPTED_ADDITION; + auto keys = std::vector>(); + auto count = pass.iterations; - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); - return std::nullopt; -} + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); -// ################# -// TFHE.add_glwe_int -// ################# + return std::nullopt; + } -static std::optional on_enter(TFHE::AddGLWEIntOp &op, - ExtractTFHEStatisticsPass &pass) { - auto resultingKey = op.getType().getKey().getNormalized(); + // ################# + // TFHE.add_glwe_int + // ################# - auto location = locationOf(op); - auto operation = PrimitiveOperation::CLEAR_ADDITION; - auto keys = std::vector>(); - auto count = pass.iterations; + static std::optional on_enter(TFHE::AddGLWEIntOp &op, + ExtractTFHEStatisticsPass &pass) { + auto resultingKey = op.getType().getKey().getNormalized(); - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); - keys.push_back(key); + auto location = locationString(op.getLoc()); + auto operation = PrimitiveOperation::CLEAR_ADDITION; + auto keys = std::vector>(); + auto count = pass.iterations; - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); - return std::nullopt; -} + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); -// ################### -// TFHE.bootstrap_glwe -// ################### + return std::nullopt; + } -static std::optional on_enter(TFHE::BootstrapGLWEOp &op, - ExtractTFHEStatisticsPass &pass) { - auto bsk = op.getKey(); + // ################### + // TFHE.bootstrap_glwe + // ################### - auto location = locationOf(op); - auto operation = PrimitiveOperation::PBS; - auto keys = std::vector>(); - auto count = pass.iterations; + static std::optional on_enter(TFHE::BootstrapGLWEOp &op, + ExtractTFHEStatisticsPass &pass) { + auto bsk = op.getKey(); - std::pair key = - std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); - keys.push_back(key); + auto location = locationString(op.getLoc()); + auto operation = PrimitiveOperation::PBS; + auto keys = std::vector>(); + auto count = pass.iterations; - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); + std::pair key = + std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); + keys.push_back(key); - return std::nullopt; -} + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); -// ################### -// TFHE.keyswitch_glwe -// ################### + return std::nullopt; + } -static std::optional on_enter(TFHE::KeySwitchGLWEOp &op, - ExtractTFHEStatisticsPass &pass) { - auto ksk = op.getKey(); + // ################### + // TFHE.keyswitch_glwe + // ################### - auto location = locationOf(op); - auto operation = PrimitiveOperation::KEY_SWITCH; - auto keys = std::vector>(); - auto count = pass.iterations; + static std::optional on_enter(TFHE::KeySwitchGLWEOp &op, + ExtractTFHEStatisticsPass &pass) { + auto ksk = op.getKey(); - std::pair key = - std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); - keys.push_back(key); + auto location = locationString(op.getLoc()); + auto operation = PrimitiveOperation::KEY_SWITCH; + auto keys = std::vector>(); + auto count = pass.iterations; - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); + std::pair key = + std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); + keys.push_back(key); - return std::nullopt; -} + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); -// ################# -// TFHE.mul_glwe_int -// ################# - -static std::optional on_enter(TFHE::MulGLWEIntOp &op, - ExtractTFHEStatisticsPass &pass) { - auto resultingKey = op.getType().getKey().getNormalized(); + return std::nullopt; + } - auto location = locationOf(op); - auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION; - auto keys = std::vector>(); - auto count = pass.iterations; + // ################# + // TFHE.mul_glwe_int + // ################# - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); - keys.push_back(key); + static std::optional on_enter(TFHE::MulGLWEIntOp &op, + ExtractTFHEStatisticsPass &pass) { + auto resultingKey = op.getType().getKey().getNormalized(); - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); + auto location = locationString(op.getLoc()); + auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION; + auto keys = std::vector>(); + auto count = pass.iterations; - return std::nullopt; -} + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); -// ############# -// TFHE.neg_glwe -// ############# + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); -static std::optional on_enter(TFHE::NegGLWEOp &op, - ExtractTFHEStatisticsPass &pass) { - auto resultingKey = op.getType().getKey().getNormalized(); + return std::nullopt; + } - auto location = locationOf(op); - auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; - auto keys = std::vector>(); - auto count = pass.iterations; + // ############# + // TFHE.neg_glwe + // ############# - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); - keys.push_back(key); + static std::optional on_enter(TFHE::NegGLWEOp &op, + ExtractTFHEStatisticsPass &pass) { + auto resultingKey = op.getType().getKey().getNormalized(); - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); + auto location = locationString(op.getLoc()); + auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; + auto keys = std::vector>(); + auto count = pass.iterations; - return std::nullopt; -} + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); -// ################# -// TFHE.sub_int_glwe -// ################# - -static std::optional on_enter(TFHE::SubGLWEIntOp &op, - ExtractTFHEStatisticsPass &pass) { - auto resultingKey = op.getType().getKey().getNormalized(); - - auto location = locationOf(op); - auto keys = std::vector>(); - auto count = pass.iterations; - - std::pair key = - std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); - keys.push_back(key); - - // clear - encrypted = clear + neg(encrypted) - - auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); - - operation = PrimitiveOperation::CLEAR_ADDITION; - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); - - return std::nullopt; -} + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); -// ################# -// TFHE.wop_pbs_glwe -// ################# + return std::nullopt; + } -static std::optional on_enter(TFHE::WopPBSGLWEOp &op, - ExtractTFHEStatisticsPass &pass) { - auto bsk = op.getBsk(); - auto ksk = op.getKsk(); - auto pksk = op.getPksk(); + // ################# + // TFHE.sub_int_glwe + // ################# + + static std::optional on_enter(TFHE::SubGLWEIntOp &op, + ExtractTFHEStatisticsPass &pass) { + auto resultingKey = op.getType().getKey().getNormalized(); + + auto location = locationString(op.getLoc()); + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); + + // clear - encrypted = clear + neg(encrypted) + + auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); + + operation = PrimitiveOperation::CLEAR_ADDITION; + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); + + return std::nullopt; + } - auto location = locationOf(op); - auto operation = PrimitiveOperation::WOP_PBS; - auto keys = std::vector>(); - auto count = pass.iterations; + // ################# + // TFHE.wop_pbs_glwe + // ################# - std::pair key = - std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); - keys.push_back(key); + static std::optional on_enter(TFHE::WopPBSGLWEOp &op, + ExtractTFHEStatisticsPass &pass) { + auto bsk = op.getBsk(); + auto ksk = op.getKsk(); + auto pksk = op.getPksk(); - key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); - keys.push_back(key); + auto location = locationString(op.getLoc()); + auto operation = PrimitiveOperation::WOP_PBS; + auto keys = std::vector>(); + auto count = pass.iterations; - key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex()); - keys.push_back(key); + std::pair key = + std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); + keys.push_back(key); - pass.feedback.statistics.push_back(Statistic{ - location, - operation, - keys, - count, - }); + key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); + keys.push_back(key); - return std::nullopt; -} + key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex()); + keys.push_back(key); -// ######## -// Dispatch -// ######## + pass.feedback.statistics.push_back(concretelang::Statistic{ + location, + operation, + keys, + count, + }); -#define DISPATCH_ENTER(type) \ - if (auto typedOp = llvm::dyn_cast(op)) { \ - std::optional error = on_enter(typedOp, *this); \ - if (error.has_value()) { \ - return error; \ - } \ + return std::nullopt; } -#define DISPATCH_EXIT(type) \ - if (auto typedOp = llvm::dyn_cast(op)) { \ - std::optional error = on_exit(typedOp, *this); \ - if (error.has_value()) { \ - return error; \ - } \ - } + size_t iterations = 1; +}; -std::optional -ExtractTFHEStatisticsPass::enter(mlir::Operation *op) { - DISPATCH_ENTER(scf::ForOp) - DISPATCH_ENTER(TFHE::AddGLWEOp) - DISPATCH_ENTER(TFHE::AddGLWEIntOp) - DISPATCH_ENTER(TFHE::BootstrapGLWEOp) - DISPATCH_ENTER(TFHE::KeySwitchGLWEOp) - DISPATCH_ENTER(TFHE::MulGLWEIntOp) - DISPATCH_ENTER(TFHE::NegGLWEOp) - DISPATCH_ENTER(TFHE::SubGLWEIntOp) - DISPATCH_ENTER(TFHE::WopPBSGLWEOp) - return std::nullopt; -} +} // namespace TFHE -std::optional -ExtractTFHEStatisticsPass::exit(mlir::Operation *op) { - DISPATCH_EXIT(scf::ForOp) - return std::nullopt; +std::unique_ptr> +createStatisticExtractionPass(CompilationFeedback &feedback) { + return std::make_unique(feedback); } + +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 2844d7ba23..e83425cbab 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -332,7 +332,7 @@ extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module, pipelinePrinting("TFHEStatistics", pm, context); addPotentiallyNestedPass( - pm, std::make_unique(feedback), + pm, mlir::concretelang::createStatisticExtractionPass(feedback), enablePass); return pm.run(module.getOperation()); @@ -358,7 +358,7 @@ computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module, pipelinePrinting("Computing Memory Usage", pm, context); addPotentiallyNestedPass( - pm, std::make_unique(feedback), enablePass); + pm, mlir::concretelang::createMemoryUsagePass(feedback), enablePass); return pm.run(module.getOperation()); } From 8b74e2546b939a799bf68adb3c8989191168ec1a Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 30 Aug 2023 09:55:11 +0100 Subject: [PATCH 2/2] refactor(frontend): change location format --- .../concrete/compiler/compilation_feedback.py | 15 +++++++++++---- .../concrete-python/concrete/fhe/mlir/context.py | 12 +++++++++--- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index f3b99b90bd..22990bf3c7 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -3,6 +3,7 @@ """Compilation feedback.""" +import re from typing import Dict, Set # pylint: disable=no-name-in-module,import-error,too-many-instance-attributes @@ -19,6 +20,10 @@ from .wrapper import WrapperCpp +# matches (@tag, separator( | ), filename) +REGEX_LOCATION = r"loc\(\"(@[\w\.]+)?( \| )?(.+)\"" + + class CompilationFeedback(WrapperCpp): """CompilationFeedback is a set of hint computed by the compiler engine.""" @@ -130,8 +135,9 @@ def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int if statistic.operation not in operations: continue - file_and_maybe_tag = statistic.location.split("@") - tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip() + tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups() + # remove the @ + tag = tag[1:] if tag else "" tag_components = tag.split(".") for i in range(1, len(tag_components) + 1): @@ -176,8 +182,9 @@ def count_per_tag_per_parameter( if statistic.operation not in operations: continue - file_and_maybe_tag = statistic.location.split("@") - tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip() + tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups() + # remove the @ + tag = tag[1:] if tag else "" tag_components = tag.split(".") for i in range(1, len(tag_components) + 1): diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 7b8a0d7247..54a3730404 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -116,9 +116,15 @@ def location(self) -> MlirLocation: """ Create an MLIR location from the node that is being converted. """ - - tag = "" if self.converting.tag == "" else f" @ {self.converting.tag}" - return MlirLocation.name(f"{self.converting.location}{tag}", context=self.context) + filename, lineno = self.converting.location.rsplit(":", maxsplit=1) + + tag = "" if self.converting.tag == "" else f"@{self.converting.tag} | " + return MlirLocation.file( + f"{tag}{self.converting.location}", + line=int(lineno), + col=0, + context=self.context, + ) def attribute(self, resulting_type: ConversionType, value: Any) -> MlirAttribute: """