diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index 644cafe5c8..ef8863e02b 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -43,13 +43,13 @@ def __init__( self.shifts_with_promotion = shifts_with_promotion def apply(self, graph: Graph): - optimizer = z3.Optimize() + solver = z3.Solver() max_bit_width: z3.Int = z3.Int("max") bit_widths: Dict[Node, z3.Int] = {} additional_constraints = AdditionalConstraints( - optimizer, + solver, graph, bit_widths, self.comparison_strategy_preference, @@ -69,19 +69,17 @@ def apply(self, graph: Graph): bit_width = z3.Int(f"%{i}") bit_widths[node] = bit_width - optimizer.add(max_bit_width >= bit_width) - optimizer.add(bit_width >= required_bit_width) + solver.add(max_bit_width >= bit_width) + solver.add(bit_width >= required_bit_width) additional_constraints.generate_for(node, bit_width) if self.single_precision: for bit_width in bit_widths.values(): - optimizer.add(bit_width == max_bit_width) + solver.add(bit_width == max_bit_width) - optimizer.minimize(sum(bit_width**2 for bit_width in bit_widths.values())) - - assert optimizer.check() == z3.sat - model = optimizer.model() + assert solver.check() == z3.sat + model = solver.model() for node, bit_width in bit_widths.items(): assert isinstance(node.output.dtype, Integer) @@ -99,7 +97,7 @@ class AdditionalConstraints: AdditionalConstraints class to customize bit-width assignment step easily. """ - optimizer: z3.Optimize + solver: z3.Solver graph: Graph bit_widths: Dict[Node, z3.Int] @@ -114,14 +112,14 @@ class AdditionalConstraints: def __init__( self, - optimizer: z3.Optimize, + solver: z3.Solver, graph: Graph, bit_widths: Dict[Node, z3.Int], comparison_strategy_preference: List[ComparisonStrategy], bitwise_strategy_preference: List[BitwiseStrategy], shifts_with_promotion: bool, ): - self.optimizer = optimizer + self.solver = solver self.graph = graph self.bit_widths = bit_widths @@ -192,12 +190,12 @@ def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool: def inputs_share_precision(self, node: Node, preds: List[Node]): for i in range(len(preds) - 1): - self.optimizer.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]]) + self.solver.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]]) def inputs_and_output_share_precision(self, node: Node, preds: List[Node]): self.inputs_share_precision(node, preds) if len(preds) != 0: - self.optimizer.add(self.bit_widths[preds[-1]] == self.bit_widths[node]) + self.solver.add(self.bit_widths[preds[-1]] == self.bit_widths[node]) def inputs_require_one_more_bit(self, node: Node, preds: List[Node]): for pred in preds: @@ -206,7 +204,7 @@ def inputs_require_one_more_bit(self, node: Node, preds: List[Node]): actual_bit_width = pred.output.dtype.bit_width required_bit_width = actual_bit_width + 1 - self.optimizer.add(self.bit_widths[pred] >= required_bit_width) + self.solver.add(self.bit_widths[pred] >= required_bit_width) def comparison(self, node: Node, preds: List[Node]): assert len(preds) == 2 @@ -226,11 +224,11 @@ def comparison(self, node: Node, preds: List[Node]): for strategy in strategies + fallback: if strategy.can_be_used(x.output, y.output): new_x_bit_width, new_y_bit_width = strategy.promotions(x.output, y.output) - self.optimizer.add(self.bit_widths[x] >= new_x_bit_width) - self.optimizer.add(self.bit_widths[y] >= new_y_bit_width) + self.solver.add(self.bit_widths[x] >= new_x_bit_width) + self.solver.add(self.bit_widths[y] >= new_y_bit_width) if strategy == ComparisonStrategy.ONE_TLU_PROMOTED: - self.optimizer.add(self.bit_widths[x] == self.bit_widths[y]) + self.solver.add(self.bit_widths[x] == self.bit_widths[y]) node.properties["strategy"] = strategy break @@ -252,11 +250,11 @@ def bitwise(self, node: Node, preds: List[Node]): for strategy in strategies + fallback: if strategy.can_be_used(x.output, y.output): new_x_bit_width, new_y_bit_width = strategy.promotions(x.output, y.output) - self.optimizer.add(self.bit_widths[x] >= new_x_bit_width) - self.optimizer.add(self.bit_widths[y] >= new_y_bit_width) + self.solver.add(self.bit_widths[x] >= new_x_bit_width) + self.solver.add(self.bit_widths[y] >= new_y_bit_width) if strategy == BitwiseStrategy.ONE_TLU_PROMOTED: - self.optimizer.add(self.bit_widths[x] == self.bit_widths[y]) + self.solver.add(self.bit_widths[x] == self.bit_widths[y]) node.properties["strategy"] = strategy break @@ -266,7 +264,7 @@ def bitwise(self, node: Node, preds: List[Node]): and node.properties["strategy"] == BitwiseStrategy.CHUNKED and self.shifts_with_promotion ): - self.optimizer.add(self.bit_widths[x] == self.bit_widths[node]) + self.solver.add(self.bit_widths[x] == self.bit_widths[node]) # ========== # Operations