Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: GTIR temporary extraction pass #1678

Merged
merged 30 commits into from
Oct 18, 2024

Conversation

tehrengruber
Copy link
Contributor

@tehrengruber tehrengruber commented Oct 1, 2024

New temporary extraction pass. Transforms an itir.Program like

testee(inp, out) {
  out @ c⟨ IDimₕ: [0, 1) ⟩
       ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(inp));
}

into

testee(inp, out) {
  __tmp_1 = temporary(domain=c⟨ IDimₕ: [0, 1) ⟩, dtype=float64);
  __tmp_1 @ c⟨ IDimₕ: [0, 1) ⟩ ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(inp);
  out @ c⟨ IDimₕ: [0, 1) ⟩ ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(__tmp_1);
}

Note that this pass intentionally unconditionally extracts. In case you don't want a temporary you should fuse the as_fieldop before. As such the fusion pass (see #1670) contains the heuristics on what to fuse.

@tehrengruber tehrengruber marked this pull request as draft October 1, 2024 22:39
@tehrengruber tehrengruber changed the title feat[next]: GTIR temporaries extraction feat[next]: GTIR temporaries extraction pass Oct 3, 2024
@tehrengruber tehrengruber marked this pull request as ready for review October 10, 2024 15:02
@@ -8,582 +8,183 @@

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pass here has been removed completely and been replaced by the new pass.

@@ -6,464 +6,219 @@
# Please, refer to the LICENSE file in the root directory.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file has been removed and then rewritten from scratch with new tests.

@tehrengruber tehrengruber changed the title feat[next]: GTIR temporaries extraction pass feat[next]: GTIR temporary extraction pass Oct 10, 2024
@@ -157,23 +160,23 @@ def infer_as_fieldop(
raise ValueError(f"Unsupported expression of type '{type(in_field)}'.")
input_ids.append(id_)

accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains(
inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reasoning for the following changes in this function: This dict contains as keys not only the symref inputs, but also temporary ids. The symrefs are already added to the result dict by the loop below, while the temporary ids should not be in the result anyway. as such do not use this dict as the starting point for the domain union in the loop below.

@@ -21,7 +21,7 @@
ir_makers as im,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SF-N Can you review the functional changes in this file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They all make sense to me.

Copy link
Contributor

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just questions and minor style comments (most of them actually optional).

@@ -319,7 +320,7 @@ def extract_subexpression(
subexprs = CollectSubexpressions.apply(node)

# collect multiple occurrences and map them to fresh symbols
expr_map = dict[int, itir.SymRef]()
expr_map: dict[int, itir.SymRef] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

src/gt4py/next/iterator/transforms/global_tmps.py Outdated Show resolved Hide resolved


def _transform_by_pattern(
stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotation for predicate (why is mypy not complaining? We need to make our settings stricter)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe because the function is private?

for tmp_sym, tmp_expr in extracted_fields.items():
domain = tmp_expr.annex.domain

# TODO(tehrengruber): Implement. This happens when the expression for a combination
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check this comment and rephrase it (what does Implement. mean here?)

Suggested change
# TODO(tehrengruber): Implement. This happens when the expression for a combination
# TODO(tehrengruber): Implement. This happens when the expression is a combination

Copy link
Contributor Author

@tehrengruber tehrengruber Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how or why to improve this comment. The task to do is to implement the case described in detail, whose conditions are expressed in the code that leads to a NotImplementedError thereafter.

src/gt4py/next/iterator/transforms/global_tmps.py Outdated Show resolved Hide resolved
Comment on lines 84 to 91
flattened_domains: tuple[domain_utils.SymbolicDomain] = (
next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough
)
if not all(d == flattened_domains[0] for d in flattened_domains):
raise NotImplementedError(
"Tuple expressions with different domains is not " "supported yet."
)
domain = flattened_domains[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional suggestion for an alternative (cleaner?) implementation of this (assuming SymbolicDomain is a hashable frozen dataclass):

Suggested change
flattened_domains: tuple[domain_utils.SymbolicDomain] = (
next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough
)
if not all(d == flattened_domains[0] for d in flattened_domains):
raise NotImplementedError(
"Tuple expressions with different domains is not " "supported yet."
)
domain = flattened_domains[0]
flattened_domains: tuple[domain_utils.SymbolicDomain] = set(
next_utils.flatten_nested_tuple(domain)
)
domain = flattened_domains.pop()
if flattened_domains:
raise NotImplementedError(
"Tuple expressions with different domains is not " "supported yet."
)

Copy link
Contributor Author

@tehrengruber tehrengruber Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SymbolicDomain contains a dict and is not hashable. I could make it hashable, but investing on a NotImplementedError isn't worth it.

for transform in transforms:
transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids)
if transformed_stmts:
unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a collections.deque for unprocessed_stmts could make this simpler and probably more efficient.

Suggested change
unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts]
unprocessed_stmts.appendleft(transformed_stmts)

Copy link
Contributor Author

@tehrengruber tehrengruber Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A deque doesn't really fit since

  1. only need a one sided queue
  2. deque.extendleft reverses the order transformed_stmts which is wrong here. Manually reversing nulls simplicity again.

I can reverse the order of the list if you prefer that, but it feels like premature optimizations.

# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
# λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))`
ir = _inline_into_scan(ir)
# ruff: noqa: ERA001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out and not deleted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the implementation is still missing and this is the template for it. It will be deleted in the next PR. Not sure why I commented it out, but doesn't matter.

otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES,
otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics,
# otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, # noqa: ERA001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there is no heuristics right now and I didn't want to remove it to then rewire / rewrite it again.

UIDs.reset_sequence()
testee = ir.FencilDefinition(
id="f",
def program_factory(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really the only placed where this factory is/should be used? It looks like a pretty general utility to me....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most tests actually only work on expressions not programs and since every test needs a slightly different version of this making it a general utility is not so useful.

Copy link
Contributor

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tehrengruber tehrengruber merged commit eb0a0c1 into GridTools:main Oct 18, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants