Skip to content

Commit eff7916

Browse files
committed
minor fix
1 parent d592fbf commit eff7916

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

tilelang/language/v2/ast.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,18 @@ def flush_binds():
396396
def visit_Assign(self, node: ast.Assign) -> list[ast.AST]:
397397
node = self.generic_visit(node)
398398
rval = node.value
399-
stmts = []
400-
for target in reversed(node.targets):
401-
stmts.extend(self._emit_assign_target(target, rval))
402-
rval = target
403-
return stmts
399+
if len(node.targets) == 1:
400+
return self._emit_assign_target(node.targets[0], rval)
401+
else:
402+
tmp_name = self.get_tmp()
403+
tmp_store = ast.Name(tmp_name, ctx=ast.Store())
404+
tmp_load = ast.Name(tmp_name, ctx=ast.Load())
405+
ast_set_span(tmp_store, node.targets[0])
406+
ast_set_span(tmp_load, node.targets[0])
407+
stmt = self._emit_assign_target(tmp_store, rval)
408+
for target in node.targets:
409+
stmt.extend(self._emit_assign_target(target, tmp_load))
410+
return stmt
404411

405412
def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]:
406413
node = self.generic_visit(node)

tilelang/language/v2/builder.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def unwrap_expr(expr) -> PrimExpr | int | float:
2929
expr = tir.BufferLoad(expr, indices=[0])
3030
elif isinstance(expr, (EqualOp, NotEqualOp)):
3131
expr = expr.asobject()
32-
elif isinstance(expr, tir.IntImm) and expr.dtype == 'int32':
32+
elif isinstance(expr, IntImm) and expr.dtype == 'int32':
3333
expr = expr.value
3434
return expr
3535

@@ -257,10 +257,9 @@ def bind(self, name, value, annot=BaseBuilder.empty):
257257
return res
258258

259259
def unwrap_value(self, value):
260+
value = unwrap_expr(value)
260261
# handle bx, by = tl.Kernel(128, 128), rval is frame
261-
if isinstance(value, tir.meta_var):
262-
return value.value
263-
elif isinstance(value, tir.frame.IRBuilderFrame):
262+
if isinstance(value, tir.frame.IRBuilderFrame):
264263
return self.enter_frame(value)
265264
else:
266265
return value
@@ -295,9 +294,9 @@ def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty
295294
return super().assign_slice(lval, sl, value)
296295

297296
def aug_assign(self, op, target, aug_value):
298-
if isinstance(target, Buffer) and target.scope() == 'local.var':
299-
tir.buffer_store(target, eval_op(op, target, aug_value), 0)
300-
if isinstance(target, Buffer):
297+
if is_var(target):
298+
tir.buffer_store(target, eval_op(op, target[0], aug_value), 0)
299+
elif isinstance(target, Buffer):
301300
raise RuntimeError("Augmented assignment is not supported for Buffer")
302301
else:
303302
return super().aug_assign(op, target, aug_value)
@@ -370,11 +369,7 @@ def rval(self, name: str, value: Any) -> Any:
370369
f"Use immutable variable `{name}` outside its defining region, did you forget **alloc_var**?\n"
371370
f"variable `{name}` is defined in frame: {frame}, current frames: {self.frames}."
372371
)
373-
if isinstance(value, tir.IntImm):
374-
return value.value
375-
if isinstance(value, Buffer) and value.scope() == 'local.var':
376-
return tir.BufferLoad(value, indices=[0])
377-
return super().rval(name, value)
372+
return unwrap_expr(value)
378373

379374
def arg(self, name, value):
380375
if self.find_frame_idx(MacroFrame) is not None:

tilelang/language/v2/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_compiled_object(source: str | ast.AST,
100100
compiled = disk_compile(source, name)
101101
except Exception as e:
102102
source_str = source if isinstance(source, str) else ast.unparse(source)
103-
raise RuntimeError(f'Failed to compile source for {name}:\n{source_str}') from e
103+
raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e
104104
locs = {}
105105
exec(compiled, globals, locs)
106106
return locs[name]

0 commit comments

Comments
 (0)