-
Notifications
You must be signed in to change notification settings - Fork 332
Fix various issues under int64_t static and dynamic shape.
#1218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughPropagates per-symbol dtypes through JIT adapters and makes TVM assume zero-comparisons dtype-aware; updates kernel argument typing and marshalling for dynamic symbols and adds tests for int64-sized symbolic/static fills. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Wrapper
participant Adapter
participant Kernel
rect rgb(230,240,220)
Note over Wrapper,Adapter: NEW dtype-aware dynamic symbol flow
User->>Wrapper: request dispatch / prim_func
Wrapper->>Wrapper: collect dynamic symbols as (name, dtype)
Wrapper->>Adapter: send [(name1,dtype1), (name2,dtype2), ...]
Adapter->>Adapter: for each (name,dtype): arg_type = _lookup_type(dtype)
Adapter->>Adapter: marshal values (e.g., ctypes.c_int64 or mapped type)
Adapter->>Kernel: call kernel with dtype-aware args
Kernel-->>Adapter: result/return
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
@codex review |
src/transform/inject_assumes.cc
Outdated
| auto analyzer = arith::Analyzer{}; | ||
| for (const auto &e : items) { | ||
| auto simplified = analyzer.Simplify(GT(e.expr, 0)); | ||
| auto simplified = analyzer.Simplify(GT(e.expr, IntImm(e.expr->dtype, 0))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use make_zero(e.expr->dtype) instead.
| PrimExpr offset = buffer.OffsetOf(indices).back(); | ||
| // condition for alignment, maybe useless | ||
| condition_ = (FloorMod(offset, vector_size_) == 0); | ||
| condition_ = (FloorMod(offset, IntImm(offset.dtype(), vector_size_)) == IntImm(offset.dtype(), 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_zero(e.expr->dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
tilelang/jit/adapter/wrapper.py
Outdated
| buffer = prim_func.buffer_map[param] | ||
| for stride in buffer.strides: | ||
| if isinstance(stride, tvm.tir.Var): | ||
| unique_push_back(stride.name) | ||
| unique_push_back(stride.name, str(dim.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use stride dtype when deduplicating dynamic symbols
The updated dynamic symbol gathering now tracks both name and dtype, but the stride pass still calls unique_push_back(stride.name, str(dim.dtype)). At this point dim refers to whatever value was left from the previous shape loop (or may be undefined if no dynamic shape variables existed), so stride variables inherit the wrong dtype or raise a NameError. This causes the new assertion in unique_push_back to fire and the wrapper to fail even when the stride symbol itself is valid. The call should use stride.dtype for each stride entry.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_int64.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_int64.py
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_int64.py (3)
tilelang/language/symbolics.py (1)
symbolic(25-27)tilelang/language/loop.py (1)
Parallel(12-32)tilelang/language/v2/dtypes.py (1)
bfloat16(297-297)
| def run_fill_symbolic(n: int): | ||
| import torch | ||
|
|
||
| x = torch.zeros(n, dtype=torch.bfloat16, device="cuda") | ||
| fill_symbolic(1.0)(x) | ||
| assert x.min() == 1.0 and x.max() == 1.0 | ||
|
|
||
|
|
||
| def test_fill_symbolic(): | ||
| # Requires 8GB VRAM | ||
| run_fill_symbolic(2**32) | ||
|
|
||
|
|
||
| @tilelang.jit | ||
| def fill_static(n: int, value: float, dtype="bfloat16"): | ||
| block_n = 512 | ||
|
|
||
| @T.prim_func | ||
| def main(x: T.Tensor[n, dtype]): | ||
| # Initialize Kernel Context | ||
| with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx: | ||
| # Doesn't yet work with int64-shaped global tensor | ||
| # T.fill(x[bx * block_n : (bx + 1) * block_n], value) | ||
| for i in T.Parallel(block_n): | ||
| x[bx * block_n + i] = value | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| def run_fill_static(n: int): | ||
| import torch | ||
|
|
||
| x = torch.zeros(n, dtype=torch.bfloat16, device="cuda") | ||
| fill_static(n, 1.0)(x) | ||
| assert x.min() == 1.0 and x.max() == 1.0 | ||
|
|
||
|
|
||
| def test_fill_static(): | ||
| # Requires 8GB VRAM | ||
| run_fill_static(2**32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent OOM and missing-CUDA crashes in the int64 fill tests
run_fill_symbolic / run_fill_static unconditionally create a torch.zeros buffer on CUDA and the test invokes them with n = 2**32. On hosts without a CUDA-capable GPU this raises immediately because torch.cuda.is_available() is false, and even when a GPU exists the allocation needs ~8.6 GiB for bfloat16 data, which will throw torch.cuda.OutOfMemoryError on the default CI machines long before the kernel is exercised.(debuglab.net)
Please gate these helpers on CUDA availability and skip when the requested tensor cannot fit into the active device before trying to allocate it. One option is:
@@
-import tilelang
-import tilelang.language as T
+import pytest
+import torch
+import tilelang
+import tilelang.language as T
@@
-def run_fill_symbolic(n: int):
- import torch
-
- x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
+def run_fill_symbolic(n: int):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA device required for int64 fill tests")
+ elem_bytes = torch.tensor([], dtype=torch.bfloat16).element_size()
+ total_mem = torch.cuda.get_device_properties(0).total_memory
+ if n * elem_bytes > total_mem:
+ pytest.skip(
+ f"Requires ~{n * elem_bytes / (1 << 30):.1f} GiB,"
+ f" but only {total_mem / (1 << 30):.1f} GiB available"
+ )
+ device = torch.device("cuda")
+ x = torch.zeros(n, dtype=torch.bfloat16, device=device)
@@
-def run_fill_static(n: int):
- import torch
-
- x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
+def run_fill_static(n: int):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA device required for int64 fill tests")
+ elem_bytes = torch.tensor([], dtype=torch.bfloat16).element_size()
+ total_mem = torch.cuda.get_device_properties(0).total_memory
+ if n * elem_bytes > total_mem:
+ pytest.skip(
+ f"Requires ~{n * elem_bytes / (1 << 30):.1f} GiB,"
+ f" but only {total_mem / (1 << 30):.1f} GiB available"
+ )
+ device = torch.device("cuda")
+ x = torch.zeros(n, dtype=torch.bfloat16, device=device)This keeps the int64 coverage when the hardware can handle it and lets the suite pass everywhere else.
🤖 Prompt for AI Agents
testing/python/language/test_tilelang_language_int64.py lines 22-61: The helpers
unconditionally allocate a large CUDA tensor (n=2**32) which fails on machines
without CUDA or without enough free VRAM; modify run_fill_symbolic and
run_fill_static to first check torch.cuda.is_available() and skip (or return)
when CUDA is absent, then compute required_bytes = n * element_size (bfloat16 ->
2) and query the device free memory (torch.cuda.mem_get_info or
torch.cuda.get_device_properties/free mem API) and skip when required_bytes >
free_bytes; perform the allocation only after these checks and use pytest.skip
with a clear message so tests are gated safely.
…#1218) * Fix various issues under int64_t static and dynamic shape. * Resolve reviewed issues. * Add unit test. * fix --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
This PR tries to resolve various issues under
int64_tstatic and dynamic shape.TODOs:
c_int64;Summary by CodeRabbit
Bug Fixes
Tests