@@ -24,7 +24,8 @@ namespace tl {
2424
2525using namespace tir ;
2626
27- static Buffer makeBufferWithLayout (const Buffer &buffer, const Layout &layout) {
27+ static Buffer makeBufferWithLayout (const Buffer &buffer, const Layout &layout,
28+ Map<Var, Var> &var_remap) {
2829 const auto *ptr_type =
2930 TVM_TYPE_AS (buffer->data ->type_annotation , PointerTypeNode);
3031 Type new_type;
@@ -38,7 +39,12 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) {
3839 if (ptr_type->storage_scope == " global" ) {
3940 new_var = buffer->data ;
4041 } else {
41- new_var = Var (buffer->data ->name_hint , new_type);
42+ if (var_remap.count (buffer->data )) {
43+ new_var = var_remap[buffer->data ];
44+ } else {
45+ new_var = Var (buffer->data ->name_hint , new_type);
46+ var_remap.Set (buffer->data , new_var);
47+ }
4248 }
4349 Array<PrimExpr> layout_shape = layout->OutputShape ();
4450 Array<PrimExpr> output_shape = layout_shape;
@@ -62,7 +68,6 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) {
6268 output_shape.insert (output_shape.begin (), replicate_extent);
6369 }
6470 }
65-
6671 return Buffer (new_var, buffer->dtype , output_shape, {}, buffer->elem_offset ,
6772 buffer->name , buffer->data_alignment , buffer->offset_factor ,
6873 buffer->buffer_type );
@@ -106,7 +111,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
106111 .as <Map<Buffer, Layout>>()
107112 .value ();
108113 for (auto [buffer, layout] : layout_map) {
109- buffer_remap_.Set (buffer, makeBufferWithLayout (buffer, layout));
114+ buffer_remap_.Set (buffer,
115+ makeBufferWithLayout (buffer, layout, var_remap_));
110116 layout_map_.Set (buffer, layout);
111117 }
112118 }
@@ -265,21 +271,34 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
265271 if (is_ptx_) {
266272 return load;
267273 }
268-
269- if (buffer_remap_.count (load-> buffer )) {
270- auto new_indices = layout_map_[load-> buffer ]->Forward (load->indices );
274+ auto buffer = load-> buffer ;
275+ if (buffer_remap_.count (buffer)) {
276+ auto new_indices = layout_map_[buffer]->Forward (load->indices );
271277 auto new_buffer = buffer_remap_[load->buffer ];
272278 return BufferLoad (new_buffer, new_indices);
279+ } else if (var_remap_.count (buffer->data )) {
280+ auto new_buffer = Buffer (
281+ var_remap_[buffer->data ], buffer->dtype , buffer->shape ,
282+ buffer->strides , buffer->elem_offset , buffer->name ,
283+ buffer->data_alignment , buffer->offset_factor , buffer->buffer_type );
284+ return BufferLoad (new_buffer, load->indices );
273285 }
274286 return load;
275287 }
276288
277289 Stmt VisitStmt_ (const BufferStoreNode *op) final {
278290 auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_ (op));
279- if (buffer_remap_.count (store->buffer )) {
280- auto new_indices = layout_map_[store->buffer ]->Forward (store->indices );
291+ auto buffer = store->buffer ;
292+ if (buffer_remap_.count (buffer)) {
293+ auto new_indices = layout_map_[buffer]->Forward (store->indices );
281294 auto new_buffer = buffer_remap_[store->buffer ];
282295 return BufferStore (new_buffer, store->value , new_indices);
296+ } else if (var_remap_.count (buffer->data )) {
297+ auto new_buffer = Buffer (
298+ var_remap_[buffer->data ], buffer->dtype , buffer->shape ,
299+ buffer->strides , buffer->elem_offset , buffer->name ,
300+ buffer->data_alignment , buffer->offset_factor , buffer->buffer_type );
301+ return BufferStore (new_buffer, store->value , store->indices );
283302 }
284303 return store;
285304 }
@@ -364,6 +383,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
364383 bool is_ptx_{false };
365384 // Mapping from data Var of a Buffer to Buffer, for lookup
366385 std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
386+ Map<Var, Var> var_remap_;
367387};
368388
369389namespace transform {
0 commit comments