-
Notifications
You must be signed in to change notification settings - Fork 332
[Language] Enhance T.alloc_var for AugAssign and AnnAsign
#979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughIntroduces a TileLang parser override package and wires it into the language package. On import, overrides register custom handlers for AugAssign and AnnAssign to route assignments to local buffers via buffer_store when applicable; otherwise fall back to default evaluation. No public API signatures changed. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant UserCode as import tilelang.language
participant LangInit as tilelang/language/__init__.py
participant Overrides as tilelang/language/overrides/__init__.py
participant ParserOv as overrides/parser.py
participant Dispatcher as TVMScript Parser Dispatcher
UserCode->>LangInit: import
LangInit->>Overrides: import _overrides
Overrides->>ParserOv: import parser overrides
ParserOv->>Dispatcher: register tilelang_visit_ann_assign
ParserOv->>Dispatcher: register tilelang_visit_aug_assign
note right of Dispatcher: Handlers available for parsing
sequenceDiagram
autonumber
participant Parser as TVMScript Parser
participant Handler as TileLang Assign Handler
participant VarTable as var_table/frame
participant TIR as tvm.tir API
Parser->>Handler: visit AnnAssign/AugAssign(node)
Handler->>Handler: compute spans, set Load/Store ctx
Handler->>VarTable: resolve LHS/RHS, track locals
alt LHS is local buffer target
alt subscripted or 1D with [0]
Handler->>TIR: buffer_store(local_buf, value, indices)
else non-subscripted local.var
Handler->>TIR: buffer_store(local_buf, value, ...)
end
else
Handler->>Parser: fallback eval_assign/bind
end
Handler-->>Parser: assignment bound
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tilelang/language/__init__.py (1)
10-10: Drop the redundantnoqaRuff reports this
# noqa: F401as unused (RUF100). Since the override import is already safely namespaced via_overrides, we can simply remove the directive (or otherwise use the name) to keep lint clean.tilelang/language/overrides/__init__.py (1)
8-8: Remove the unusednoqaRuff flags
# noqa: F401here as redundant (RUF100). Dropping the directive keeps the side-effect import while satisfying lint.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/language/__init__.py(1 hunks)tilelang/language/overrides/__init__.py(1 hunks)tilelang/language/overrides/parser.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/overrides/parser.py (2)
tilelang/language/ast/ir.py (2)
buffer_store(1263-1300)LetStmt(880-908)tilelang/language/parser/parser.py (2)
bind_assign_value(114-160)visit_tvm_annotation(423-437)
🪛 Ruff (0.13.3)
tilelang/language/__init__.py
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/overrides/parser.py
20-20: Possible hardcoded password assigned to argument: "token"
(S106)
68-68: Possible hardcoded password assigned to argument: "token"
(S106)
tilelang/language/overrides/__init__.py
8-8: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
| """Override `AnnAssign` to support writes into `local.var` buffers.""" | ||
| lhs = node.target | ||
| rhs = self.eval_expr(node.value) | ||
| ann_var = self.visit_tvm_annotation(node.annotation) | ||
| if not isinstance(ann_var, Var): | ||
| self.report_error(node.annotation, "Annotation should be Var") | ||
|
|
||
| if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): | ||
| load_ctx = doc.Load() | ||
| store_ctx = doc.Store() | ||
| lhs.ctx = load_ctx | ||
| lhs_value = self.eval_expr(lhs) | ||
| lhs.ctx = store_ctx | ||
| if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and | ||
| len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): | ||
| T.buffer_store(lhs_value.buffer, rhs, indices=[0]) | ||
| return | ||
|
|
||
| self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value) | ||
| frame = T.LetStmt(rhs, var=ann_var) | ||
| frame.add_callback(partial(frame.__exit__, None, None, None)) | ||
| frame.__enter__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle AnnAssign without a value
doc.AnnAssign nodes are allowed to omit value (e.g. buf: T.handle). The override now calls self.eval_expr(node.value) unconditionally, so value=None will raise before reaching eval_assign, breaking valid TileLang/TIR scripts that previously parsed. Please restore the upstream guard: only evaluate/store rhs when node.value is not None, and skip the LetStmt when there's no initializer.
- rhs = self.eval_expr(node.value)
+ rhs = self.eval_expr(node.value) if node.value is not None else None
@@
- if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get():
+ if rhs is not None and isinstance(lhs, doc.Name) and lhs.id in self.var_table.get():
@@
- self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value)
- frame = T.LetStmt(rhs, var=ann_var)
- frame.add_callback(partial(frame.__exit__, None, None, None))
- frame.__enter__()
+ self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value)
+ if rhs is not None:
+ frame = T.LetStmt(rhs, var=ann_var)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """Override `AnnAssign` to support writes into `local.var` buffers.""" | |
| lhs = node.target | |
| rhs = self.eval_expr(node.value) | |
| ann_var = self.visit_tvm_annotation(node.annotation) | |
| if not isinstance(ann_var, Var): | |
| self.report_error(node.annotation, "Annotation should be Var") | |
| if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): | |
| load_ctx = doc.Load() | |
| store_ctx = doc.Store() | |
| lhs.ctx = load_ctx | |
| lhs_value = self.eval_expr(lhs) | |
| lhs.ctx = store_ctx | |
| if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and | |
| len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): | |
| T.buffer_store(lhs_value.buffer, rhs, indices=[0]) | |
| return | |
| self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value) | |
| frame = T.LetStmt(rhs, var=ann_var) | |
| frame.add_callback(partial(frame.__exit__, None, None, None)) | |
| frame.__enter__() | |
| """Override `AnnAssign` to support writes into `local.var` buffers.""" | |
| lhs = node.target | |
| rhs = self.eval_expr(node.value) if node.value is not None else None | |
| ann_var = self.visit_tvm_annotation(node.annotation) | |
| if not isinstance(ann_var, Var): | |
| self.report_error(node.annotation, "Annotation should be Var") | |
| if rhs is not None and isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): | |
| load_ctx = doc.Load() | |
| store_ctx = doc.Store() | |
| lhs.ctx = load_ctx | |
| lhs_value = self.eval_expr(lhs) | |
| lhs.ctx = store_ctx | |
| if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and | |
| len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): | |
| T.buffer_store(lhs_value.buffer, rhs, indices=[0]) | |
| return | |
| self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value) | |
| if rhs is not None: | |
| frame = T.LetStmt(rhs, var=ann_var) | |
| frame.add_callback(partial(frame.__exit__, None, None, None)) | |
| frame.__enter__() |
🤖 Prompt for AI Agents
In tilelang/language/overrides/parser.py around lines 70 to 91: the code
unconditionally calls self.eval_expr(node.value) and always creates a LetStmt,
which breaks valid AnnAssign nodes that omit value; change it so rhs is only
computed when node.value is not None, guard the buffer_store and LHS evaluation
that use rhs behind that same check, and only create/enter the T.LetStmt when a
value/initializer exists (if node.value is None skip evaluating rhs, skip the
buffer_store branch and do not create the LetStmt), but keep the annotation
handling and eval_assign call for the declaration path.
|
merged as local test passed :) |
This pull request introduces TileLang-specific parser overrides to support custom assignment semantics, particularly for buffer writes into
local.varscopes. The changes ensure that augmented and annotated assignments in TileLang scripts are handled correctly, extending the upstream TVMScript parser functionality. The main changes are grouped below.Parser override infrastructure:
tilelang/language/overridespackage that registers TileLang-specific parser overrides on import, ensuring custom behavior is available automatically.tilelang/language/__init__.pyto import theoverridespackage, activating the custom handlers for all TileLang users.Custom parser logic for assignments:
tilelang_visit_aug_assignandtilelang_visit_ann_assignfunctions intilelang/language/overrides/parser.pyto overrideAugAssignandAnnAssignhandling, allowing writes intolocal.varbuffers and supporting TileLang-specific assignment semantics.Summary by CodeRabbit
New Features
Bug Fixes