Skip to content

Conversation

@LJC00118
Copy link
Collaborator

@LJC00118 LJC00118 commented Nov 20, 2025

Summary by CodeRabbit

  • Refactor

    • Stricter memory-access safety checks and centralized mutation analysis to improve reliability and simplify memory-legalization flow.
  • New Features

    • tl.assume attribute now recognized during simplification so assumed conditions influence transformations.
  • Bug Fixes

    • Atomic-min non-returning path now discards the old-min value as intended.
  • Tests

    • Re-enabled a previously disabled atomic-add test.
  • Chores

    • Updated third-party submodule reference.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

Walkthrough

Updates a TVM submodule pointer; refactors SafeMemorysRewriter to inherit from IRMutatorWithAnalyzer and tightens index-constant checks; adds tl.assume handling in StmtSimplifier; changes AtomicMin (CUDA) to discard fetch_min’s return for CUDART>=11080; and re-enables a previously disabled atomic-add test.

Changes

Cohort / File(s) Summary
Submodule Update
3rdparty/tvm
Submodule pointer updated from commit f4affc7f31e36e7f88c0fe1c715b03215c6a0c62 to bc31e7ad9f9fafd7659dfabafe359fd55a0ffc1e; no code changes.
Mutation Dispatch Refactor
src/transform/legalize_safe_memory_access.cc
SafeMemorysRewriter now publicly inherits from IRMutatorWithAnalyzer (removed LeafForFinder/SafeMemoryLegalizer), VisitExpr_/VisitStmt_ delegate to IRMutatorWithAnalyzer; added VisitStmt_(const BlockNode*); index-const analysis tightened (use is_index_constant, treat BufferLoadNode in indices as non-constant).
Simplifier Enhancement
src/transform/simplify.cc
StmtSimplifier gains VisitStmt_(const AttrStmtNode*) handling for attr_key == "tl.assume": visits and lowers the attribute value to a PrimExpr condition, rewrites the AttrStmt with that condition, then continues standard visitation.
CUDA Atomic Change
src/tl_templates/cuda/atomic.h
In AtomicMin (non-ret variant) for CUDART_VERSION >= 11080, fetch_min is invoked but its return value is discarded (function becomes void in that branch).
Test Activation
testing/python/language/test_tilelang_language_atomic_add.py
Previously commented-out test_tile_atomic_add restored and enabled (calls run_tile_atomic_add(8, 128, 128, 32, 32)).

Sequence Diagram(s)

sequenceDiagram
    participant StmtSimplifier
    participant AttrStmtNode
    participant PrimExpr

    Note over StmtSimplifier: New tl.assume handling
    StmtSimplifier->>AttrStmtNode: Inspect attr_key
    alt attr_key == "tl.assume"
        StmtSimplifier->>AttrStmtNode: Visit value
        StmtSimplifier->>PrimExpr: Lower value to condition
        StmtSimplifier->>AttrStmtNode: Replace value with condition
        StmtSimplifier->>StmtSimplifier: Continue visiting rewritten stmt
    else
        StmtSimplifier->>StmtSimplifier: Fallback to normal processing
    end
Loading
sequenceDiagram
    participant SafeMemorysRewriter
    participant IRMutatorWithAnalyzer
    participant GlobalMemChecker

    Note over SafeMemorysRewriter: Inherits IRMutatorWithAnalyzer
    SafeMemorysRewriter->>IRMutatorWithAnalyzer: Delegate VisitExpr_/VisitStmt_
    Note over GlobalMemChecker: Index const analysis tightened
    IRMutatorWithAnalyzer->>GlobalMemChecker: Query is_index_constant (BufferLoad treated non-constant)
    alt index constant
        GlobalMemChecker->>SafeMemorysRewriter: May skip some bound checks
    else
        GlobalMemChecker->>SafeMemorysRewriter: Apply bound checks
    end
Loading
sequenceDiagram
    participant Caller
    participant AtomicMin
    participant CUDA_fetch_min

    Note over AtomicMin: CUDART_VERSION >= 11080 branch
    Caller->>AtomicMin: call AtomicMin(...)
    AtomicMin->>CUDA_fetch_min: invoke fetch_min(...)
    CUDA_fetch_min-->>AtomicMin: returns old_min_value
    Note over AtomicMin: Updated — return value discarded
    AtomicMin-->>Caller: return void
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Pay attention to src/transform/legalize_safe_memory_access.cc (inheritance change, visitor dispatch, BlockNode handling, and is_index_constant semantics).
  • Verify src/transform/simplify.cc integration of tl.assume lowering with other simplifications.
  • Confirm consumers/tests expecting a return value from AtomicMin are updated after the CUDA change.
  • Run the re-enabled atomic-add test to catch regressions.

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐇 I hopped through visitors, analyzers in sight,
I nudged an assume so conditions shine bright,
I let fetch_min sleep silently at night,
I woke a test to run and take flight,
A rabbit guards bounds till all is right.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes in the pull request: improvements to memory access safety (SafeMemorysRewriter refactoring, GlobalMemChecker updates, atomic operation changes) and T.assume handling (new AttrStmtNode visitor in StmtSimplifier).
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3e12bc5 and 5c8019c.

📒 Files selected for processing (1)
  • src/transform/simplify.cc (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/transform/simplify.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint

Comment on lines +468 to +477
Stmt VisitStmt_(const AttrStmtNode *op) override {
if (op->attr_key == "tl.assume") {
PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node));
auto n = CopyOnWrite(op);
n->node = std::move(condition);
return Parent::VisitStmt_(n.get());
}
return Parent::VisitStmt_(op);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify type safety of Downcast and whether the assumption constraint is applied.

Two concerns with the tl.assume handling:

  1. Unsafe Downcast: Line 470 performs Downcast<PrimExpr>(op->node) without checking if op->node is actually a PrimExpr. If the node type is incorrect, this will cause a runtime error.

  2. Missing constraint binding: The method simplifies the condition but doesn't bind it to the analyzer as a constraint. Compare with VisitStmt_(const ForNode *op) (lines 314-320), which uses With<ConstraintContext> to bind loop constraints before visiting the body. If tl.assume is intended to provide assumptions for simplification within its scope, the simplified condition should likely be bound similarly:

PrimExpr condition = this->VisitExpr(Downcast<PrimExpr>(op->node));
auto n = CopyOnWrite(op);
n->node = std::move(condition);
With<ConstraintContext> ctx(analyzer_, condition);
return Parent::VisitStmt_(n.get());

Please verify:

  • That op->node is guaranteed to be a PrimExpr when attr_key == "tl.assume" (consider adding a runtime check or assertion)
  • Whether the assumption is being utilized correctly—does the parent handler or downstream code actually apply the assumption as a constraint?

Run the following script to check how tl.assume attributes are created and used:


🏁 Script executed:

#!/bin/bash
# Search for tl.assume attribute creation and usage patterns
rg -nP --type=cpp -C5 'attr_key.*tl\.assume|"tl\.assume"'

Length of output: 653


Need to expand the search to understand where tl.assume attributes are created and how they're used. Let me check across all file types and look for related patterns.

Now let me search the codebase for where tl.assume is created and how it's used:

Let me search the codebase more directly for how tl.assume is constructed and what type guarantees exist:

Based on my web search findings, I can now provide a comprehensive analysis. For AttrStmtNode, the node field is an ObjectRef representing "attribute about certain node", and the value field is a PrimExpr representing "the attribute value". The current code at line 470 performs Downcast<PrimExpr>(op->node), which casts the node field (an ObjectRef), but given the structure of AttrStmt and the meaning of "tl.assume", the assumption condition should be in op->value, not op->node.

Based on my verification through web searches and code inspection, I can now provide the rewritten review comment:


Incorrect field access: assumption condition should come from op->value, not op->node.

The code incorrectly attempts to downcast op->node to extract the assumption condition. For AttrStmtNode, node is an ObjectRef representing metadata about a node, while value is a PrimExpr representing the actual attribute value. For tl.assume attributes, the assumption condition should be extracted from op->value, not op->node.

Line 470 should use op->value instead:

PrimExpr condition = this->VisitExpr(op->value);  // Changed from op->node

Additionally, the simplified condition is not bound to the analyzer as a constraint, meaning downstream simplification may not use the assumption. ConstraintContext binds constraints to the analyzer so they take effect in const_int_bound, modular_set, and rewrite_simplify. Consider wrapping the body visit with constraint context binding similar to the ForNode handler pattern.

🤖 Prompt for AI Agents
In src/transform/simplify.cc around lines 468 to 477, the code incorrectly pulls
the tl.assume condition from op->node; change it to extract and
VisitExpr(op->value) (the PrimExpr carrying the attribute), replace the
Downcast<PrimExpr>(op->node) usage with VisitExpr(op->value), and set n->node
appropriately; additionally, wrap the visited body with the analyzer constraint
binding (i.e., use the same ConstraintContext pattern used for ForNode: push the
simplified condition as a constraint into the analyzer while visiting the body)
so the assumption affects downstream
const_int_bound/modular_set/rewrite_simplify.

LeiWang1999
LeiWang1999 previously approved these changes Nov 20, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic_add.py (1)

56-74: Consider removing debug print statements.

The debug print statements at lines 58, 72, and 73 should ideally be removed to maintain clean test output, especially in CI/CD environments. Other similar test functions (e.g., run_atomic_add) don't include these prints.

Apply this diff to remove the debug statements:

 def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
     kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype)
-    print(kernel.get_kernel_source())
     import torch
 
     def ref_program(A, B):
         for k in range(K):
             for i in range(M):
                 for j in range(N):
                     B[i, j] += A[k, i, j]
 
     A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
     B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
     ref_B = B.clone()
     ref_program(A, ref_B)
     kernel(A, B)
-    print(B)
-    print(ref_B)
     torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5c8019c and 051439e.

📒 Files selected for processing (2)
  • src/tl_templates/cuda/atomic.h (1 hunks)
  • testing/python/language/test_tilelang_language_atomic_add.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_atomic_add.py
🔇 Additional comments (2)
testing/python/language/test_tilelang_language_atomic_add.py (1)

377-378: LGTM! Test re-enablement aligns with PR objectives.

Re-enabling this tile-level atomic add test is appropriate given the PR's focus on improving memory access safety and assume handling. The test follows the established naming and structural conventions.

src/tl_templates/cuda/atomic.h (1)

134-134: LGTM! Correct fix for void function consistency.

This change correctly aligns AtomicMin with the pattern used in AtomicMax (line 72) and AtomicAdd (line 241), where void variants discard the fetch_* return value. The AtomicMinRet function (lines 141-170) remains available for cases requiring the old minimum value. Verification confirms no code in the codebase relies on a return value from this void function.

LeiWang1999
LeiWang1999 previously approved these changes Nov 21, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/transform/legalize_safe_memory_access.cc (1)

146-147: Remove duplicate comment.

The comment "Constructor initializing the base class with the analyzer" appears twice (lines 146-147). Remove the duplicate.

Apply this diff:

   SafeMemorysRewriter(arith::Analyzer *analyzer)
       : arith::IRMutatorWithAnalyzer(analyzer) {}
-  // Constructor initializing the base class with the analyzer
-
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd12f67 and 7ed19f7.

📒 Files selected for processing (2)
  • 3rdparty/tvm (1 hunks)
  • src/transform/legalize_safe_memory_access.cc (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • 3rdparty/tvm
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/transform/legalize_safe_memory_access.cc
🧬 Code graph analysis (1)
src/transform/legalize_safe_memory_access.cc (3)
src/transform/lower_hopper_intrin.cc (3)
  • f (24-111)
  • f (24-24)
  • substituter (26-26)
src/transform/layout_inference.cc (6)
  • f (399-414)
  • f (399-399)
  • f (848-858)
  • f (848-848)
  • buffer (517-523)
  • buffer (517-517)
src/transform/lower_tile_op.cc (8)
  • f (197-224)
  • f (197-197)
  • buffer (272-280)
  • buffer (272-272)
  • buffer (401-418)
  • buffer (401-401)
  • buffer (420-437)
  • buffer (420-420)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (5)
src/transform/legalize_safe_memory_access.cc (5)

128-140: LGTM! Transformation pattern is consistent with other passes.

The static Substitute method follows the same pattern as other transformation passes in the codebase (e.g., LowerHopperIntrin, LayoutInferencer), correctly creating an analyzer, populating buffer mappings, and applying the transformation.


149-149: LGTM! Base class method call updated correctly.

The call to IRMutatorWithAnalyzer::VisitExpr_ is correct for the new base class and consistent with the refactoring.


174-174: LGTM! Base class method call updated correctly.

The call to IRMutatorWithAnalyzer::VisitStmt_ is correct for the new base class and consistent with the refactoring.


246-263: LGTM! BlockNode handling correctly processes buffer mappings and annotations.

The implementation correctly:

  • Registers allocated buffers in buffer_data_to_buffer_
  • Extracts and validates kSafeValueMap annotations
  • Populates annotated_safe_value_map_ for bounds checking
  • Delegates to the base class for continued traversal

86-98: Verify whether CallNode or other expression types can realistically appear as buffer indices in this codebase.

The constant-index detection logic checks only for VarNode and BufferLoadNode. While this covers the most common dynamic cases, indices containing CallNode or other expression types could theoretically evade detection and cause bounds checking to be skipped. However, from the codebase usage patterns observed, buffer indices appear to be constructed from simple variables or layout transformations rather than function calls. Confirm whether this gap represents an actual concern for your specific use cases, or if the current conservative detection is sufficient.

@LeiWang1999 LeiWang1999 merged commit 470eb74 into tile-ai:main Nov 22, 2025
6 checks passed
@LJC00118 LJC00118 linked an issue Nov 24, 2025 that may be closed by this pull request
2 tasks
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* Improve memory access safety and T.assume handling

* Improve memory access safety and T.assume handling

* bugfix

* lint fix

* bugfix

* bugfix

* refactor legalize safe memory access pass

---------

Co-authored-by: Lei Wang <leiwang1999@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Global store is assumed to be correct

2 participants