Skip to content

[Fix] fix the mha fwd_v3 segment fault in torch.compile(mode="reduce-overhead", fullgraph=True)#1794

Merged
valarLip merged 5 commits intomainfrom
mmd/fix/torchcompile
Jan 13, 2026
Merged

[Fix] fix the mha fwd_v3 segment fault in torch.compile(mode="reduce-overhead", fullgraph=True)#1794
valarLip merged 5 commits intomainfrom
mmd/fix/torchcompile

Conversation

@minmengdie
Copy link
Contributor

@minmengdie minmengdie commented Jan 8, 2026

Motivation

fix the mha fwd_v3 segment fault in torch.compile(mode="reduce-overhead", fullgraph=True)

Technical Details

  1. changed the reference capture to the value capture to prevent the args from being destroyed early.
  2. Explicit assignment forces evaluation order and prevents compiler from reordering operations that could lead to accessing uninitialized args.

Test Plan

cd /root/rtfm-amd
source ./.venv/bin/activate
python -m wlt.models.rtfm.demo.server.server --model-config rtfm_10_06_512_chonky_meanflow --port 8083

Test Result

image

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds thread safety improvements via a mutex and temporary debug logging to the MHA (Multi-Head Attention) forward pass implementation, while also adjusting test configurations.

Key Changes:

  • Added mutex protection for thread-safe access to the kernel implementation pointer map in C++ code
  • Added extensive debug logging statements for troubleshooting purposes
  • Modified test parameters (GQA head count and test configuration dimensions)

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.

File Description
csrc/cpp_itfs/mha_fwd.cpp Added mutex for thread-safe kernel pointer map access and debug logging statements throughout
op_tests/test_mha.py Commented out parameter, reduced test dimensions, changed GQA head count, and disabled seq_padding test

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@minmengdie minmengdie changed the title add log and mutex for test [Fix] fix the mha fwd_v3 segment fault in torch.compile(mode="reduce-overhead", fullgraph=True) Jan 12, 2026
@valarLip valarLip merged commit 2985cb6 into main Jan 13, 2026
17 checks passed
@valarLip valarLip deleted the mmd/fix/torchcompile branch January 13, 2026 04:14
zhuyuhua-v pushed a commit that referenced this pull request Jan 14, 2026
…overhead", fullgraph=True) (#1794)

* add log and mutex for test

* add thread_local

* value capture args

* fix the Explicit assignment forces evaluation order and prevents compiler from reordering operations

* delete some logs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants