|
3 | 3 | * \brief Warp specialized Pipeline for cuda GPU (sm90+) |
4 | 4 | */ |
5 | 5 |
|
6 | | -#include "arith/ir_visitor_with_analyzer.h" |
7 | | -#include "tir/analysis/var_use_def_analysis.h" |
8 | | -#include <tvm/ffi/reflection/registry.h> |
9 | | -#include <tvm/tir/analysis.h> |
10 | | -#include <tvm/tir/builtin.h> |
11 | | -#include <tvm/tir/op.h> |
12 | | -#include <tvm/tir/stmt_functor.h> |
13 | | -#include <tvm/tir/transform.h> |
14 | | - |
15 | | -#include <utility> |
16 | | - |
17 | | -#include "../op/builtin.h" |
18 | | -#include "./common/collector.h" |
19 | | -#include "runtime/thread_storage_scope.h" |
20 | | -#include "tir/transforms/ir_utils.h" |
| 6 | +#include "warp_specialized_rewriter.h" |
21 | 7 |
|
22 | 8 | namespace tvm { |
23 | 9 | namespace tl { |
@@ -1284,73 +1270,6 @@ class WarpSpecializedRewriter : public StmtExprMutator { |
1284 | 1270 | bool disable_shuffle_elect_ = false; |
1285 | 1271 | }; |
1286 | 1272 |
|
1287 | | -class WarpSpecializedDetector : public IRVisitorWithAnalyzer { |
1288 | | -public: |
1289 | | - // return true means this aws will be disabled |
1290 | | - static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { |
1291 | | - WarpSpecializedDetector detector; |
1292 | | - detector.VisitStmt(stmt); |
1293 | | - if (detector.has_warp_specialization_) { |
1294 | | - LOG(WARNING) << "Auto warp specialization will be disabled because warp " |
1295 | | - "specialization is manually enabled"; |
1296 | | - return true; |
1297 | | - } |
1298 | | - if (detector.has_tma_op_ && detector.has_mbarrier_op_) { |
1299 | | - LOG(WARNING) << "Auto warp specialization will be disabled because TMA " |
1300 | | - "and mbarrier are both present"; |
1301 | | - return true; |
1302 | | - } |
1303 | | - return false; |
1304 | | - } |
1305 | | - |
1306 | | - WarpSpecializedDetector() { |
1307 | | - has_tma_op_ = false; |
1308 | | - has_mbarrier_op_ = false; |
1309 | | - has_warp_specialization_ = false; |
1310 | | - } |
1311 | | - |
1312 | | -private: |
1313 | | - void VisitStmt_(const EvaluateNode *op) final { |
1314 | | - if (const CallNode *call = op->value.as<CallNode>()) { |
1315 | | - if (call->op.same_as(create_list_of_mbarrier()) || |
1316 | | - call->op.same_as(mbarrier_wait_parity()) || |
1317 | | - call->op.same_as(builtin::ptx_arrive_barrier()) || |
1318 | | - call->op.same_as(builtin::ptx_cp_async_barrier())) { |
1319 | | - has_mbarrier_op_ = true; |
1320 | | - } |
1321 | | - } |
1322 | | - IRVisitorWithAnalyzer::VisitStmt_(op); |
1323 | | - } |
1324 | | - |
1325 | | - void VisitExpr_(const CallNode *op) final { |
1326 | | - if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || |
1327 | | - op->op.same_as(set_max_nreg())) { |
1328 | | - has_tma_op_ = true; |
1329 | | - } |
1330 | | - IRVisitorWithAnalyzer::VisitExpr_(op); |
1331 | | - } |
1332 | | - |
1333 | | - void VisitStmt_(const AttrStmtNode *op) final { |
1334 | | - if (op->attr_key == "warp_specialize" && |
1335 | | - op->value.as<IntImmNode>()->value == 1) { |
1336 | | - has_warp_specialization_ = true; |
1337 | | - } |
1338 | | - if (op->attr_key == tir::attr::thread_extent) { |
1339 | | - IterVar iv = Downcast<IterVar>(op->node); |
1340 | | - if (iv->thread_tag == "threadIdx.x") { |
1341 | | - ICHECK(iv->dom->extent.as<IntImmNode>()); |
1342 | | - thread_var_ = iv; |
1343 | | - } |
1344 | | - } |
1345 | | - IRVisitorWithAnalyzer::VisitStmt_(op); |
1346 | | - } |
1347 | | - |
1348 | | - bool has_tma_op_{false}; |
1349 | | - IterVar thread_var_; |
1350 | | - bool has_mbarrier_op_{false}; |
1351 | | - bool has_warp_specialization_{false}; |
1352 | | -}; |
1353 | | - |
1354 | 1273 | using namespace tir::transform; |
1355 | 1274 |
|
1356 | 1275 | tvm::transform::Pass WarpSpecialized() { |
|
0 commit comments