-
Notifications
You must be signed in to change notification settings - Fork 332
Improve memory access safety and T.assume handling
#1292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughUpdates a TVM submodule pointer; refactors SafeMemorysRewriter to inherit from IRMutatorWithAnalyzer and tightens index-constant checks; adds tl.assume handling in StmtSimplifier; changes AtomicMin (CUDA) to discard fetch_min’s return for CUDART>=11080; and re-enables a previously disabled atomic-add test. Changes
Sequence Diagram(s)sequenceDiagram
participant StmtSimplifier
participant AttrStmtNode
participant PrimExpr
Note over StmtSimplifier: New tl.assume handling
StmtSimplifier->>AttrStmtNode: Inspect attr_key
alt attr_key == "tl.assume"
StmtSimplifier->>AttrStmtNode: Visit value
StmtSimplifier->>PrimExpr: Lower value to condition
StmtSimplifier->>AttrStmtNode: Replace value with condition
StmtSimplifier->>StmtSimplifier: Continue visiting rewritten stmt
else
StmtSimplifier->>StmtSimplifier: Fallback to normal processing
end
sequenceDiagram
participant SafeMemorysRewriter
participant IRMutatorWithAnalyzer
participant GlobalMemChecker
Note over SafeMemorysRewriter: Inherits IRMutatorWithAnalyzer
SafeMemorysRewriter->>IRMutatorWithAnalyzer: Delegate VisitExpr_/VisitStmt_
Note over GlobalMemChecker: Index const analysis tightened
IRMutatorWithAnalyzer->>GlobalMemChecker: Query is_index_constant (BufferLoad treated non-constant)
alt index constant
GlobalMemChecker->>SafeMemorysRewriter: May skip some bound checks
else
GlobalMemChecker->>SafeMemorysRewriter: Apply bound checks
end
sequenceDiagram
participant Caller
participant AtomicMin
participant CUDA_fetch_min
Note over AtomicMin: CUDART_VERSION >= 11080 branch
Caller->>AtomicMin: call AtomicMin(...)
AtomicMin->>CUDA_fetch_min: invoke fetch_min(...)
CUDA_fetch_min-->>AtomicMin: returns old_min_value
Note over AtomicMin: Updated — return value discarded
AtomicMin-->>Caller: return void
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/simplify.cc(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/simplify.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
| Stmt VisitStmt_(const AttrStmtNode *op) override { | ||
| if (op->attr_key == "tl.assume") { | ||
| PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node)); | ||
| auto n = CopyOnWrite(op); | ||
| n->node = std::move(condition); | ||
| return Parent::VisitStmt_(n.get()); | ||
| } | ||
| return Parent::VisitStmt_(op); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify type safety of Downcast and whether the assumption constraint is applied.
Two concerns with the tl.assume handling:
-
Unsafe Downcast: Line 470 performs
Downcast<PrimExpr>(op->node)without checking ifop->nodeis actually aPrimExpr. If the node type is incorrect, this will cause a runtime error. -
Missing constraint binding: The method simplifies the condition but doesn't bind it to the analyzer as a constraint. Compare with
VisitStmt_(const ForNode *op)(lines 314-320), which usesWith<ConstraintContext>to bind loop constraints before visiting the body. Iftl.assumeis intended to provide assumptions for simplification within its scope, the simplified condition should likely be bound similarly:
PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node));
auto n = CopyOnWrite(op);
n->node = std::move(condition);
With<ConstraintContext> ctx(analyzer_, condition);
return Parent::VisitStmt_(n.get());Please verify:
- That
op->nodeis guaranteed to be aPrimExprwhenattr_key == "tl.assume"(consider adding a runtime check or assertion) - Whether the assumption is being utilized correctly—does the parent handler or downstream code actually apply the assumption as a constraint?
Run the following script to check how tl.assume attributes are created and used:
🏁 Script executed:
#!/bin/bash
# Search for tl.assume attribute creation and usage patterns
rg -nP --type=cpp -C5 'attr_key.*tl\.assume|"tl\.assume"'Length of output: 653
Need to expand the search to understand where tl.assume attributes are created and how they're used. Let me check across all file types and look for related patterns.
Now let me search the codebase for where tl.assume is created and how it's used:
Let me search the codebase more directly for how tl.assume is constructed and what type guarantees exist:
Based on my web search findings, I can now provide a comprehensive analysis. For AttrStmtNode, the node field is an ObjectRef representing "attribute about certain node", and the value field is a PrimExpr representing "the attribute value". The current code at line 470 performs Downcast<PrimExpr>(op->node), which casts the node field (an ObjectRef), but given the structure of AttrStmt and the meaning of "tl.assume", the assumption condition should be in op->value, not op->node.
Based on my verification through web searches and code inspection, I can now provide the rewritten review comment:
Incorrect field access: assumption condition should come from op->value, not op->node.
The code incorrectly attempts to downcast op->node to extract the assumption condition. For AttrStmtNode, node is an ObjectRef representing metadata about a node, while value is a PrimExpr representing the actual attribute value. For tl.assume attributes, the assumption condition should be extracted from op->value, not op->node.
Line 470 should use op->value instead:
PrimExpr condition = this->VisitExpr(op->value); // Changed from op->nodeAdditionally, the simplified condition is not bound to the analyzer as a constraint, meaning downstream simplification may not use the assumption. ConstraintContext binds constraints to the analyzer so they take effect in const_int_bound, modular_set, and rewrite_simplify. Consider wrapping the body visit with constraint context binding similar to the ForNode handler pattern.
🤖 Prompt for AI Agents
In src/transform/simplify.cc around lines 468 to 477, the code incorrectly pulls
the tl.assume condition from op->node; change it to extract and
VisitExpr(op->value) (the PrimExpr carrying the attribute), replace the
Downcast<PrimExpr>(op->node) usage with VisitExpr(op->value), and set n->node
appropriately; additionally, wrap the visited body with the analyzer constraint
binding (i.e., use the same ConstraintContext pattern used for ForNode: push the
simplified condition as a constraint into the analyzer while visiting the body)
so the assumption affects downstream
const_int_bound/modular_set/rewrite_simplify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic_add.py (1)
56-74: Consider removing debug print statements.The debug print statements at lines 58, 72, and 73 should ideally be removed to maintain clean test output, especially in CI/CD environments. Other similar test functions (e.g.,
run_atomic_add) don't include these prints.Apply this diff to remove the debug statements:
def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"): kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) - print(kernel.get_kernel_source()) import torch def ref_program(A, B): for k in range(K): for i in range(M): for j in range(N): B[i, j] += A[k, i, j] A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() ref_B = B.clone() ref_program(A, ref_B) kernel(A, B) - print(B) - print(ref_B) torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tl_templates/cuda/atomic.h(1 hunks)testing/python/language/test_tilelang_language_atomic_add.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_atomic_add.py
🔇 Additional comments (2)
testing/python/language/test_tilelang_language_atomic_add.py (1)
377-378: LGTM! Test re-enablement aligns with PR objectives.Re-enabling this tile-level atomic add test is appropriate given the PR's focus on improving memory access safety and assume handling. The test follows the established naming and structural conventions.
src/tl_templates/cuda/atomic.h (1)
134-134: LGTM! Correct fix for void function consistency.This change correctly aligns
AtomicMinwith the pattern used inAtomicMax(line 72) andAtomicAdd(line 241), where void variants discard thefetch_*return value. TheAtomicMinRetfunction (lines 141-170) remains available for cases requiring the old minimum value. Verification confirms no code in the codebase relies on a return value from this void function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/transform/legalize_safe_memory_access.cc (1)
146-147: Remove duplicate comment.The comment "Constructor initializing the base class with the analyzer" appears twice (lines 146-147). Remove the duplicate.
Apply this diff:
SafeMemorysRewriter(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} - // Constructor initializing the base class with the analyzer -
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
3rdparty/tvm(1 hunks)src/transform/legalize_safe_memory_access.cc(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- 3rdparty/tvm
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/legalize_safe_memory_access.cc
🧬 Code graph analysis (1)
src/transform/legalize_safe_memory_access.cc (3)
src/transform/lower_hopper_intrin.cc (3)
f(24-111)f(24-24)substituter(26-26)src/transform/layout_inference.cc (6)
f(399-414)f(399-399)f(848-858)f(848-848)buffer(517-523)buffer(517-517)src/transform/lower_tile_op.cc (8)
f(197-224)f(197-197)buffer(272-280)buffer(272-272)buffer(401-418)buffer(401-401)buffer(420-437)buffer(420-420)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (5)
src/transform/legalize_safe_memory_access.cc (5)
128-140: LGTM! Transformation pattern is consistent with other passes.The static
Substitutemethod follows the same pattern as other transformation passes in the codebase (e.g.,LowerHopperIntrin,LayoutInferencer), correctly creating an analyzer, populating buffer mappings, and applying the transformation.
149-149: LGTM! Base class method call updated correctly.The call to
IRMutatorWithAnalyzer::VisitExpr_is correct for the new base class and consistent with the refactoring.
174-174: LGTM! Base class method call updated correctly.The call to
IRMutatorWithAnalyzer::VisitStmt_is correct for the new base class and consistent with the refactoring.
246-263: LGTM! BlockNode handling correctly processes buffer mappings and annotations.The implementation correctly:
- Registers allocated buffers in
buffer_data_to_buffer_- Extracts and validates
kSafeValueMapannotations- Populates
annotated_safe_value_map_for bounds checking- Delegates to the base class for continued traversal
86-98: Verify whetherCallNodeor other expression types can realistically appear as buffer indices in this codebase.The constant-index detection logic checks only for
VarNodeandBufferLoadNode. While this covers the most common dynamic cases, indices containingCallNodeor other expression types could theoretically evade detection and cause bounds checking to be skipped. However, from the codebase usage patterns observed, buffer indices appear to be constructed from simple variables or layout transformations rather than function calls. Confirm whether this gap represents an actual concern for your specific use cases, or if the current conservative detection is sufficient.
* Improve memory access safety and T.assume handling * Improve memory access safety and T.assume handling * bugfix * lint fix * bugfix * bugfix * refactor legalize safe memory access pass --------- Co-authored-by: Lei Wang <leiwang1999@outlook.com>
Summary by CodeRabbit
Refactor
New Features
Bug Fixes
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.