Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve explicit optimizer partition frontiers for TFHE circuit parametrization #702

Merged
merged 6 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,63 @@ class TypeInferenceAnalysisBase : public AnalysisT {
}
}

// Prints an indentation composed of `indent` times `" "`.
void printIndent(int indent) {
for (int i = 0; i < indent; i++)
llvm::dbgs() << " ";
}

// Dumps the state of type inference for the operation `op` with an
// indentation level of `indent` as the name of the operation,
// followed by the types inferred for each operand, followed by
// `->`, followed by a dump of the state for any operation nested in
// any region of `op`.
void dumpStateForOp(mlir::Operation *op, int indent) {
const LocalInferenceState state = getCurrentInferredTypes(op);

printIndent(indent);
llvm::dbgs() << op->getName() << " {";

llvm::interleaveComma(
op->getAttrs(), llvm::dbgs(), [&](const mlir::NamedAttribute &attr) {
llvm::dbgs() << attr.getName() << " = " << attr.getValue();
});

llvm::dbgs() << "} : (";

llvm::interleaveComma(op->getOperands(), llvm::dbgs(), [&](mlir::Value v) {
llvm::dbgs() << state.find(v);
});

llvm::dbgs() << ") -> (";

llvm::interleaveComma(op->getResults(), llvm::dbgs(), [&](mlir::Value v) {
llvm::dbgs() << state.find(v);
});

llvm::dbgs() << ")\n";

for (mlir::Region &r : op->getRegions())
for (mlir::Block &b : r.getBlocks())
for (mlir::Operation &childOp : b.getOperations())
dumpStateForOp(&childOp, indent + 1);
}

// Dumps the entire state of type inference for the function
// containing the operation `op`. For each operation, this prints
// the name of the operation, followed by the types inferred for
// each operand, followed by `->`, followed by the types inferred
// for the results.
void dumpAllState(mlir::Operation *op) {
mlir::Operation *funcOp = op;
while (funcOp && !llvm::isa<mlir::func::FuncOp>(funcOp))
funcOp = funcOp->getParentOp();

assert(funcOp);

dumpStateForOp(funcOp, 0);
}

TypeResolver &resolver;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ add_subdirectory(RT)
add_subdirectory(SDFG)
add_subdirectory(Tracing)
add_subdirectory(TypeInference)
add_subdirectory(Optimizer)
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(DynamicTLU)
add_subdirectory(BigInt)
add_subdirectory(Boolean)
add_subdirectory(Max)
add_subdirectory(Optimizer)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Optimizer.td)
mlir_tablegen(Optimizer.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEOptimizerPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEOptimizerPassIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES_H
#define CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES_H

#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Pass/Pass.h>

#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <concretelang/Dialect/Optimizer/IR/OptimizerDialect.h>
#include <concretelang/Dialect/Optimizer/IR/OptimizerOps.h>
#include <concretelang/Support/V0Parameters.h>

#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/Optimizer/Optimizer.h.inc>

namespace mlir {
namespace concretelang {

std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createOptimizerPartitionFrontierMaterializationPass(
const optimizer::CircuitSolution &solverSolution);

} // namespace concretelang
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES
#define CONCRETELANG_DIALECT_OPTIMIZER_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def OptimizerPartitionFrontierMaterializationPass
: Pass<"optimizer-partition-frontier-materialization",
"::mlir::func::FuncOp"> {
let summary =
"Inserts Optimizer.partition_frontier operations between FHE operations "
"that were explicitly marked by the optimizer as belonging to separate "
"partitions via an extra conversion key in the optimizer solution.";

let constructor = "mlir::concretelang::"
"createOptimizerPartitionFrontierMaterializationPass()";
let options = [];
let dependentDialects = [
"mlir::concretelang::FHE::FHEDialect",
"mlir::concretelang::Optimizer::OptimizerDialect"
];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
set(LLVM_TARGET_DEFINITIONS OptimizerOps.td)
mlir_tablegen(OptimizerOps.h.inc -gen-op-decls)
mlir_tablegen(OptimizerOps.cpp.inc -gen-op-defs)
mlir_tablegen(OptimizerOpsDialect.h.inc -gen-dialect-decls -dialect=Optimizer)
mlir_tablegen(OptimizerOpsDialect.cpp.inc -gen-dialect-defs -dialect=Optimizer)
add_public_tablegen_target(MLIROptimizerOpsIncGen)
add_dependencies(mlir-headers MLIROptimizerOpsIncGen)

add_concretelang_doc(OptimizerOps OptimizerDialect concretelang/ -gen-dialect-doc -dialect=Optimizer)
add_concretelang_doc(OptimizerOps OptimizerOps concretelang/ -gen-op-doc)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZERDIALECT_H
#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZERDIALECT_H

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"

#include "concretelang/Dialect/Optimizer/IR/OptimizerOpsDialect.h.inc"

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- OptimizerDialect.td - Optimizer dialect ----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_DIALECT
#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_DIALECT

include "mlir/IR/OpBase.td"

def Optimizer_Dialect : Dialect {
let name = "Optimizer";
let summary = "Auxiliary operations for the interaction with the optimizer";
let cppNamespace = "::mlir::concretelang::Optimizer";
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZEROPS_H
#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZEROPS_H

#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>

#define GET_OP_CLASSES
#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h.inc"

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- OptimizerOps.td - Optimizer dialect ops ----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_OPS
#define CONCRETELANG_DIALECT_OPTIMIZER_IR_OPTIMIZER_OPS

include "mlir/Interfaces/SideEffectInterfaces.td"
include "concretelang/Dialect/Optimizer/IR/OptimizerDialect.td"

class Optimizer_Op<string mnemonic, list<Trait> traits = []> :
Op<Optimizer_Dialect, mnemonic, traits>;

def Optimizer_PartitionFrontierOp : Optimizer_Op<"partition_frontier", [Pure]> {
let summary = "Models an explicit edge between two partitions";

let description = [{
Models an explicit edge between two partitions in the solution
determined by the optimizer requiring a key change between the
encrypted values of the operand and the encrypted values of
the result.
}];

let arguments = (ins
AnyType:$input,
I64Attr:$inputKeyID,
I32Attr:$outputKeyID
);

let results = (outs AnyType);
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ class CompilerEngine {
llvm::Expected<std::optional<optimizer::Description>>
getConcreteOptimizerDescription(CompilationResult &res);
llvm::Error determineFHEParameters(CompilationResult &res);
mlir::LogicalResult
materializeOptimizerPartitionFrontiers(CompilationResult &res);
};

} // namespace concretelang
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ namespace pipeline {
mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult materializeOptimizerPartitionFrontiers(
mlir::MLIRContext &context, mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);

llvm::Expected<std::map<std::string, std::optional<optimizer::Description>>>
getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
optimizer::Config config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ add_mlir_dialect_library(
DEPENDS
FHEDialect
FHELinalgDialect
OptimizerDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
FHEDialect
FHELinalgDialect)
FHELinalgDialect
OptimizerDialect)

target_link_libraries(FHEDialect PUBLIC MLIRIR)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
#include "concretelang/Dialect/Optimizer/IR/OptimizerOps.h"
#include "concretelang/Support/Constants.h"
#include "concretelang/Support/logging.h"

Expand Down Expand Up @@ -1936,6 +1937,77 @@ struct FHELinalgUnaryOpToLinalgGeneric
};
};

// Replaces a `optimizer.partition_frontier` operation with a tensor
// operand and a tensor result with a `linalg.generic` operation
// applying a `optimizer.partition_frontier` operation with scalar
// operands.
struct TensorPartitionFrontierOpToLinalgGeneric
: public mlir::OpRewritePattern<
mlir::concretelang::Optimizer::PartitionFrontierOp> {
TensorPartitionFrontierOpToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<
mlir::concretelang::Optimizer::PartitionFrontierOp>(context,
benefit) {}

::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Optimizer::PartitionFrontierOp pfOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
pfOp.getResult().getType().cast<mlir::RankedTensorType>();
mlir::RankedTensorType tensorTy =
pfOp.getInput().getType().cast<mlir::RankedTensorType>();

mlir::Value init = rewriter.create<mlir::tensor::EmptyOp>(
pfOp.getLoc(), resultTy, mlir::ValueRange{});

// Create affine maps and iterator types for an embarassingly
// parallel op
llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getMultiDimIdentityMap(tensorTy.getShape().size(),
this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(),
this->getContext()),
};

llvm::SmallVector<mlir::utils::IteratorType> iteratorTypes(
resultTy.getShape().size(), mlir::utils::IteratorType::parallel);

// Create the body of the `linalg.generic` op applying a
// `tensor.partition_frontier` op on the scalar arguments
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::concretelang::Optimizer::PartitionFrontierOp scalarOp =
nestedBuilder
.create<mlir::concretelang::Optimizer::PartitionFrontierOp>(
pfOp.getLoc(), resultTy.getElementType(), blockArgs[0],
pfOp->getAttrs());

nestedBuilder.create<mlir::linalg::YieldOp>(pfOp.getLoc(),
scalarOp.getResult());
};

// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{pfOp.getInput()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};

mlir::linalg::GenericOp genericOp =
rewriter.create<mlir::linalg::GenericOp>(pfOp.getLoc(), resTypes, ins,
outs, maps, iteratorTypes, doc,
call, bodyBuilder);

rewriter.replaceOp(pfOp, {genericOp.getResult(0)});

return ::mlir::success();
};
};

namespace {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
Expand All @@ -1956,6 +2028,13 @@ void FHETensorOpsToLinalg::runOnOperation() {
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();

target.addDynamicallyLegalOp<
mlir::concretelang::Optimizer::PartitionFrontierOp>(
[&](mlir::concretelang::Optimizer::PartitionFrontierOp op) {
return !op.getInput().getType().isa<mlir::RankedTensorType>() &&
!op.getResult().getType().isa<mlir::RankedTensorType>();
});

mlir::RewritePatternSet patterns(&getContext());

patterns.insert<DotToLinalgGeneric<mlir::concretelang::FHELinalg::Dot,
Expand Down Expand Up @@ -2109,6 +2188,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
patterns.insert<FHELinalgMaxpool2dToLinalgMaxpool2d>(&getContext());
patterns.insert<TransposeToLinalgGeneric>(&getContext());
patterns.insert<FromElementToTensorFromElements>(&getContext());
patterns.insert<TensorPartitionFrontierOpToLinalgGeneric>(&getContext());

if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_mlir_dialect_library(
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
FHEDialect
OptimizerDialect
mlir-headers
LINK_LIBS
PUBLIC
Expand Down
Loading
Loading