Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(frontend-python): remove optimization goal from z3 #567

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def __init__(
self.shifts_with_promotion = shifts_with_promotion

def apply(self, graph: Graph):
optimizer = z3.Optimize()
solver = z3.Solver()

max_bit_width: z3.Int = z3.Int("max")
bit_widths: Dict[Node, z3.Int] = {}

additional_constraints = AdditionalConstraints(
optimizer,
solver,
graph,
bit_widths,
self.comparison_strategy_preference,
Expand All @@ -69,19 +69,17 @@ def apply(self, graph: Graph):
bit_width = z3.Int(f"%{i}")
bit_widths[node] = bit_width

optimizer.add(max_bit_width >= bit_width)
optimizer.add(bit_width >= required_bit_width)
solver.add(max_bit_width >= bit_width)
solver.add(bit_width >= required_bit_width)

additional_constraints.generate_for(node, bit_width)

if self.single_precision:
for bit_width in bit_widths.values():
optimizer.add(bit_width == max_bit_width)
solver.add(bit_width == max_bit_width)

optimizer.minimize(sum(bit_width**2 for bit_width in bit_widths.values()))

assert optimizer.check() == z3.sat
model = optimizer.model()
assert solver.check() == z3.sat
model = solver.model()

for node, bit_width in bit_widths.items():
assert isinstance(node.output.dtype, Integer)
Expand All @@ -99,7 +97,7 @@ class AdditionalConstraints:
AdditionalConstraints class to customize bit-width assignment step easily.
"""

optimizer: z3.Optimize
solver: z3.Solver
graph: Graph
bit_widths: Dict[Node, z3.Int]

Expand All @@ -114,14 +112,14 @@ class AdditionalConstraints:

def __init__(
self,
optimizer: z3.Optimize,
solver: z3.Solver,
graph: Graph,
bit_widths: Dict[Node, z3.Int],
comparison_strategy_preference: List[ComparisonStrategy],
bitwise_strategy_preference: List[BitwiseStrategy],
shifts_with_promotion: bool,
):
self.optimizer = optimizer
self.solver = solver
self.graph = graph
self.bit_widths = bit_widths

Expand Down Expand Up @@ -192,12 +190,12 @@ def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool:

def inputs_share_precision(self, node: Node, preds: List[Node]):
for i in range(len(preds) - 1):
self.optimizer.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]])
self.solver.add(self.bit_widths[preds[i]] == self.bit_widths[preds[i + 1]])

def inputs_and_output_share_precision(self, node: Node, preds: List[Node]):
self.inputs_share_precision(node, preds)
if len(preds) != 0:
self.optimizer.add(self.bit_widths[preds[-1]] == self.bit_widths[node])
self.solver.add(self.bit_widths[preds[-1]] == self.bit_widths[node])

def inputs_require_one_more_bit(self, node: Node, preds: List[Node]):
for pred in preds:
Expand All @@ -206,7 +204,7 @@ def inputs_require_one_more_bit(self, node: Node, preds: List[Node]):
actual_bit_width = pred.output.dtype.bit_width
required_bit_width = actual_bit_width + 1

self.optimizer.add(self.bit_widths[pred] >= required_bit_width)
self.solver.add(self.bit_widths[pred] >= required_bit_width)

def comparison(self, node: Node, preds: List[Node]):
assert len(preds) == 2
Expand All @@ -226,11 +224,11 @@ def comparison(self, node: Node, preds: List[Node]):
for strategy in strategies + fallback:
if strategy.can_be_used(x.output, y.output):
new_x_bit_width, new_y_bit_width = strategy.promotions(x.output, y.output)
self.optimizer.add(self.bit_widths[x] >= new_x_bit_width)
self.optimizer.add(self.bit_widths[y] >= new_y_bit_width)
self.solver.add(self.bit_widths[x] >= new_x_bit_width)
self.solver.add(self.bit_widths[y] >= new_y_bit_width)

if strategy == ComparisonStrategy.ONE_TLU_PROMOTED:
self.optimizer.add(self.bit_widths[x] == self.bit_widths[y])
self.solver.add(self.bit_widths[x] == self.bit_widths[y])

node.properties["strategy"] = strategy
break
Expand All @@ -252,11 +250,11 @@ def bitwise(self, node: Node, preds: List[Node]):
for strategy in strategies + fallback:
if strategy.can_be_used(x.output, y.output):
new_x_bit_width, new_y_bit_width = strategy.promotions(x.output, y.output)
self.optimizer.add(self.bit_widths[x] >= new_x_bit_width)
self.optimizer.add(self.bit_widths[y] >= new_y_bit_width)
self.solver.add(self.bit_widths[x] >= new_x_bit_width)
self.solver.add(self.bit_widths[y] >= new_y_bit_width)

if strategy == BitwiseStrategy.ONE_TLU_PROMOTED:
self.optimizer.add(self.bit_widths[x] == self.bit_widths[y])
self.solver.add(self.bit_widths[x] == self.bit_widths[y])

node.properties["strategy"] = strategy
break
Expand All @@ -266,7 +264,7 @@ def bitwise(self, node: Node, preds: List[Node]):
and node.properties["strategy"] == BitwiseStrategy.CHUNKED
and self.shifts_with_promotion
):
self.optimizer.add(self.bit_widths[x] == self.bit_widths[node])
self.solver.add(self.bit_widths[x] == self.bit_widths[node])

# ==========
# Operations
Expand Down