Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output stride order #2548

Merged
merged 26 commits into from
Mar 7, 2023
Merged

Output stride order #2548

merged 26 commits into from
Mar 7, 2023

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Mar 7, 2023

Added new python API fd.ops.add_output(tensor, stride_order), where stride_order means that output axis i is the stride_order[i]th fastest dimension.

e.g. if we want to specify output to be in channel-last format, we should specify fd.ops.add_output(tensor_view, [0, 3, 1, 2]), where a given output with shape [N, C, H, W] will have stride [H*W*C, 1, W*C, C]

Implementation details:
It's currently done in a naive way. Since nvfuser doesn't support user specified stride order yet, we fake it by:

  1. adding a permute op on outputs inside the generated kernel, to ensure that the output is stored in the correct memory layout;
  2. after the kernel has executed, we permute that corresponding output to undo the permutation inside the kernel, this gives us the semantically correct output in the desired memory layout.

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

std::vector<int64_t> stride_order) {
FUSER_PERF_SCOPE("FusionDefinition.add_output (tensor)");
TORCH_CHECK(
!self.id().has_value(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic to this PR, but self.completed() exists for these.

TORCH_CHECK(
duplicate_check == (1 << reverse_perm.size()) - 1,
"duplicated elements in stride_order detected!");
tv_output = permute(tv_output, reverse_perm);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong but I think this call is first time the length of provided perm is checked to equal ndim of other argument. It might be nice to check that right up front.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I had it in python_bindings.cpp earlier, and removed that as duplication.

But I think it makes sense to throw an error up front. I'll add one down there.

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

For serialization, we just need to add the stride_order field for OutputRecord.

@jjsjann123
Copy link
Collaborator Author

review comments have been addressed. running CI locally and will merge the PR afterwards.

@jjsjann123
Copy link
Collaborator Author

I am seeing failing CIs, but I don't think they are relevant. I'm merging this one.

@jjsjann123 jjsjann123 merged commit 8ed9540 into devel Mar 7, 2023
@jjsjann123 jjsjann123 deleted the output_stride_order branch March 7, 2023 21:15
@csarofeen
Copy link
Owner

@zasdfgbnm @jacobhinkle could we revisit this approach now that we have allocation domains?

@csarofeen
Copy link
Owner

Warning this is the csarofeen/pytorch repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants