Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce transpose operations to speedup #589

Merged
merged 4 commits into from
Feb 3, 2024
Merged

Conversation

hjchen2
Copy link
Contributor

@hjchen2 hjchen2 commented Jan 30, 2024

No description provided.

@hjchen2 hjchen2 requested a review from lixiang007666 January 30, 2024 06:26
@@ -450,7 +450,9 @@ def forward(
# hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
# Rewrite for onediff SVD dynamic shape
hidden_states = (
hidden_states.permute(0, 2, 1).reshape_as(hidden_states_in).contiguous()
hidden_states.reshape_as(residual.permute(0, 2, 3, 1))
.permute(0, 3, 1, 2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshape后面的permute得保留,避免破坏NCHW->NHWC转换后permute消除的优化。

@@ -382,9 +382,9 @@ def forward(
# height * width, batch_size, 1, time_context.shape[-1]
# )
# Rewrite for onediff SVD dynamic shape
broadcast_tensor = hidden_states.flatten(2, 3).permute(2, 0, 1)
broadcast_tensor = hidden_states.flatten(2, 3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里减少一个permute,尽管还是多了一次flatten,但flatten支持view,不会产生实际的kernel计算开销

@hjchen2 hjchen2 merged commit 814053b into main Feb 3, 2024
4 of 5 checks passed
@hjchen2 hjchen2 deleted the dev_speedup_svd_dynamic_shape branch February 3, 2024 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants