-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Output stride order #2548
Output stride order #2548
Conversation
…oid permutation info lost in cache hit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
std::vector<int64_t> stride_order) { | ||
FUSER_PERF_SCOPE("FusionDefinition.add_output (tensor)"); | ||
TORCH_CHECK( | ||
!self.id().has_value(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off-topic to this PR, but self.completed()
exists for these.
TORCH_CHECK( | ||
duplicate_check == (1 << reverse_perm.size()) - 1, | ||
"duplicated elements in stride_order detected!"); | ||
tv_output = permute(tv_output, reverse_perm); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could be wrong but I think this call is first time the length of provided perm is checked to equal ndim of other argument. It might be nice to check that right up front.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. I had it in python_bindings.cpp
earlier, and removed that as duplication.
But I think it makes sense to throw an error up front. I'll add one down there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
For serialization, we just need to add the stride_order
field for OutputRecord
.
review comments have been addressed. running CI locally and will merge the PR afterwards. |
I am seeing failing CIs, but I don't think they are relevant. I'm merging this one. |
@zasdfgbnm @jacobhinkle could we revisit this approach now that we have allocation domains? |
Warning this is the csarofeen/pytorch repo. |
Added new python API
fd.ops.add_output(tensor, stride_order)
, wherestride_order
means that output axisi
is thestride_order[i]
th fastest dimension.e.g. if we want to specify output to be in channel-last format, we should specify
fd.ops.add_output(tensor_view, [0, 3, 1, 2])
, where a given output with shape[N, C, H, W]
will have stride[H*W*C, 1, W*C, C]
Implementation details:
It's currently done in a naive way. Since nvfuser doesn't support user specified stride order yet, we fake it by: