[WIP] mHC: Manifold-constrained Hyper Connection#1859
Draft
anhminhnguyenhoang wants to merge 68 commits intomainfrom
Draft
[WIP] mHC: Manifold-constrained Hyper Connection#1859anhminhnguyenhoang wants to merge 68 commits intomainfrom
anhminhnguyenhoang wants to merge 68 commits intomainfrom
Conversation
… kernel (#1877) * Refactor mHC kernel and wrapper to implement equations 14-18 with fused kernel * improve comments * Enhance documentation for mhc function: clarify equations, input/output shapes, and activation details * Enhance documentation in test_mhc.py: clarify equations, input/output shapes, and activation details for mHC kernel tests * Add _sinkhorn_knopp_log_domain_kernel to the fusion module * Add logging and sync Sinkhorn-Knopp function for doubly stochastic matrices * sync log-domain Sinkhorn-Knopp kernel for doubly stochastic matrix projection * Improve logging in mhc function to include all alpha parameters
aiter/ops/triton/fusions/mhc.py
Outdated
| else: | ||
| assert out.shape == (M, N), f"Output shape mismatch: expected ({M}, {N}), got {out.shape}" | ||
| assert out.dtype == x.dtype, f"Output dtype mismatch: expected {x.dtype}, got {out.dtype}" | ||
| assert out.device == x.device, f"Output device mismatch" |
Contributor
…plified comments for sinkhorn-knopp impl
|
|
||
| # Res-stream: no constraints (identity activation) | ||
| # Just verify it exists | ||
| assert out_res.shape == (M, n_squared), f"Res-stream shape mismatch" |
Contributor
| # SPDX-License-Identifier: MIT | ||
| # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| from .mhc_ref import * |
Contributor
| # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| from .mhc_ref import * | ||
| from .mla_decode_ref import * |
Contributor
|
|
||
| from .mhc_ref import * | ||
| from .mla_decode_ref import * | ||
| from .mla_extend_ref import * |
Contributor
| from .mhc_ref import * | ||
| from .mla_decode_ref import * | ||
| from .mla_extend_ref import * | ||
| from .rotary_embedding import * |
Contributor
| - H^res: [2n:2n+n²] residual connection (identity) (n² elements) | ||
| """ | ||
| x_f32 = x.to(torch.float32) | ||
| nC = x.shape[1] |
Contributor
| H_tilde = x_norm @ phi_f32 | ||
|
|
||
| # Split into three streams | ||
| n_squared = n * n |
Contributor
* Refactor mHC kernel and wrapper to implement equations 14-18 with fused kernel * improve comments * Enhance documentation for mhc function: clarify equations, input/output shapes, and activation details * Enhance documentation in test_mhc.py: clarify equations, input/output shapes, and activation details for mHC kernel tests * Add _sinkhorn_knopp_log_domain_kernel to the fusion module * Add logging and sync Sinkhorn-Knopp function for doubly stochastic matrices * sync log-domain Sinkhorn-Knopp kernel for doubly stochastic matrix projection * Improve logging in mhc function to include all alpha parameters * Fix H dimensions * Refactor mHC function to return separate output tensors for pre, post, and residual streams * Refactor mhc_torch to return separate output tensors for pre, post, and residual streams * Adjust tolerance for is_doubly_stochastic assertion in test_sk_matrix_sizes for bfloat16 precision
…ke H_res doubly stochastic
Author
anhminhnguyenhoang
left a comment
There was a problem hiding this comment.
Looks good, I would personally clean up the comments as they look a bit redundant
| H_res_torch.to(torch.float32), | ||
| atol=1e-2, | ||
| rtol=1e-2, | ||
| atol=5e-2, |
Author
There was a problem hiding this comment.
Did you run into test failure because of this for similar tests that you need to relax the tolerance?
There was a problem hiding this comment.
Yes, mainly because of sinkhorn which is an iterative process and returns higher differences due you only 10 iterations. May be we can try 20 for better results?
…o pre, post, and residual streams; update tests accordingly.
…ccuracy of assertions.
- Update benchmark script to use dynamic configurations
…dd log for reference
…d parameters tuning for better performance on MI355X (gfx950)
This reverts commit 33ded07.
Changes: - Fix BLOCK_N calculation: Use max(n_factorial, n_squared) instead of conditionally selecting based on hres_op flag. This ensures BLOCK_N is always large enough for both intermediate computation (n! values for softmax over permutations) and final output (n² doubly stochastic matrix). - Rename n_factorial to n_res_expected for clarity, since this variable holds factorial(n) in lite mode but n_squared in sinkhorn mode. The new name better reflects its purpose as the expected dimension of the residual stream output. - Remove redundant conditional in acc_res_cols assignment since the mode-specific logic is already handled in n_res_expected definition.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Co-authors: @waqahmed-amd-fi @anhminhnguyenhoang
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist