Skip to content

Commit

Permalink
feat(frontend-python): multi parameter strategy in Configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Sep 25, 2023
1 parent 556eeac commit 0f77a1e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.value("DAG_MULTI", optimizer::Strategy::DAG_MULTI)
.export_values();

pybind11::enum_<concrete_optimizer::MultiParamStrategy>(
m, "OptimizerMultiParamStrategy")
.value("PRECISION", concrete_optimizer::MultiParamStrategy::ByPrecision)
.value("PRECISION_AND_NORM2",
concrete_optimizer::MultiParamStrategy::ByPrecisionAndNorm2)
.export_values();

pybind11::enum_<concrete_optimizer::Encoding>(m, "Encoding")
.value("AUTO", concrete_optimizer::Encoding::Auto)
.value("CRT", concrete_optimizer::Encoding::Crt)
Expand Down Expand Up @@ -103,6 +110,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CompilationOptions &options, optimizer::Strategy strategy) {
options.optimizerConfig.strategy = strategy;
})
.def("set_optimizer_multi_parameter_strategy",
[](CompilationOptions &options,
concrete_optimizer::MultiParamStrategy strategy) {
options.optimizerConfig.multi_param_strategy = strategy;
})
.def("set_global_p_error",
[](CompilationOptions &options, double global_p_error) {
options.optimizerConfig.global_p_error = global_p_error;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mlir._mlir_libs._concretelang._compiler import (
CompilationOptions as _CompilationOptions,
OptimizerStrategy as _OptimizerStrategy,
OptimizerMultiParamStrategy as _OptimizerMultiParamStrategy,
Encoding,
)
from .wrapper import WrapperCpp
Expand Down Expand Up @@ -177,12 +178,27 @@ def set_optimizer_strategy(self, strategy: _OptimizerStrategy):
strategy (OptimizerStrategy): Use the specified optmizer strategy.
Raises:
TypeError: if the value is not a bool
TypeError: if the value is not an OptimizerStrategy
"""
if not isinstance(strategy, _OptimizerStrategy):
raise TypeError("enable should be a bool")
self.cpp().set_optimizer_strategy(strategy)

def set_optimizer_multi_parameter_strategy(
self, strategy: _OptimizerMultiParamStrategy
):
"""Set the strategy of the optimizer for multi-parameter.
Args:
strategy (OptimizerMultiParameteStrategy): Use the specified optmizer strategy.
Raises:
TypeError: if the value is not a OptimizerMultiParameteStrategy
"""
if not isinstance(strategy, _OptimizerMultiParamStrategy):
raise TypeError("enable should be a bool")
self.cpp().set_optimizer_multi_parameter_strategy(strategy)

def set_global_p_error(self, global_p_error: float):
"""Set global error probability for the full circuit.
Expand Down
1 change: 1 addition & 0 deletions frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DebugArtifacts,
EncryptionStatus,
Keys,
MultiParamStrategy,
ParameterSelectionStrategy,
Server,
Value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BitwiseStrategy,
ComparisonStrategy,
Configuration,
MultiParamStrategy,
ParameterSelectionStrategy,
)
from .keys import Keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,32 @@ def parse(cls, string: str) -> "ParameterSelectionStrategy":
raise ValueError(message)


class MultiParamStrategy(str, Enum):
"""
MultiParamStrategy, to set optimization strategy for mutli-parameter.
"""

PRECISION = "precision"
PRECISION_AND_NORM2 = "precision_and_norm2"

@classmethod
def parse(cls, string: str) -> "MultiParamStrategy":
"""Convert a string to a MultiParamStrategy."""
if isinstance(string, cls):
return string
if not isinstance(string, str):
message = f"{string} cannot be parsed to a {cls.__name__}"
raise TypeError(message)
for value in MultiParamStrategy:
if string.lower().replace("-", "_") == value.value:
return value
message = (
f"'{string}' is not a valid '{friendly_type_format(cls)}' ("
f"{', '.join(v.value for v in MultiParamStrategy)})"
)
raise ValueError(message)


class ComparisonStrategy(str, Enum):
"""
ComparisonStrategy, to specify implementation preference for comparisons.
Expand Down Expand Up @@ -639,6 +665,7 @@ class Configuration:
auto_adjust_rounders: bool
single_precision: bool
parameter_selection_strategy: ParameterSelectionStrategy
multi_parameter_strategy: ParameterSelectionStrategy
show_progress: bool
progress_title: str
progress_tag: Union[bool, int]
Expand Down Expand Up @@ -673,6 +700,7 @@ def __init__(
parameter_selection_strategy: Union[
ParameterSelectionStrategy, str
] = ParameterSelectionStrategy.MULTI,
multi_parameter_strategy: Union[MultiParamStrategy, str] = MultiParamStrategy.PRECISION,
show_progress: bool = False,
progress_title: str = "",
progress_tag: Union[bool, int] = False,
Expand Down Expand Up @@ -714,6 +742,7 @@ def __init__(
self.parameter_selection_strategy = ParameterSelectionStrategy.parse(
parameter_selection_strategy
)
self.multi_parameter_strategy = MultiParamStrategy.parse(multi_parameter_strategy)
self.show_progress = show_progress
self.progress_title = progress_title
self.progress_tag = progress_tag
Expand Down Expand Up @@ -768,6 +797,7 @@ def fork(
auto_adjust_rounders: Union[Keep, bool] = KEEP,
single_precision: Union[Keep, bool] = KEEP,
parameter_selection_strategy: Union[Keep, Union[ParameterSelectionStrategy, str]] = KEEP,
multi_parameter_strategy: Union[Keep, Union[MultiParamStrategy, str]] = KEEP,
show_progress: Union[Keep, bool] = KEEP,
progress_title: Union[Keep, str] = KEEP,
progress_tag: Union[Keep, Union[bool, int]] = KEEP,
Expand Down
15 changes: 14 additions & 1 deletion frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@
set_compiler_logging,
set_llvm_debug_flag,
)
from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerStrategy, PrimitiveOperation
from mlir._mlir_libs._concretelang._compiler import (
KeyType,
OptimizerMultiParamStrategy,
OptimizerStrategy,
PrimitiveOperation,
)
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
Configuration,
MultiParamStrategy,
ParameterSelectionStrategy,
)
from .specs import ClientSpecs
Expand Down Expand Up @@ -161,6 +167,13 @@ def create(
options.set_optimizer_strategy(OptimizerStrategy.DAG_MONO)
elif parameter_selection_strategy == ParameterSelectionStrategy.MULTI: # pragma: no cover
options.set_optimizer_strategy(OptimizerStrategy.DAG_MULTI)

multi_parameter_strategy = configuration.multi_parameter_strategy
converter = {
MultiParamStrategy.PRECISION: OptimizerMultiParamStrategy.PRECISION,
MultiParamStrategy.PRECISION_AND_NORM2: OptimizerMultiParamStrategy.PRECISION_AND_NORM2,
}
options.set_optimizer_multi_parameter_strategy(converter[multi_parameter_strategy])
try:
if configuration.compiler_debug_mode: # pragma: no cover
set_llvm_debug_flag(True)
Expand Down

0 comments on commit 0f77a1e

Please sign in to comment.