-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New TransformPropagator algorithm #1763
Conversation
Marking as ready for review. Tests are all green. 😎 Feel free to play around, and I will do some cleanup and add more comments. |
I am done adding docs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some superficial suggestions for now
// nullptr used to start from starting_tv | ||
return next_hop.to->nDims(); | ||
} | ||
// TODO: why does TransformReplay require specifying a position in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's because TransformReplay
was designed for computeAt
, which takes a position in the leaf domain.
TransformPropagator
does not take such a position parameter, but that may be something we would want eventually?
@@ -4,7 +4,6 @@ | |||
|
|||
#include <torch/csrc/jit/codegen/cuda/fusion.h> | |||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h> | |||
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, was this causing any problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it worked fine, though I don't know why. It looks weird to me that two headers include each other.
// I think I need to modify TransformReplay to add a new interface to specify | ||
// the root domains, instead of a position in the leaf domain. With the new | ||
// interface, this function will not be needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is due to a mismatch between TransformReplay
and TransformPropagator
. The former replays the first N
leaf IDs, whereas the latter replays everything from the starting reference tensor. I suspect we would want the behavior of TransformReplay
, but not sure.
// Find the pos where all leaf IDs at <= pos contains | ||
// information about the starting root domain | ||
// | ||
// TODO: should I change to the following behavior? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you be more specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a few more lines in the comment below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you create a test exhibiting this behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this will never happen? I added a new test at 5cc10c8, but looks like this is not triggered. I think this is because, if from
is the reference tensor, then we always have full information. If not, then from
must come from a replay, then the axes containing reference tensor information will always be put to the front. So this case will never happen?
I will merge this PR for now, and will write a followup PR for TransformReplay
starting from specified root domain. Once that is done, then this issue will not be relevant. Feel free to leave more comment, and I will resolve in my followup PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we can rewrite TransformReplay
to use root domain positions rather than leaf domain positions. In other words, computeAt
currently is specified with leaf positions. Can we change that to use root positions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the meantime, we may want to assert that no remaining domains are included in relevant_leaves
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we can rewrite
TransformReplay
to use root domain positions rather than leaf domain positions. In other words,computeAt
currently is specified with leaf positions. Can we change that to use root positions?
I think the interface and outcome of computeAt
should still use the leaf positions. But during propagation, since we are saving information about root/rfactor domains, I think it makes sense to change the interface (or at least add an additional interface) to specify root/rfactor domains? I am not sure how doable this is. Need to dig into the code to see.
For the meantime, we may want to assert that no remaining domains are included in
relevant_leaves
.
I think it makes sense to add an assert. If we decide that we should not change TransformReplay
to specify the root domain, then I will add assert in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should talk more about this, root domains is a good mechanism to understand "how replayed" one replay is versus another. I'm skeptical propagating based on root domains is a good idea.
Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
continue; | ||
} | ||
for (auto root_id : root_ids) { | ||
if (id == root_id || DependencyCheck::isDependencyOf(root_id, id)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is giving full credit to a domain in rfactor that could potentially only have a partial domain from the root. I wonder how safe this is through complex uses of view. Are there instances where we would have to accurately track "partial" ownership of an rfactor domain with view?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's an interesting question. I don't know if there's an actually adverse case. It seems to me also about how propagation should be done, whether transformation should be propagated through partial ownership? Not sure which should be preferred.
continue; | ||
} | ||
for (auto rfactor_id : rfactor_ids) { | ||
if (DependencyCheck::isDependencyOf(id, rfactor_id)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
Really cool algorithm, only comments I really have: I think it make sense to update the compute at PR with the new algorithm, fix up our current scheduling, then come back and revisit for view, as we want view support for after the next PyT release. |
Agree with Christian. |
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - TransformPropagator refactor: switched to Dijkstra instead of exhaustive enumeration on all possible paths to reduce compilation time on transform propagation; - Indexing refactor: remove reference tensor creation in all tensor indexing logic (csarofeen#1690) - (more) generic grouped grid reduction kernel; - Minor parser/fuser patches: 1. zero-dim tensor reduction support 3. no-op binary removal within fused graph 4. expand supported in fusion Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` a054b3e Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (csarofeen#1775) d67e1cd Indexing refactor stage 1: remove reference tensor creation in all tensor indexing logic (csarofeen#1690) 1b65299 Issue 1770 (csarofeen#1774) 35b0427 Avoid compilation errors like below: (csarofeen#1773) 452c773 Ignore reductions of zero-dim tensors per PyTorch conventions (csarofeen#1771) 31d6c56 TransformPropagator refactor (csarofeen#1769) 570c5a8 Merge pull request csarofeen#1767 from csarofeen/upstream_merge_0621 9d6c3d8 merging upstream 61305cd 0ed815f New TransformPropagator algorithm (csarofeen#1763) 6c19520 no-op binary removal (csarofeen#1764) ec7fa41 Proper propagation of IterType (csarofeen#1762) b263562 Fix dimensionality check (csarofeen#1759) 2d6343f More generic grouped grid reduction kernel (csarofeen#1740) 64e2b56 [nvfuser] prevent spamming warning message (pytorch#77777) (csarofeen#1758) 0c43162 [nvFuser] Improving bitwise ops support (pytorch#77158) (csarofeen#1757) b93a147 Parser expand (csarofeen#1754) ``` RUN_TORCHBENCH: nvfuser Pull Request resolved: pytorch#80355 Approved by: https://github.com/davidberard98
Fixes #1760, but far beyond that.
Per offline discussion with @csarofeen and @naoyam, I completely rewrite the
TransformPropagator
. In this newTransformPropagator
, I explicitly keep track of the information about which root ID in the starting tensor is preserved. TheRootIDInfo
stores the information for each root ID.view
is not treated differently from other ops. During propagation, I do Dijkstra to find the path for each tensor in the graph that preserves the most amount of information. Each tensor will only be replayed once.