Skip to content

Commit

Permalink
feat(frontend-python): multi parameter strategy in Configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Sep 25, 2023
1 parent 2228075 commit 866b0a2
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.value("DAG_MULTI", optimizer::Strategy::DAG_MULTI)
.export_values();

pybind11::enum_<concrete_optimizer::MultiParamStrategy>(m, "OptimizerMultiParamStrategy")
.value("PRECISION", concrete_optimizer::MultiParamStrategy::ByPrecision)
.value("PRECISION_AND_NORM2", concrete_optimizer::MultiParamStrategy::ByPrecisionAndNorm2)
.export_values();

pybind11::enum_<concrete_optimizer::Encoding>(m, "Encoding")
.value("AUTO", concrete_optimizer::Encoding::Auto)
.value("CRT", concrete_optimizer::Encoding::Crt)
Expand Down Expand Up @@ -103,6 +108,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CompilationOptions &options, optimizer::Strategy strategy) {
options.optimizerConfig.strategy = strategy;
})
.def("set_optimizer_multi_parameter_strategy",
[](CompilationOptions &options, concrete_optimizer::MultiParamStrategy strategy) {
options.optimizerConfig.multi_param_strategy = strategy;
})
.def("set_global_p_error",
[](CompilationOptions &options, double global_p_error) {
options.optimizerConfig.global_p_error = global_p_error;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mlir._mlir_libs._concretelang._compiler import (
CompilationOptions as _CompilationOptions,
OptimizerStrategy as _OptimizerStrategy,
OptimizerMultiParamStrategy as _OptimizerMultiParamStrategy,
Encoding,
)
from .wrapper import WrapperCpp
Expand Down Expand Up @@ -177,12 +178,25 @@ def set_optimizer_strategy(self, strategy: _OptimizerStrategy):
strategy (OptimizerStrategy): Use the specified optmizer strategy.
Raises:
TypeError: if the value is not a bool
TypeError: if the value is not an OptimizerStrategy
"""
if not isinstance(strategy, _OptimizerStrategy):
raise TypeError("enable should be a bool")
self.cpp().set_optimizer_strategy(strategy)

def set_optimizer_multi_parameter_strategy(self, strategy: _OptimizerMultiParamStrategy):
"""Set the strategy of the optimizer for multi-parameter.
Args:
strategy (OptimizerMultiParameteStrategy): Use the specified optmizer strategy.
Raises:
TypeError: if the value is not a OptimizerMultiParameteStrategy
"""
if not isinstance(strategy, _OptimizerMultiParamStrategy):
raise TypeError("enable should be a bool")
self.cpp().set_optimizer_multi_parameter_strategy(strategy)

def set_global_p_error(self, global_p_error: float):
"""Set global error probability for the full circuit.
Expand Down
1 change: 1 addition & 0 deletions frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DebugArtifacts,
EncryptionStatus,
Keys,
MultiParamStrategy,
ParameterSelectionStrategy,
Server,
Value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BitwiseStrategy,
ComparisonStrategy,
Configuration,
MultiParamStrategy,
ParameterSelectionStrategy,
)
from .keys import Keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ def parse(cls, string: str) -> "ParameterSelectionStrategy":
)
raise ValueError(message)

class MultiParamStrategy(str, Enum):
"""
MultiParamStrategy, to set optimization strategy for mutli-parameter.
"""

PRECISION = "precision"
PRECISION_AND_NORM2 = "precision_and_norm2"

@classmethod
def parse(cls, string: str) -> "MultiParamStrategy":
"""Convert a string to a MultiParamStrategy."""
if isinstance(string, cls):
return string
if not isinstance(string, str):
message = f"{string} cannot be parsed to a {cls.__name__}"
raise TypeError(message)
for value in MultiParamStrategy:
if string.lower().replace("-", "_") == value.value:
return value
message = (
f"'{string}' is not a valid '{friendly_type_format(cls)}' ("
f"{', '.join(v.value for v in MultiParamStrategy)})"
)
raise ValueError(message)

class ComparisonStrategy(str, Enum):
"""
Expand Down Expand Up @@ -639,6 +663,7 @@ class Configuration:
auto_adjust_rounders: bool
single_precision: bool
parameter_selection_strategy: ParameterSelectionStrategy
multi_parameter_strategy: ParameterSelectionStrategy
show_progress: bool
progress_title: str
progress_tag: Union[bool, int]
Expand Down Expand Up @@ -673,6 +698,9 @@ def __init__(
parameter_selection_strategy: Union[
ParameterSelectionStrategy, str
] = ParameterSelectionStrategy.MULTI,
multi_parameter_strategy: Union[
MultiParamStrategy, str
] = MultiParamStrategy.PRECISION,
show_progress: bool = False,
progress_title: str = "",
progress_tag: Union[bool, int] = False,
Expand Down Expand Up @@ -714,6 +742,9 @@ def __init__(
self.parameter_selection_strategy = ParameterSelectionStrategy.parse(
parameter_selection_strategy
)
self.multi_parameter_strategy = MultiParamStrategy.parse(
multi_parameter_strategy
)
self.show_progress = show_progress
self.progress_title = progress_title
self.progress_tag = progress_tag
Expand Down Expand Up @@ -768,6 +799,7 @@ def fork(
auto_adjust_rounders: Union[Keep, bool] = KEEP,
single_precision: Union[Keep, bool] = KEEP,
parameter_selection_strategy: Union[Keep, Union[ParameterSelectionStrategy, str]] = KEEP,
multi_parameter_strategy: Union[Keep, Union[MultiParamStrategy, str]] = KEEP,
show_progress: Union[Keep, bool] = KEEP,
progress_title: Union[Keep, str] = KEEP,
progress_tag: Union[Keep, Union[bool, int]] = KEEP,
Expand Down
10 changes: 9 additions & 1 deletion frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
set_compiler_logging,
set_llvm_debug_flag,
)
from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerStrategy, PrimitiveOperation
from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerMultiParamStrategy, OptimizerStrategy, PrimitiveOperation
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
Configuration,
MultiParamStrategy,
ParameterSelectionStrategy,
)
from .specs import ClientSpecs
Expand Down Expand Up @@ -161,6 +162,13 @@ def create(
options.set_optimizer_strategy(OptimizerStrategy.DAG_MONO)
elif parameter_selection_strategy == ParameterSelectionStrategy.MULTI: # pragma: no cover
options.set_optimizer_strategy(OptimizerStrategy.DAG_MULTI)

multi_parameter_strategy = configuration.multi_parameter_strategy
converter = {
MultiParamStrategy.PRECISION: OptimizerMultiParamStrategy.PRECISION,
MultiParamStrategy.PRECISION_AND_NORM2: OptimizerMultiParamStrategy.PRECISION_AND_NORM2
}
options.set_optimizer_multi_parameter_strategy(converter[multi_parameter_strategy])
try:
if configuration.compiler_debug_mode: # pragma: no cover
set_llvm_debug_flag(True)
Expand Down

1 comment on commit 866b0a2

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 866b0a2 Previous: dcf7329 Ratio
v0 PBS table generation 88277734 ns/iter (± 188736) 88480005 ns/iter (± 427124) 1.00
v0 PBS simulate dag table generation 55062759 ns/iter (± 118744) 54397444 ns/iter (± 131472) 1.01
v0 WoP-PBS table generation 137355919 ns/iter (± 1754476) 144758715 ns/iter (± 156850) 0.95

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.