Skip to content

Commit

Permalink
feat(compiler/frontend): add flag to enable/disable overflow detection
Browse files Browse the repository at this point in the history
in simulation
  • Loading branch information
youben11 committed May 23, 2024
1 parent 3740d38 commit 32cf4b2
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct CompilationOptions {
/// Simulate options
bool simulate;

/// Enable overflow detection during simulation
bool enableOverflowDetectionInSimulation;

/// Parallelization options
bool autoParallelize;
bool loopParallelize;
Expand Down Expand Up @@ -110,7 +113,7 @@ struct CompilationOptions {
CompilationOptions()
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
/// Simulate options
simulate(false),
simulate(false), enableOverflowDetectionInSimulation(false),
// Parallelization options
autoParallelize(false), loopParallelize(true),
dataflowParallelize(false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ transformTFHEOperations(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
bool enableOverflowDetection,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def("set_print_tlu_fusing",
[](CompilationOptions &options, bool printTluFusing) {
options.printTluFusing = printTluFusing;
})
.def("set_enable_overflow_detection_in_simulation",
[](CompilationOptions &options, bool enableOverflowDetection) {
options.enableOverflowDetectionInSimulation =
enableOverflowDetection;
});

pybind11::enum_<mlir::concretelang::PrimitiveOperation>(m,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,20 @@ def set_print_tlu_fusing(self, print_tlu_fusing: bool):
if not isinstance(print_tlu_fusing, bool):
raise TypeError("need to pass a boolean value")
self.cpp().set_print_tlu_fusing(print_tlu_fusing)

def set_enable_overflow_detection_in_simulation(
self, enable_overflow_detection: bool
):
"""Enable or disable overflow detection during simulation.
Args:
enable_overflow_detection (bool): flag to enable or disable overflow detection
Raises:
TypeError: if the value to set is not bool
"""
if not isinstance(enable_overflow_detection, bool):
raise TypeError("need to pass a boolean value")
self.cpp().set_enable_overflow_detection_in_simulation(
enable_overflow_detection
)
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,

if (options.simulate) {
if (mlir::concretelang::pipeline::simulateTFHE(
mlirContext, module, res.fheContext, this->enablePass)
mlirContext, module, res.fheContext,
options.enableOverflowDetectionInSimulation, this->enablePass)
.failed()) {
return StreamStringError("Simulating TFHE failed");
}
Expand Down
4 changes: 2 additions & 2 deletions compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,13 @@ transformTFHEOperations(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
bool enableOverflowDetection,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);

// we want to disable overflow detection if CRT is used (overflow would be
// expected)
bool enableOverflowDetection = true;
if (fheContext) {
if (fheContext && enableOverflowDetection) {
auto solution = fheContext.value().solution;
auto optCrt = getCrtDecompositionFromSolution(solution);
if (optCrt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def compile_run_assert(
):
# compile with simulation
options.simulation(True)
options.set_enable_overflow_detection_in_simulation(True)
compilation_result = engine.compile(mlir_input, options)
result = run_simulated(engine, args_and_shape, compilation_result)
assert_result(result, expected_result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ class Configuration:
enable_tlu_fusing: bool
print_tlu_fusing: bool
optimize_tlu_based_on_original_bit_width: Union[bool, int]
detect_overflow_in_simulation: bool

def __init__(
self,
Expand Down Expand Up @@ -1055,6 +1056,7 @@ def __init__(
enable_tlu_fusing: bool = True,
print_tlu_fusing: bool = False,
optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8,
detect_overflow_in_simulation=False,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1155,6 +1157,8 @@ def __init__(

self.optimize_tlu_based_on_original_bit_width = optimize_tlu_based_on_original_bit_width

self.detect_overflow_in_simulation = detect_overflow_in_simulation

self._validate()

class Keep:
Expand Down Expand Up @@ -1198,14 +1202,17 @@ def fork(
compiler_debug_mode: Union[Keep, bool] = KEEP,
compiler_verbose_mode: Union[Keep, bool] = KEEP,
comparison_strategy_preference: Union[
Keep, Optional[Union[ComparisonStrategy, str, List[Union[ComparisonStrategy, str]]]]
Keep,
Optional[Union[ComparisonStrategy, str, List[Union[ComparisonStrategy, str]]]],
] = KEEP,
bitwise_strategy_preference: Union[
Keep, Optional[Union[BitwiseStrategy, str, List[Union[BitwiseStrategy, str]]]]
Keep,
Optional[Union[BitwiseStrategy, str, List[Union[BitwiseStrategy, str]]]],
] = KEEP,
shifts_with_promotion: Union[Keep, bool] = KEEP,
multivariate_strategy_preference: Union[
Keep, Optional[Union[MultivariateStrategy, str, List[Union[MultivariateStrategy, str]]]]
Keep,
Optional[Union[MultivariateStrategy, str, List[Union[MultivariateStrategy, str]]]],
] = KEEP,
min_max_strategy_preference: Union[
Keep, Optional[Union[MinMaxStrategy, str, List[Union[MinMaxStrategy, str]]]]
Expand All @@ -1223,6 +1230,7 @@ def fork(
enable_tlu_fusing: Union[Keep, bool] = KEEP,
print_tlu_fusing: Union[Keep, bool] = KEEP,
optimize_tlu_based_on_original_bit_width: Union[Keep, bool, int] = KEEP,
detect_overflow_in_simulation: bool = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
10 changes: 9 additions & 1 deletion frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def create(
options.set_compress_evaluation_keys(configuration.compress_evaluation_keys)
options.set_compress_input_ciphertexts(configuration.compress_input_ciphertexts)
options.set_composable(configuration.composable)
options.set_enable_overflow_detection_in_simulation(
configuration.detect_overflow_in_simulation
)

if configuration.auto_parallelize or configuration.dataflow_parallelize:
# pylint: disable=c-extension-no-member,no-member
Expand Down Expand Up @@ -319,7 +322,12 @@ def load(path: Union[str, Path]) -> "Server":
server_program = ServerProgram.load(support, is_simulated)

return Server(
client_specs, output_dir, support, compilation_result, server_program, is_simulated
client_specs,
output_dir,
support,
compilation_result,
server_program,
is_simulated,
)

def run(
Expand Down

0 comments on commit 32cf4b2

Please sign in to comment.