Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(compiler): clean statistic passes #551

Merged
merged 2 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_ANALYSIS_UTILS_H
#define CONCRETELANG_ANALYSIS_UTILS_H

#include <boost/outcome.h>
#include <concretelang/Common/Error.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Location.h>

namespace mlir {
namespace concretelang {

/// Get the string representation of a location
std::string locationString(mlir::Location loc);

/// Compute the number of iterations based on loop info
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step);

/// Compute the number of iterations of an scf for loop
outcome::checked<int64_t, ::concretelang::error::StringError>
calculateNumberOfIterations(scf::ForOp &op);

} // namespace concretelang
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_DIALECT_CONCRETE_ANALYSIS
#define CONCRETELANG_DIALECT_CONCRETE_ANALYSIS

include "mlir/Pass/PassBase.td"

def MemoryUsage : Pass<"MemoryUsage", "::mlir::ModuleOp"> {
let summary = "Compute memory usage";
let description = [{
Computes memory usage per location, and provides those numbers throught the CompilationFeedback.
}];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Analysis.td)
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
add_public_tablegen_target(ConcretelangConcreteAnalysisPassIncGen)
add_dependencies(mlir-headers ConcretelangConcreteAnalysisPassIncGen)
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,16 @@
#define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>

#include <concretelang/Support/CompilationFeedback.h>

namespace mlir {
namespace concretelang {
namespace Concrete {

struct MemoryUsagePass
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createMemoryUsagePass(CompilationFeedback &feedback);

CompilationFeedback &feedback;

MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};

void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (walk.wasInterrupted()) {
signalPassFailure();
}
}

std::optional<StringError> enter(Operation *op);

std::optional<StringError> exit(Operation *op);

std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;

size_t iterations = 1;
};

} // namespace Concrete
} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef CONCRETELANG_DIALECT_TFHE_ANALYSIS
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS

include "mlir/Pass/PassBase.td"

def ExtractStatistics : Pass<"ExtractStatistics", "::mlir::ModuleOp"> {
let summary = "Extracts statistics";
let description = [{
Extracts different statistics (e.g. number of certain crypto operations),
and provides those numbers throught the CompilationFeedback.
}];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Analysis.td)
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
add_public_tablegen_target(ConcretelangTFHEAnalysisPassIncGen)
add_dependencies(mlir-headers ConcretelangTFHEAnalysisPassIncGen)
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,15 @@
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS_EXTRACT_STATISTICS_H

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>

#include <concretelang/Support/CompilationFeedback.h>

namespace mlir {
namespace concretelang {
namespace TFHE {

struct ExtractTFHEStatisticsPass
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {

CompilationFeedback &feedback;

ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
: feedback{feedback} {};

void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (walk.wasInterrupted()) {
signalPassFailure();
}
}

std::optional<StringError> enter(Operation *op);

std::optional<StringError> exit(Operation *op);

size_t iterations = 1;
};

} // namespace TFHE
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createStatisticExtractionPass(CompilationFeedback &feedback);
} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_mlir_library(
AnalysisUtils
Utils.cpp
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR)
67 changes: 67 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <concretelang/Analysis/Utils.h>
#include <mlir/Dialect/Arith/IR/Arith.h>

using ::concretelang::error::StringError;

namespace mlir {
namespace concretelang {
std::string locationString(mlir::Location loc) {
auto location = std::string();
auto locationStream = llvm::raw_string_ostream(location);
loc->print(locationStream);
return location;
}

int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
int64_t high;
int64_t low;

if (step > 0) {
low = start;
high = stop;
} else {
low = stop;
high = start;
step = -step;
}

if (low >= high) {
return 0;
}

return ((high - low - 1) / step) + 1;
}

outcome::checked<int64_t, StringError>
calculateNumberOfIterations(scf::ForOp &op) {
mlir::Value startValue = op.getLowerBound();
mlir::Value stopValue = op.getUpperBound();
mlir::Value stepValue = op.getStep();

auto startOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
auto stopOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
auto stepOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());

if (!startOp || !stopOp || !stepOp) {
return StringError("only static loops can be analyzed");
}

auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();

if (!startOp || !stopOp || !stepOp) {
return StringError("only integer loops can be analyzed");
}

int64_t start = startAttr.getInt();
int64_t stop = stopAttr.getInt();
int64_t step = stepAttr.getInt();

return calculateNumberOfIterations(start, stop, step);
}
} // namespace concretelang
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Compilation feedback."""

import re
from typing import Dict, Set

# pylint: disable=no-name-in-module,import-error,too-many-instance-attributes
Expand All @@ -19,6 +20,10 @@
from .wrapper import WrapperCpp


# matches (@tag, separator( | ), filename)
REGEX_LOCATION = r"loc\(\"(@[\w\.]+)?( \| )?(.+)\""


class CompilationFeedback(WrapperCpp):
"""CompilationFeedback is a set of hint computed by the compiler engine."""

Expand Down Expand Up @@ -130,8 +135,9 @@ def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int
if statistic.operation not in operations:
continue

file_and_maybe_tag = statistic.location.split("@")
tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip()
tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups()
# remove the @
tag = tag[1:] if tag else ""

tag_components = tag.split(".")
for i in range(1, len(tag_components) + 1):
Expand Down Expand Up @@ -176,8 +182,9 @@ def count_per_tag_per_parameter(
if statistic.operation not in operations:
continue

file_and_maybe_tag = statistic.location.split("@")
tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip()
tag, _, _ = re.match(REGEX_LOCATION, statistic.location).groups()
# remove the @
tag = tag[1:] if tag else ""

tag_components = tag.split(".")
for i in range(1, len(tag_components) + 1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &[
const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &["LLVMX86CodeGen", "LLVMX86Desc", "LLVMX86Info"];

const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
"AnalysisUtils",
"RTDialect",
"RTDialectTransforms",
"ConcretelangSupport",
Expand Down
1 change: 1 addition & 0 deletions compilers/concrete-compiler/compiler/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(Transforms)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ add_mlir_library(
LINK_LIBS
PUBLIC
MLIRIR
ConcreteDialect)
ConcreteDialect
AnalysisUtils)
Loading