-
Notifications
You must be signed in to change notification settings - Fork 333
[Bugfix] Implement classic arena algorithm for shmem merge and WAW conflict detection #1146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughReplaces legacy arena plumbing with a self-contained arena-oriented planner that computes per-buffer liveness intervals, alignment-aware linear-scan packing, per-buffer offsets and arena size (with a sequential fallback on overlap). Also expands sync conflict checks with pointer-range disjointness and propagates TMA async-copy context to storage-access entries. Changes
Sequence Diagram(s)sequenceDiagram
participant Analyzer as Liveness Analyzer
participant Collector as Buf Collector
participant Planner as ArenaPlan Builder
participant Packer as LinearScanPacker
participant Fallback as SequentialLayouter
participant Emitter as Offset Emitter
Analyzer->>Collector: supply start/end indices & levels
Collector->>Planner: build BufInfo (size, align, const flag)
Planner->>Packer: provide intervals + alignment constraints
rect rgb(235,248,255)
Packer->>Packer: linear-scan packing → assign offsets
Packer->>Planner: success / report overlaps
end
alt Overlap for constant buffers
Planner->>Fallback: compute sequential offsets (no reuse)
Fallback->>Emitter: emit offsets & merged arena_size
else Successful packing
Packer->>Emitter: emit offsets & merged arena_size
end
Emitter->>Emitter: log arena plan (verbose)
sequenceDiagram
participant Stmt as Statement processing
participant Access as Access analyzer
participant Disjoint as PointerAccessIsDisjoint
participant Sync as Sync inserter
Stmt->>Access: identify conflicting accesses (reads/writes)
Access->>Disjoint: if both pointer accesses → prove disjointness
alt Proven disjoint
Disjoint-->>Access: no conflict
Access->>Stmt: skip sync
else Not proven / non-pointer
Disjoint-->>Access: conflict
Access->>Sync: insert sync (hoist if loop-carried & safe)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/transform/thread_storage_sync.cc (2)
663-669: Bug: sync dropped when thread_count is not a multiple of 32Returning an empty Stmt removes the barrier entirely, risking data races. Fall back to the original sync when named-barrier rewrite is unsafe.
Apply this diff:
- if (thread_count % 32 != 0) { - // TODO(lei): This is a workaround for the case where the thread count is - // not a multiple of 32. we should enhance the pass to analysis index - // instead of buffer expression etc. - return Stmt(); - } + if (thread_count % 32 != 0) { + // Fallback: keep the original full sync (no named-barrier rewrite). + return Evaluate(IRMutatorWithAnalyzer::VisitExpr_(op)); + }
728-736: Guard against non-constant thread extentsmin/extent are assumed IntImm; if dynamic, min_node/extent_node can be null → undefined behavior. Return “full extent” (no rewrite) in that case.
Apply this diff:
const auto *min_node = iv->dom->min.as<IntImmNode>(); const auto *extent_node = iv->dom->extent.as[IntImmNode>(); + if (!min_node || !extent_node) { + // Dynamic or non-constant extent: treat as full extent to avoid unsafe rewrite. + return true; + }src/transform/merge_shared_memory_allocations.cc (1)
943-951: Kill-point boundary condition likely invertedLast statement at gen_level should be chosen when the next statement is at a different level (or end). Current check uses == gen_level, which prematurely stops.
Apply this diff:
- for (; stmt_it != gen_kill_seq.end(); ++stmt_it) { - // Check if next statement has different level - auto next_it = stmt_it + 1; - if (next_it == gen_kill_seq.end() || - stmt_attrs.at(next_it->stmt).level == gen_level) { - last_stmt_at_level = stmt_it->stmt; - break; - } - } + for (; stmt_it != gen_kill_seq.end(); ++stmt_it) { + // Advance until the next statement leaves gen_level + auto next_it = stmt_it + 1; + if (next_it == gen_kill_seq.end() || + stmt_attrs.at(next_it->stmt).level != gen_level) { + last_stmt_at_level = stmt_it->stmt; + break; + } + }
🧹 Nitpick comments (1)
src/transform/merge_shared_memory_allocations.cc (1)
1148-1160: Use 64-bit dtype for arena offsets for robustnessoffset_dtype is derived from the first buffer’s size dtype. Prefer a fixed 64-bit integer for offsets to avoid accidental narrowing across mixed dtype extents (harmless on GPUs today, but safer and clearer).
Apply this minimal change:
- DataType offset_dtype = - buf_infos.empty() ? DataType::Int(32) : buf_infos.front().size_dtype; + DataType offset_dtype = DataType::Int(64);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/merge_shared_memory_allocations.cc(5 hunks)src/transform/thread_storage_sync.cc(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/merge_shared_memory_allocations.cc (1)
src/transform/storage_rewrite.cc (21)
buf(249-256)buf(249-249)buf(509-526)buf(509-509)buf(1740-1765)buf(1740-1740)i(720-860)n(368-372)n(368-368)n(373-377)n(373-373)dtype(712-718)dtype(712-712)seq(906-932)seq(906-906)seq(958-1049)seq(959-961)var(1153-1176)var(1153-1153)e(862-904)e(862-862)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (7)
src/transform/thread_storage_sync.cc (5)
86-92: Good: seed sync_before_stmt from pre-marked insertsSeeding from pre_marked_sync and clearing read/write sets maintains correct hazard windows across known barriers.
101-103: WAW hazard detection addedChecking writes against prior writes (and reads) fixes missed write-after-write conflicts.
128-130: Avoid double insertion for pre-marked syncsGuarding insert_syncs with !pre_marked_sync prevents redundant barriers.
147-149: Loop-carry: include WAW in carry conflictsConsidering both RAW and WAW across loop boundaries improves correctness of carry dependencies.
333-336: Index-compare early-exit condition fixedBreaking only when indices differ avoids prematurely skipping later dimensions. This aligns with the intent to bail after detecting a divergence.
src/transform/merge_shared_memory_allocations.cc (2)
196-224: Relaxed load-level constraint is correctSwitching to ICHECK_LE and attributing same-level reads to the allocation frame avoids false CHECKs and keeps liveness tight.
226-246: Var reference handling mirrors BufferLoad pathAccepting same-level references and attributing them to the allocation frame aligns with the flattening pipeline realities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/merge_shared_memory_allocations.cc (1)
960-973: Bug: kill-point relocation picks the first, not the last, statement at gen_level.Condition checks are inverted. You want the last statement at
gen_level(i.e., when the next statement is different level or end). Current code breaks early when the next statement has the same level.Apply this fix:
- // start from current statement and find the last statement at - // gen_level - - for (; stmt_it != gen_kill_seq.end(); ++stmt_it) { - // Check if next statement has different level - auto next_it = stmt_it + 1; - if (next_it == gen_kill_seq.end() || - stmt_attrs.at(next_it->stmt).level == gen_level) { - last_stmt_at_level = stmt_it->stmt; - break; - } - } + // Start from current statement and find the last consecutive stmt at gen_level. + for (; stmt_it != gen_kill_seq.end(); ++stmt_it) { + if (stmt_attrs.at(stmt_it->stmt).level != gen_level) continue; + auto next_it = stmt_it + 1; + if (next_it == gen_kill_seq.end() || + stmt_attrs.at(next_it->stmt).level != gen_level) { + last_stmt_at_level = stmt_it->stmt; + break; + } + }Add a test where gen is at level k, kill at level k+1, and ensure kill moves to the final stmt at level k.
🧹 Nitpick comments (6)
src/transform/merge_shared_memory_allocations.cc (6)
204-222: Guard against potential size_t underflow in access_level.
std::min(it->second.level, scope_.size() - 1)underflows whenscope_.empty(). It’s usually non-empty here, but make it robust and clearer. Also, the comment says “Add write access” in a read path—nit.- ICHECK_LE(it->second.level, scope_.size()) + ICHECK_LE(it->second.level, scope_.size()) << "Load memory in places other than store."; ... - // When the access happens in the same scope frame as the allocation + // When the access happens in the same scope frame as the allocation // we attribute it to that frame instead of the outer parent. This // keeps the liveness window tight while still accounting for nested // scopes that legitimately touch the buffer deeper in the tree. - size_t access_level = std::min(it->second.level, scope_.size() - 1); + size_t access_level = scope_.empty() + ? 0 + : std::min(it->second.level, scope_.size() - 1); scope_[access_level].touched.push_back(buf);And fix the nearby comment “Add write access.” to “Add read access.” if desired.
232-245: Apply the same empty-scope guard for direct Var reads.Mirror the defensive calc used for BufferLoad to avoid underflow and keep behavior consistent.
- // Attribute same-level uses to the allocation frame, mirroring the + // Attribute same-level uses to the allocation frame, mirroring the // BufferLoad handling to keep reuse decisions consistent. - size_t access_level = std::min(it->second.level, scope_.size() - 1); + size_t access_level = scope_.empty() + ? 0 + : std::min(it->second.level, scope_.size() - 1); scope_[access_level].touched.push_back(buf);
624-674: FreeList best-fit is fine, but Normalize on every allocate/free can be costly.For many blocks, repeated full sort/merge increases overhead unnecessarily. Not critical here (SMEM buffer count is small), but easy to improve.
- Defer coalescing: only call
Normalize()when fragmentation exceeds a threshold or before the next allocation scan.- Maintain blocks sorted by insertion and perform local merges on adjacent neighbors instead of sorting the full vector each time.
775-786: AlignPrimExpr: integer-only guard is correct; behavior matches AlignUp.Small suggestion: reuse AlignUpSize semantics by factoring the constant
(alignment-1)into a helper if repeated, but not necessary.
1176-1188: Unify offset dtype to int64 for safety across mixed-extent dtypes.Offsets can exceed 2^31 in extreme cases. Using the first buffer’s extent dtype may vary across functions. Prefer a fixed 64-bit offset dtype for stability.
- DataType offset_dtype = - buf_infos.empty() ? DataType::Int(32) : buf_infos.front().size_dtype; + DataType offset_dtype = DataType::Int(64);Downcast only at use-sites that demand narrower dtypes if needed.
1215-1253: Overlap detector + fallback is good; document behavior in verbose logs.Runtime sanity check is valuable. Consider logging total number of overlaps to gauge severity. No functional issues.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/transform/merge_shared_memory_allocations.cc(9 hunks)src/transform/storage_access.cc(3 hunks)src/transform/storage_access.h(2 hunks)src/transform/thread_storage_sync.cc(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/storage_access.cc (3)
src/transform/thread_storage_sync.cc (16)
op(421-433)op(421-421)op(479-493)op(479-479)op(524-530)op(524-524)op(531-537)op(531-531)op(538-556)op(538-538)op(558-597)op(558-558)op(612-631)op(612-612)op(682-701)op(682-682)tilelang/language/builtin.py (1)
tma_load(85-94)src/transform/storage_access.h (1)
tma_depth_(167-167)
src/transform/merge_shared_memory_allocations.cc (2)
src/transform/storage_rewrite.cc (21)
buf(249-256)buf(249-249)buf(509-526)buf(509-509)buf(1740-1765)buf(1740-1740)i(720-860)n(368-372)n(368-368)n(373-377)n(373-373)dtype(712-718)dtype(712-712)seq(906-932)seq(906-906)seq(958-1049)seq(959-961)var(1153-1176)var(1153-1153)e(862-904)e(862-862)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1065-1104)ptx_wgmma_rs(1107-1144)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (18)
src/transform/storage_access.h (2)
86-89: LGTM! Well-documented async copy context flag.The addition of
is_async_copywith clear documentation aligns well with the PR's goal of handling TMA async-copy semantics. The default value offalseensures backward compatibility.
166-167: LGTM! Proper depth tracking for TMA operations.The
tma_depth_member is correctly initialized to 0 and will track nesting depth of TMA operations as documented.src/transform/storage_access.cc (4)
32-32: LGTM! Required include for TMA operations.The inclusion of
builtin.his necessary to access the TMA load operation definitions used in the visitor logic below.
304-322: LGTM! Correct TMA depth tracking pattern.The implementation correctly tracks TMA load nesting depth by incrementing before visiting arguments and decrementing after. The early return is appropriate since the function manually visits all arguments, preventing double-traversal.
417-417: LGTM! Correct async copy propagation for reads.The logic correctly marks read accesses as async when occurring within a TMA load context (tma_depth_ > 0).
422-422: LGTM! Correct async copy propagation for writes.The logic correctly marks write accesses as async when occurring within a TMA load context, consistent with the read access handling.
src/transform/thread_storage_sync.cc (5)
102-103: LGTM! Critical WAW conflict detection added.The enhancement correctly detects Write-After-Write conflicts by checking both previous reads and writes, as mentioned in the PR title. This improves synchronization accuracy for scenarios where multiple threads write to the same shared memory location.
255-262: LGTM! Correct async-copy write optimization.The special case correctly relaxes synchronization between multiple async-copy writes (e.g., TMA loads) since they don't require interspersed barriers among themselves. The logic appropriately preserves conflicts with reads to ensure visibility before consumption.
286-296: LGTM! Sound pointer disjointness optimization.The enhancement correctly attempts to prove disjointness for pointer accesses before conservatively assuming conflict. The fallback to treating them as overlapping when disjointness cannot be proven maintains correctness.
401-419: LGTM! Sound implementation of pointer disjointness check.The function correctly attempts to prove disjointness by checking if one range lies entirely before the other using symbolic bounds. The conservative requirement for exactly one touched interval per side prevents false negatives, and the use of
kSymbolicBoundproof strength is appropriate.
134-194: Code implementation is sound, but explicit test coverage for both hoisting paths is not evident.The loop sync hoisting optimization (lines 134-189) has two distinct branches: hoisting syncs before loops when no reads are present (
has_read_in_scope=false), and inserting syncs inside loops when reads are present (has_read_in_scope=true). While the test suite includes loop scenarios, the existing tests do not explicitly isolate and verify both optimization paths.Recommend adding dedicated test cases that specifically exercise:
- Write-only loops (e.g., stmatrix to shared.dyn with no reads)
- Loops with read-write conflicts requiring in-body sync insertion
to ensure both code paths are validated.
src/transform/merge_shared_memory_allocations.cc (7)
34-39: New headers look appropriate.Algorithm/queue/optional/limits/functional/stringstream are all used below. No concerns.
359-364: Good expansion of alignment-sensitive intrinsics.Including
tma_load/storeandptx_wgmma_{ss,rs}under alignment scope is correct and future-proofs the planner.
419-423: Order-of-ops LGTM: liveness before arena packing.The pass first computes liveness and then feeds it into packing; this matches the new arena model.
684-711: Normalize() correctness is good; consider tighter merge condition.Current merge treats touching ranges (
blk.offset == last_end) as mergeable (good). If you keep Normalize per-call, considermerged.reserve(blocks_.size())already used—nice touch. No functional issues.
725-774: Linear-scan packing is solid; tie-breakers ensure determinism.Sorting by start, then larger size first, then var address gives deterministic placements and reduces fragmentation. Nice.
1054-1060: Early return when no shmem allocs: OK.Setting merged size to 0 is correct; using Int(64) there is fine.
1255-1273: Fallback sequential layout: OK; ensure alignment preserved.Sequential path aligns each buffer individually and recomputes total size with
AlignPrimExpr(new_total, align_bytes_). Good. No issues.
|
@codex review |
|
Codex Review: Didn't find any major issues. Breezy! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
|
local test passed, merge:) |
…nflict detection (tile-ai#1146) * atomic_fix * atomic_fix * mem fix * lint fix * add some comments * fix * fix * lint fix * handle async copy * lint fix
This pull request refines the logic for thread storage synchronization planning in the
TileLangThreadSyncPlannerclass. The main focus is on improving conflict detection for read/write accesses and ensuring synchronization statements are only inserted when necessary. The changes also include minor code cleanups for clarity.Synchronization and Conflict Detection Improvements:
insert_syncs) are now only inserted if not already pre-marked, preventing redundant syncs.Minor Cleanups:
Summary by CodeRabbit