-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid input cache when resized #28
Conversation
} | ||
producer->cacheAfter(); | ||
cached.insert(producer); | ||
// Resized tensors are those created by operations like pad and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an unrelated change. Instead of using cacheAfter
, just create a copy of producer
and use it as the input to the resize expr.
Example:
tv1 = sum(tv0)
tv2 = some_resize_op(tv1);
tv3 = some_other_op(tv1);
When tv1
is promoted to Global
, we want to avoid reducing to a global memory tensor, so with cacheAfter
:
tv1 = sum(tv0);
tv4 = tv1
tv4->setMemoryType(Global)
tv2 = some_resize_op(tv3)
tv3 = some_other_op(tv3);
This way, the reduction is done using Local, but some_other_op
doesn't need to use the gmem copy of tv1
. This should be just fine:
tv1 = sum(tv0);
tv4 = tv1
tv4->setMemoryType(Global)
tv2 = some_resize_op(tv4)
tv3 = some_other_op(tv1);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just cacheFork
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think so. There's some similarity, though here we don't change fusion outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can extend the interface of cacheFork
to do this?
If we have:
tv1 = sum(tv0)
tv2 = some_resize_op(tv1);
tv3 = some_other_op(tv1);
Can we just do
tv1->setMemoryType(Global);
tv4 = tv1->cacheFork(/*keep_global=*/{tv2->definition()})
And we get:
tv0 -> tv4 (local) -> tv1 -> tv2
tv4 (fork) -> tv3
Just some idea, have no strong opinion on having to do like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cacheFork
is also designed to change Fusion outputs, which shouldn't be done in this case. We could make it optional, of course. This is a simple transformation, so I don't feel it's worth consolidating into an extended cacheFork
cached.insert(producer); | ||
// Resized tensors are those created by operations like pad and | ||
// slice. If it has no defining expression, it must be a fusion | ||
// input, and no need of the memory type promotion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can a fusion input have resize rfactor? I thought fusion inputs never have rfactor domain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. When segmented, a resized tensor may be an output from a segment and an input to a next segment.
} | ||
producer->cacheAfter(); | ||
cached.insert(producer); | ||
// Resized tensors are those created by operations like pad and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just cacheFork
?
C++ test copied from #27
If the producer of a resized tensor is an input cache and we need to promote it to global memory, just do not cache the input but directly read from the global memory input. It doesn't make any sense to cache a global memory input in global memory.
This fixes the issue of #27, which is due to the grid sync inserted for an input to pad, where the input is a copy of a fusion input.