Skip to content

Conversation

@Elevator14B
Copy link
Collaborator

@Elevator14B Elevator14B commented Nov 10, 2025

This PR tries to resolve various issues under int64_t static and dynamic shape.

TODOs:

  • Wrapper/adapter should be further improved to use corresponding data type instead of enforcing c_int64;
  • Add corresponding cases as unit tests.

Summary by CodeRabbit

  • Bug Fixes

    • Zero-value checks are now dtype-aware to avoid incorrect comparisons.
    • Dynamic kernel argument types now follow each symbol's declared dtype instead of assuming integer, improving correctness across backends.
  • Tests

    • Added integration tests that exercise very large int64-shaped workloads to validate dynamic-dimension and kernel-argument behavior.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 10, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Propagates per-symbol dtypes through JIT adapters and makes TVM assume zero-comparisons dtype-aware; updates kernel argument typing and marshalling for dynamic symbols and adds tests for int64-sized symbolic/static fills.

Changes

Cohort / File(s) Summary
TVM transform
src/transform/inject_assumes.cc
Added #include <tvm/tir/op.h> and replaced GT(e.expr, 0) with GT(e.expr, make_zero(e.expr->dtype)) in AssumeInjector::build to use a dtype-aware zero.
Core dynamic-symbol refactor
tilelang/jit/adapter/wrapper.py
dynamic_symbolic_set now stores name→dtype; get_dynamic_symbolic_set returns (name, dtype) pairs; call sites updated to unpack pairs and use _lookup_type(dtype) instead of assuming int.
NVRTC adapter
tilelang/jit/adapter/nvrtc/wrapper.py
Iteration unpacks (dyn_sym, dyn_sym_dtype) and function arg types use self._lookup_type(dyn_sym_dtype) (dtype-aware mapping).
Cython adapter
tilelang/jit/adapter/cython/cython_wrapper.pyx
Dynamic-symbol shape/stride values marshalled as ctypes.c_int64 when appended to kernel args (replacing plain Python ints).
Tests (new)
testing/python/language/test_tilelang_language_int64.py
New tests exercising large (int64-sized) symbolic/static fills for CUDA kernels; adds helpers, runners, and test entry points.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Wrapper
    participant Adapter
    participant Kernel

    rect rgb(230,240,220)
    Note over Wrapper,Adapter: NEW dtype-aware dynamic symbol flow
    User->>Wrapper: request dispatch / prim_func
    Wrapper->>Wrapper: collect dynamic symbols as (name, dtype)
    Wrapper->>Adapter: send [(name1,dtype1), (name2,dtype2), ...]
    Adapter->>Adapter: for each (name,dtype): arg_type = _lookup_type(dtype)
    Adapter->>Adapter: marshal values (e.g., ctypes.c_int64 or mapped type)
    Adapter->>Kernel: call kernel with dtype-aware args
    Kernel-->>Adapter: result/return
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Attention points:
    • All call sites of get_dynamic_symbolic_set() must unpack (name, dtype) consistently.
    • Verify _lookup_type() covers expected dtype strings and maps correctly to ctypes.
    • Confirm Cython ctypes.c_int64 choice matches target kernel ABIs and platforms.
    • Validate make_zero(expr->dtype) is correct for all used dtypes in assumes.
    • New tests allocate large GPU memory — check CI feasibility or guard.

Possibly related PRs

Poem

🐇 I hopped through fields of dtype and code,

symbols found labels on each little node,
zeros now match what tensors expect,
args dressed up clean for the kernel connect,
a joyful nibble — merge and goad.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.88% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and concisely summarizes the main objective: fixing issues related to int64_t handling for both static and dynamic shapes, which aligns with all five file changes in the changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7f6efe6 and e273ba3.

📒 Files selected for processing (1)
  • src/transform/inject_assumes.cc (2 hunks)
🔇 Additional comments (2)
src/transform/inject_assumes.cc (2)

9-9: LGTM: Include added for dtype-aware operations.

The include is necessary to provide the make_zero function used in the dtype-aware comparison below.


66-67: LGTM: Dtype-aware zero comparison implemented.

This change correctly implements the suggestion from the previous review to use make_zero(e.expr->dtype) instead of a literal 0. This ensures type consistency when e.expr has dtype int64_t, preventing potential type mismatches in the comparison.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@LeiWang1999
Copy link
Member

@codex review

auto analyzer = arith::Analyzer{};
for (const auto &e : items) {
auto simplified = analyzer.Simplify(GT(e.expr, 0));
auto simplified = analyzer.Simplify(GT(e.expr, IntImm(e.expr->dtype, 0)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use make_zero(e.expr->dtype) instead.

PrimExpr offset = buffer.OffsetOf(indices).back();
// condition for alignment, maybe useless
condition_ = (FloorMod(offset, vector_size_) == 0);
condition_ = (FloorMod(offset, IntImm(offset.dtype(), vector_size_)) == IntImm(offset.dtype(), 0));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_zero(e.expr->dtype)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 602 to 607
buffer = prim_func.buffer_map[param]
for stride in buffer.strides:
if isinstance(stride, tvm.tir.Var):
unique_push_back(stride.name)
unique_push_back(stride.name, str(dim.dtype))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use stride dtype when deduplicating dynamic symbols

The updated dynamic symbol gathering now tracks both name and dtype, but the stride pass still calls unique_push_back(stride.name, str(dim.dtype)). At this point dim refers to whatever value was left from the previous shape loop (or may be undefined if no dynamic shape variables existed), so stride variables inherit the wrong dtype or raise a NameError. This causes the new assertion in unique_push_back to fire and the wrapper to fail even when the stride symbol itself is valid. The call should use stride.dtype for each stride entry.

Useful? React with 👍 / 👎.

@Elevator14B Elevator14B marked this pull request as ready for review November 17, 2025 03:34
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4ee7b24 and 7f6efe6.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_int64.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_int64.py
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_int64.py (3)
tilelang/language/symbolics.py (1)
  • symbolic (25-27)
tilelang/language/loop.py (1)
  • Parallel (12-32)
tilelang/language/v2/dtypes.py (1)
  • bfloat16 (297-297)

Comment on lines +22 to +61
def run_fill_symbolic(n: int):
import torch

x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
fill_symbolic(1.0)(x)
assert x.min() == 1.0 and x.max() == 1.0


def test_fill_symbolic():
# Requires 8GB VRAM
run_fill_symbolic(2**32)


@tilelang.jit
def fill_static(n: int, value: float, dtype="bfloat16"):
block_n = 512

@T.prim_func
def main(x: T.Tensor[n, dtype]):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
# Doesn't yet work with int64-shaped global tensor
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
for i in T.Parallel(block_n):
x[bx * block_n + i] = value

return main


def run_fill_static(n: int):
import torch

x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
fill_static(n, 1.0)(x)
assert x.min() == 1.0 and x.max() == 1.0


def test_fill_static():
# Requires 8GB VRAM
run_fill_static(2**32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Prevent OOM and missing-CUDA crashes in the int64 fill tests

run_fill_symbolic / run_fill_static unconditionally create a torch.zeros buffer on CUDA and the test invokes them with n = 2**32. On hosts without a CUDA-capable GPU this raises immediately because torch.cuda.is_available() is false, and even when a GPU exists the allocation needs ~8.6 GiB for bfloat16 data, which will throw torch.cuda.OutOfMemoryError on the default CI machines long before the kernel is exercised.(debuglab.net)

Please gate these helpers on CUDA availability and skip when the requested tensor cannot fit into the active device before trying to allocate it. One option is:

@@
-import tilelang
-import tilelang.language as T
+import pytest
+import torch
+import tilelang
+import tilelang.language as T
@@
-def run_fill_symbolic(n: int):
-    import torch
-
-    x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
+def run_fill_symbolic(n: int):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA device required for int64 fill tests")
+    elem_bytes = torch.tensor([], dtype=torch.bfloat16).element_size()
+    total_mem = torch.cuda.get_device_properties(0).total_memory
+    if n * elem_bytes > total_mem:
+        pytest.skip(
+            f"Requires ~{n * elem_bytes / (1 << 30):.1f} GiB,"
+            f" but only {total_mem / (1 << 30):.1f} GiB available"
+        )
+    device = torch.device("cuda")
+    x = torch.zeros(n, dtype=torch.bfloat16, device=device)
@@
-def run_fill_static(n: int):
-    import torch
-
-    x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
+def run_fill_static(n: int):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA device required for int64 fill tests")
+    elem_bytes = torch.tensor([], dtype=torch.bfloat16).element_size()
+    total_mem = torch.cuda.get_device_properties(0).total_memory
+    if n * elem_bytes > total_mem:
+        pytest.skip(
+            f"Requires ~{n * elem_bytes / (1 << 30):.1f} GiB,"
+            f" but only {total_mem / (1 << 30):.1f} GiB available"
+        )
+    device = torch.device("cuda")
+    x = torch.zeros(n, dtype=torch.bfloat16, device=device)

This keeps the int64 coverage when the hardware can handle it and lets the suite pass everywhere else.

🤖 Prompt for AI Agents
testing/python/language/test_tilelang_language_int64.py lines 22-61: The helpers
unconditionally allocate a large CUDA tensor (n=2**32) which fails on machines
without CUDA or without enough free VRAM; modify run_fill_symbolic and
run_fill_static to first check torch.cuda.is_available() and skip (or return)
when CUDA is absent, then compute required_bytes = n * element_size (bfloat16 ->
2) and query the device free memory (torch.cuda.mem_get_info or
torch.cuda.get_device_properties/free mem API) and skip when required_bytes >
free_bytes; perform the allocation only after these checks and use pytest.skip
with a clear message so tests are gated safely.

@LeiWang1999 LeiWang1999 merged commit 49c8571 into tile-ai:main Nov 18, 2025
6 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…#1218)

* Fix various issues under int64_t static and dynamic shape.

* Resolve reviewed issues.

* Add unit test.

* fix

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants