Skip to content

Commit 5bd3f94

Browse files
authored
[Enhancement] Add role assignment for AllocateNode in warp specialization (#657)
- Implemented a new role assignment for `AllocateNode` in `warp_specialized_rewriter.cc`, setting the role to `kConsumer` to ensure proper handling of memory allocation scenarios. - This can avoid bug when using T.reduce(clear=False)
1 parent 8205791 commit 5bd3f94

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/transform/warp_specialized_rewriter.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
170170
SetRole(op, GetRole(op->block));
171171
}
172172

173+
void VisitStmt_(const AllocateNode *op) final {
174+
StmtVisitor::VisitStmt_(op);
175+
Role role = Role::kConsumer;
176+
SetRole(op, role);
177+
}
178+
173179
template <class NodeType> void HandleBodyStmt(const NodeType *op) {
174180
StmtVisitor::VisitStmt_(op);
175181
SetRole(op, GetRole(op->body));

0 commit comments

Comments
 (0)