@@ -29,7 +29,7 @@ def unwrap_expr(expr) -> PrimExpr | int | float:
2929 expr = tir .BufferLoad (expr , indices = [0 ])
3030 elif isinstance (expr , (EqualOp , NotEqualOp )):
3131 expr = expr .asobject ()
32- elif isinstance (expr , tir . IntImm ) and expr .dtype == 'int32' :
32+ elif isinstance (expr , IntImm ) and expr .dtype == 'int32' :
3333 expr = expr .value
3434 return expr
3535
@@ -257,10 +257,9 @@ def bind(self, name, value, annot=BaseBuilder.empty):
257257 return res
258258
259259 def unwrap_value (self , value ):
260+ value = unwrap_expr (value )
260261 # handle bx, by = tl.Kernel(128, 128), rval is frame
261- if isinstance (value , tir .meta_var ):
262- return value .value
263- elif isinstance (value , tir .frame .IRBuilderFrame ):
262+ if isinstance (value , tir .frame .IRBuilderFrame ):
264263 return self .enter_frame (value )
265264 else :
266265 return value
@@ -295,9 +294,9 @@ def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty
295294 return super ().assign_slice (lval , sl , value )
296295
297296 def aug_assign (self , op , target , aug_value ):
298- if isinstance (target , Buffer ) and target . scope () == 'local.var' :
299- tir .buffer_store (target , eval_op (op , target , aug_value ), 0 )
300- if isinstance (target , Buffer ):
297+ if is_var (target ) :
298+ tir .buffer_store (target , eval_op (op , target [ 0 ] , aug_value ), 0 )
299+ elif isinstance (target , Buffer ):
301300 raise RuntimeError ("Augmented assignment is not supported for Buffer" )
302301 else :
303302 return super ().aug_assign (op , target , aug_value )
@@ -370,11 +369,7 @@ def rval(self, name: str, value: Any) -> Any:
370369 f"Use immutable variable `{ name } ` outside its defining region, did you forget **alloc_var**?\n "
371370 f"variable `{ name } ` is defined in frame: { frame } , current frames: { self .frames } ."
372371 )
373- if isinstance (value , tir .IntImm ):
374- return value .value
375- if isinstance (value , Buffer ) and value .scope () == 'local.var' :
376- return tir .BufferLoad (value , indices = [0 ])
377- return super ().rval (name , value )
372+ return unwrap_expr (value )
378373
379374 def arg (self , name , value ):
380375 if self .find_frame_idx (MacroFrame ) is not None :
0 commit comments