-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: GTIR temporary extraction pass #1678
feat[next]: GTIR temporary extraction pass #1678
Conversation
20a1314
to
f0331cb
Compare
@@ -8,582 +8,183 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pass here has been removed completely and been replaced by the new pass.
@@ -6,464 +6,219 @@ | |||
# Please, refer to the LICENSE file in the root directory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file has been removed and then rewritten from scratch with new tests.
@@ -157,23 +160,23 @@ def infer_as_fieldop( | |||
raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") | |||
input_ids.append(id_) | |||
|
|||
accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( | |||
inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reasoning for the following changes in this function: This dict contains as keys not only the symref inputs, but also temporary ids. The symrefs are already added to the result dict by the loop below, while the temporary ids should not be in the result anyway. as such do not use this dict as the starting point for the domain union in the loop below.
@@ -21,7 +21,7 @@ | |||
ir_makers as im, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SF-N Can you review the functional changes in this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They all make sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just questions and minor style comments (most of them actually optional).
@@ -319,7 +320,7 @@ def extract_subexpression( | |||
subexprs = CollectSubexpressions.apply(node) | |||
|
|||
# collect multiple occurrences and map them to fresh symbols | |||
expr_map = dict[int, itir.SymRef]() | |||
expr_map: dict[int, itir.SymRef] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
||
|
||
def _transform_by_pattern( | ||
stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing type annotation for predicate
(why is mypy not complaining? We need to make our settings stricter)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe because the function is private?
for tmp_sym, tmp_expr in extracted_fields.items(): | ||
domain = tmp_expr.annex.domain | ||
|
||
# TODO(tehrengruber): Implement. This happens when the expression for a combination |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check this comment and rephrase it (what does Implement.
mean here?)
# TODO(tehrengruber): Implement. This happens when the expression for a combination | |
# TODO(tehrengruber): Implement. This happens when the expression is a combination |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how or why to improve this comment. The task to do is to implement the case described in detail, whose conditions are expressed in the code that leads to a NotImplementedError
thereafter.
flattened_domains: tuple[domain_utils.SymbolicDomain] = ( | ||
next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough | ||
) | ||
if not all(d == flattened_domains[0] for d in flattened_domains): | ||
raise NotImplementedError( | ||
"Tuple expressions with different domains is not " "supported yet." | ||
) | ||
domain = flattened_domains[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional suggestion for an alternative (cleaner?) implementation of this (assuming SymbolicDomain
is a hashable frozen dataclass):
flattened_domains: tuple[domain_utils.SymbolicDomain] = ( | |
next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough | |
) | |
if not all(d == flattened_domains[0] for d in flattened_domains): | |
raise NotImplementedError( | |
"Tuple expressions with different domains is not " "supported yet." | |
) | |
domain = flattened_domains[0] | |
flattened_domains: tuple[domain_utils.SymbolicDomain] = set( | |
next_utils.flatten_nested_tuple(domain) | |
) | |
domain = flattened_domains.pop() | |
if flattened_domains: | |
raise NotImplementedError( | |
"Tuple expressions with different domains is not " "supported yet." | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SymbolicDomain contains a dict and is not hashable. I could make it hashable, but investing on a NotImplementedError
isn't worth it.
for transform in transforms: | ||
transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids) | ||
if transformed_stmts: | ||
unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a collections.deque
for unprocessed_stmts
could make this simpler and probably more efficient.
unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts] | |
unprocessed_stmts.appendleft(transformed_stmts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A deque doesn't really fit since
- only need a one sided queue
deque.extendleft
reverses the ordertransformed_stmts
which is wrong here. Manually reversing nulls simplicity again.
I can reverse the order of the list if you prefer that, but it feels like premature optimizations.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. | ||
# λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` | ||
ir = _inline_into_scan(ir) | ||
# ruff: noqa: ERA001 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out and not deleted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the implementation is still missing and this is the template for it. It will be deleted in the next PR. Not sure why I commented it out, but doesn't matter.
otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, | ||
otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, | ||
# otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, # noqa: ERA001 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because there is no heuristics right now and I didn't want to remove it to then rewire / rewrite it again.
UIDs.reset_sequence() | ||
testee = ir.FencilDefinition( | ||
id="f", | ||
def program_factory( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really the only placed where this factory is/should be used? It looks like a pretty general utility to me....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most tests actually only work on expressions not programs and since every test needs a slightly different version of this making it a general utility is not so useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
New temporary extraction pass. Transforms an
itir.Program
likeinto
Note that this pass intentionally unconditionally extracts. In case you don't want a temporary you should fuse the
as_fieldop
before. As such the fusion pass (see #1670) contains the heuristics on what to fuse.