Skip to content

Commit

Permalink
chore(frontend-python): remove optimization goal from z3 as it's extr…
Browse files Browse the repository at this point in the history
…emely slow for big graphs
  • Loading branch information
umut-sahin committed Sep 20, 2023
1 parent 2327c15 commit 4827e9e
Showing 1 changed file with 20 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def __init__(
self.shifts_with_promotion = shifts_with_promotion

def apply(self, graph: Graph):
optimizer = z3.Optimize()
solver = z3.Solver()

max_bit_width: z3.Int = z3.Int("max")
bit_widths: Dict[Node, z3.Int] = {}

additional_constraints = AdditionalConstraints(
optimizer,
solver,
graph,
bit_widths,
self.comparison_strategy_preference,
Expand All @@ -69,19 +69,17 @@ def apply(self, graph: Graph):
bit_width = z3.Int(f"%{i}")
bit_widths[node] = bit_width

optimizer.add(max_bit_width >= bit_width)
optimizer.add(bit_width >= required_bit_width)
solver.add(max_bit_width >= bit_width)
solver.add(bit_width >= required_bit_width)

additional_constraints.generate_for(node, bit_width)

if self.single_precision:
for bit_width in bit_widths.values():
optimizer.add(bit_width == max_bit_width)
solver.add(bit_width == max_bit_width)

optimizer.minimize(sum(bit_width**2 for bit_width in bit_widths.values()))

assert optimizer.check() == z3.sat
model = optimizer.model()
assert solver.check() == z3.sat
model = solver.model()

for node, bit_width in bit_widths.items():
assert isinstance(node.output.dtype, Integer)
Expand All @@ -99,7 +97,7 @@ class AdditionalConstraints:
AdditionalConstraints class to customize bit-width assignment step easily.
"""

optimizer: z3.Optimize
solver: z3.Solver
graph: Graph
bit_widths: Dict[Node, z3.Int]

Expand All @@ -114,14 +112,14 @@ class AdditionalConstraints:

def __init__(
self,
optimizer: z3.Optimize,
solver: z3.Solver,
graph: Graph,
bit_widths: Dict[Node, z3.Int],
comparison_strategy_preference: List[ComparisonStrategy],
bitwise_strategy_preference: List[BitwiseStrategy],
shifts_with_promotion: bool,
):
self.optimizer = optimizer
self.solver = solver
self.graph = graph
self.bit_widths = bit_widths

Expand Down Expand Up @@ -192,12 +190,12 @@ def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool:

def inputs_share_precision(self, node: Node, preds: List[Node]):
for i in range(len(preds) - 1):
self.optimizer.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]])
self.solver.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]])

def inputs_and_output_share_precision(self, node: Node, preds: List[Node]):
self.inputs_share_precision(node, preds)
if len(preds) != 0:
self.optimizer.add(self.bit_widths[preds[-1]] == self.bit_widths[node])
self.solver.add(self.bit_widths[preds[-1]] == self.bit_widths[node])

def inputs_require_one_more_bit(self, node: Node, preds: List[Node]):
for pred in preds:
Expand All @@ -206,7 +204,7 @@ def inputs_require_one_more_bit(self, node: Node, preds: List[Node]):
actual_bit_width = pred.output.dtype.bit_width
required_bit_width = actual_bit_width + 1

self.optimizer.add(self.bit_widths[pred] >= required_bit_width)
self.solver.add(self.bit_widths[pred] >= required_bit_width)

def comparison(self, node: Node, preds: List[Node]):
assert len(preds) == 2
Expand All @@ -226,11 +224,11 @@ def comparison(self, node: Node, preds: List[Node]):
for strategy in strategies + fallback:
if strategy.can_be_used(x.output, y.output):
new_x_bit_width, new_y_bit_width = strategy.promotions(x.output, y.output)
self.optimizer.add(self.bit_widths[x] >= new_x_bit_width)
self.optimizer.add(self.bit_widths[y] >= new_y_bit_width)
self.solver.add(self.bit_widths[x] >= new_x_bit_width)
self.solver.add(self.bit_widths[y] >= new_y_bit_width)

if strategy == ComparisonStrategy.ONE_TLU_PROMOTED:
self.optimizer.add(self.bit_widths[x] == self.bit_widths[y])
self.solver.add(self.bit_widths[x] == self.bit_widths[y])

node.properties["strategy"] = strategy
break
Expand All @@ -252,11 +250,11 @@ def bitwise(self, node: Node, preds: List[Node]):
for strategy in strategies + fallback:
if strategy.can_be_used(x.output, y.output):
new_x_bit_width, new_y_bit_width = strategy.promotions(x.output, y.output)
self.optimizer.add(self.bit_widths[x] >= new_x_bit_width)
self.optimizer.add(self.bit_widths[y] >= new_y_bit_width)
self.solver.add(self.bit_widths[x] >= new_x_bit_width)
self.solver.add(self.bit_widths[y] >= new_y_bit_width)

if strategy == BitwiseStrategy.ONE_TLU_PROMOTED:
self.optimizer.add(self.bit_widths[x] == self.bit_widths[y])
self.solver.add(self.bit_widths[x] == self.bit_widths[y])

node.properties["strategy"] = strategy
break
Expand All @@ -266,7 +264,7 @@ def bitwise(self, node: Node, preds: List[Node]):
and node.properties["strategy"] == BitwiseStrategy.CHUNKED
and self.shifts_with_promotion
):
self.optimizer.add(self.bit_widths[x] == self.bit_widths[node])
self.solver.add(self.bit_widths[x] == self.bit_widths[node])

# ==========
# Operations
Expand Down

0 comments on commit 4827e9e

Please sign in to comment.