diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index dec9271f3515..9020bf8d3994 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -23,6 +23,38 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { ]; } +def TritonGPUTestPipelineAssignLatencies : Pass<"tritongpu-test-pipeline-assign-latencies", "mlir::ModuleOp"> { + let summary = "test assigning latencies to interesting ops ahead of pipelining"; + + let description = [{ + This is a test pass that tests `assignLatencies` method of `TritonGPULoopScheduling`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-loop", "mlir::ModuleOp"> { + let summary = "test scheduling a loop for software pipelining"; + + let description = [{ + This is a test pass that tests `scheduleLoop` method of `TritonGPULoopScheduling`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { let summary = "3xTF32 trick"; diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index 933ce6024872..cdf22d15d499 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -11,6 +11,9 @@ static const char *kNumStagesAttrName = "tt.num_stages"; static const char *kLoopStageAttrName = "loop.stage"; static const char *kLoopClusterAttrName = "loop.cluster"; +bool loopHasDistGreaterThanOne(scf::ForOp forOp); +bool isOuterLoop(scf::ForOp forOp); + /// Function to mask operations during scheduling. Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index ac28742b4ada..916c9b252267 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -11,6 +11,18 @@ namespace mlir { namespace triton { +namespace gpu { + +/// Discover operations that should become async and assign latencies to them +/// based on the numStages value provided by the user. +DenseMap assignLatencies(ModuleOp forOp, int numStages); + +/// Schedule the loop based on the latencies assigned to the operations. +void scheduleLoop(scf::ForOp forOp, + const DenseMap &opLatency); + +}; // namespace gpu + /// This fill out the pipelining options including schedule and annotations /// for wait ops. This also does pre-processing by converting some of the /// loads into async loads so that the IR is ready to be pipelined. @@ -108,8 +120,7 @@ class CoarseSchedule { // Add dependencies of anchor ops to the coarse schedule. Schedule them to // the same stage and ordering cluster as the anchor op. -void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, - int numStages); +void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule); } // namespace triton } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 740014b77948..da176b0fd1a8 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -8,9 +8,12 @@ add_triton_library(TritonGPUTransforms OptimizeAccumulatorInit.cpp OptimizeDotOperands.cpp OptimizeThreadLocality.cpp + Pipeliner/AssignLatencies.cpp Pipeliner/MatmulLoopPipeline.cpp Pipeliner/OuterLoopPipeline.cpp Pipeliner/PipelineExpander.cpp + Pipeliner/TestPipelineAssignLatencies.cpp + Pipeliner/TestPipelineScheduleLoop.cpp Pipeliner/SoftwarePipeliner.cpp Pipeliner/TMAStoresPipeline.cpp Pipeliner/PipeliningUtility.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp index fea8540e78d7..e15b43960031 100644 --- a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp +++ b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp @@ -1,28 +1,11 @@ -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/RegionUtils.h" #include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" -#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "triton-loop-schedule" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -35,310 +18,108 @@ namespace mlir { namespace triton { namespace gpu { -// Create a map from load ops to their indirection level and the -// final use of the load op (another load op, or a dot op). -// Indirection level is "0" for the load op directly used by the dot op, -// "1" for the load op used by the load op used by the dot op, and so on. -static llvm::SmallVector> -loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { - llvm::SmallVector> - loadOpToIndLevelAndUse; - DenseSet seen; - - std::function dfs = - [&](Operation *op, int distance, Operation *use) { - if (!seen.insert(op).second) - return; - if (isa(op)) { - // TODO: What if there are multiple uses at different distances? - loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); - use = op; - distance++; - } - for (Value operand : op->getOperands()) { - if (op->hasTrait()) { - // Heuristic: only pipeline A and B operands of the dot op. - if (operand == op->getOperand(2)) - continue; - } - Value v = operand; - Operation *defOp = v.getDefiningOp(); - if (defOp && defOp->getBlock() == op->getBlock()) { - dfs(defOp, distance, use); - } - } - }; +#define GEN_PASS_DEF_TRITONGPULOOPSCHEDULING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - for (Operation &op : forOp.getBody()->without_terminator()) { - if (!op.hasTrait()) - continue; - seen.clear(); - dfs(&op, 0, &op); - } +namespace { - // If the loop has numStages attribute, also consider pipelining other loads - // that are not directly used by dot ops. - if (forOp->hasAttr(tt::kNumStagesAttrName)) { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (!isa(op)) - dfs(&op, 0, &op); - } +bool hasLatenciesAssigned(scf::ForOp forOp, + const DenseMap &opLatency) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + return true; } - - return loadOpToIndLevelAndUse; + return false; } -static bool hasSharedEncodingHelper(Operation *loadOp) { - // If the load is used by a LocalAllocOp, use the same encoding as the allocs. - // If the allocs don't all have the same encoding, bail. - if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) { - return isa(user); - })) { - ttg::SharedEncodingAttr localAllocEnc; - for (auto user : loadOp->getUsers()) { - auto localAlloc = dyn_cast(user); - if (!localAlloc) - continue; - auto enc = mlir::cast( - localAlloc.getType().getEncoding()); - if (!localAllocEnc) { - localAllocEnc = enc; - } - if (enc != localAllocEnc) - return false; - } - return true; +CoarseSchedule scheduleKeyOps(scf::ForOp forOp, + const DenseMap &opLatency) { + llvm::MapVector opToStage; + // Find terminator for later reference + auto terminator = cast(forOp.getBody()->getTerminator()); + // Determine all operations that have a non-zero latency + SmallVector latOps; + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + latOps.push_back(&op); } - return true; -} - -// Check to see if loads can be pipelined. -static llvm::DenseSet -filterPipelinedLoad(llvm::SmallVector> - &loadOpToIndLevelAndUse, - tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { - llvm::DenseSet loadsToPipeline; - for (auto &[op, dist_, use] : loadOpToIndLevelAndUse) { - if (loadsToPipeline.count(op)) - // TODO pawel: err, we'd need to verify that the distance is the same - continue; - - if (auto loadOp = dyn_cast(op)) { - assert(!isLoadFromTensorPtr(loadOp) && - "Block ptr should have been lowered before this pass."); - auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - if (auto mask = loadOp.getMask()) - vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - continue; - auto ty = - cast(tensorTy.getElementType()).getPointeeType(); - unsigned width = vec * ty.getIntOrFloatBitWidth(); - - // We do not pipeline all loads for the following reasons: - // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. - // 2. It's likely that pipling small loads won't offer much performance - // improvement and may even hurt performance by increasing register - // pressure. - LDBG("Load " << *loadOp << " has width " << width); - if (width < 32) - continue; - } - - bool hasSharedEncoding = false; - if (use->hasTrait()) { - auto mmaLoadType = getMMALoadType(op); - auto dot = dyn_cast(use); - auto warpGroupDot = dyn_cast(use); - bool isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; - bool isMMAv3Registers = - (mmaLoadType == MMALoadType::Registers) && warpGroupDot; - - if (isMMAv3Shared) { - hasSharedEncoding = true; - } else if (isa(op)) { - hasSharedEncoding = true; - } else if (isMMAv3Registers || dot) { - // FIXME: if we have a better solution in handling incompatible shared - // encoding, we can simplify the logic here by checking if all users are - // dot encoding. Fow now, getSharedEncIfAllUsersAreDotEnc will be used - // during both scheduling and lowering. - bool incompatible = false; - auto sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) - .value_or(nullptr); - hasSharedEncoding = sharedEncoding != nullptr; - // If we can't agree on a shared encoding skip pipelinig the load. - if (incompatible) - continue; - } - } else if (auto loadOp = dyn_cast(use)) { - // The use of this loadOp is another loadOp. If the use is not in the - // loadsToPipeline already, it means that the use is not valid for - // pipelining for some reason. We should skip this loadOp, too. Note that - // we have an assumption that distAndUse.second (i.e. the use of this - // loadOp) has already be processed in a previous loop iteration. This - // assumption is held by how loadOpsToIndirectionLevelAndUse recursively - // collects loadOpToIndLevelAndUse using DFS. - if (loadsToPipeline.count(loadOp) == 0) { + // If no latency ops, nothing to schedule + if (latOps.empty()) + return CoarseSchedule(0); + + // Compute the longest path to the yield for each operation reachable + // from any latency operation. + DenseMap distance; + std::function computeDistance = [&](Operation *op) -> int { + auto it = distance.find(op); + if (it != distance.end()) + return it->second; + // Compute max distance among all users that are inside the loop body + int maxDist = -1; + for (Operation *user : op->getUsers()) { + // Only consider users inside the same block and not the terminator + Operation *inBlockUser = forOp.getBody()->findAncestorOpInBlock(*user); + if (!inBlockUser || inBlockUser == terminator) continue; - } - } - - // If we still don't have a shared encoding, try a "generic" shared - // encoding. - if (!hasSharedEncoding && !isa(use)) - hasSharedEncoding = hasSharedEncodingHelper(op); - - // If that still didn't work, bail on pipelining this load. - if (!hasSharedEncoding) { - continue; - } - loadsToPipeline.insert(op); - } - return loadsToPipeline; -} - -static void scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, - DenseSet &rootUsers, int numStages) { - - ModuleOp moduleOp = forOp->getParentOfType(); - tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); - - // Get all loads that are (transitively) used by dot ops and their distance - // to the dot op. - llvm::SmallVector> - loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); - LLVM_DEBUG({ - LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); - for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { - LDBG(" - load: " << *l); - LDBG(" at indirection level: " << i); - LDBG(" used by op: " << *u); + int distUser = computeDistance(inBlockUser); + if (distUser > maxDist) + maxDist = distUser; } - }); - if (loadOpToIndLevelAndUse.empty()) - return; - - // We assume loads with different dist are assigned to different stages. - // If numStages is 2, we will have no stage available for indirect loads - // with dist >= 1. In general, when dist is equal to numStages - 1, we - // should not pipeline it. - auto it = llvm::remove_if(loadOpToIndLevelAndUse, [=](auto op) { - return std::get<1>(op) >= numStages - 1; - }); - loadOpToIndLevelAndUse.erase(it, loadOpToIndLevelAndUse.end()); - - // Check which loads are good for pipelining. - llvm::DenseSet loadsToPipeline = - filterPipelinedLoad(loadOpToIndLevelAndUse, axisInfoAnalysis); - if (loadsToPipeline.empty()) - return; + int lat = 0; + if (opLatency.count(op)) + lat = opLatency.lookup(op); + // If an op has no users (maxDist == -1) but has latency, we include its + // latency otherwise it contributes 0 to the distance. + int d = lat + (maxDist < 0 ? 0 : maxDist); + distance[op] = d; + return d; + }; - // Calculate the stage distance between applicable loads. - int maxIndirectionLevel = -1; - for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { - if (loadsToPipeline.count(loadOp) == 0) - continue; - maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + // Compute distances for all latency-starting ops + int maxDistance = 0; + for (Operation *latOp : latOps) { + int d = computeDistance(latOp); + if (d > maxDistance) + maxDistance = d; } - unsigned stagesBetweenLoads = - ceil(numStages - 2, maxIndirectionLevel + 1); - tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); - // Put the root uses of the loads in the last stage. - for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { - if (loadsToPipeline.count(loadOp) == 0) - continue; - // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be - // always present in the opInfo - if (!isa(use)) { - schedule.insert(use, numStages - 1, rootUsersCluster); - rootUsers.insert(use); - } + // Assign stage to each op reachable from a latency op + for (auto &kv : distance) { + Operation *op = kv.first; + int dist = kv.second; + // We only schedule ops that are downstream of a latency op + // (had a non-negative distance due to a latency op). + if (dist >= 0) + opToStage[op] = maxDistance - dist; } - SmallVector loadsClusters; - for (int i = 0; i < maxIndirectionLevel + 1; i++) { - loadsClusters.push_back(schedule.clusters.newAtBack()); + auto stages = llvm::make_second_range(opToStage); + int maxStage = *std::max_element(stages.begin(), stages.end()); + CoarseSchedule schedule(maxStage + 1); + SmallVector clusters(maxStage + 1); + for (int i = 0; i <= maxStage; i++) { + clusters[i] = schedule.clusters.newAtBack(); } - // Assign stages to the loads. - for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { - if (loadsToPipeline.count(loadOp) == 0) + CoarseSchedule::Cluster epilogue = schedule.clusters.newAtBack(); + // Assign ops to the clusters in reverse-stage order; + // ops with higher stage numbers are assigned first. This way we will + // end up with roughly reverse program order in the clusters. + for (auto [op, stage] : opToStage) { + if (isa(op)) { + schedule.insert(op, stage, epilogue); continue; - int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; - schedule.insert(loadOp, stage, loadsClusters[indLevel]); - } -} - -// Schedule the prologue and epilogue `if` ops in the loop, pushing them as -// close to the loop boundaries as possible. Return the cluster after the -// prologue (or the beginning of the loop if there is no prologue). -static tt::CoarseSchedule::Cluster -schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, - DenseSet &rootUsers, int numStages) { - tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); - - // Look for the IfOp that is in the backward slice any of the currently - // scheduled ops and put it at the beginning of the loop. - DenseMap ifsToStage; - // Go stage by stage. - for (int stage = 0; stage < numStages; stage++) { - for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { - if (stage_ != stage) - continue; - SetVector backwardSlice; - BackwardSliceOptions opt; - opt.omitBlockArguments = true; - getBackwardSlice((Operation *)op, &backwardSlice, opt); - - for (auto op : backwardSlice) { - if (auto ifOp = dyn_cast(op)) { - ifsToStage.insert({ifOp, stage}); - } - } } - } - tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); - for (auto [ifOp, stage] : ifsToStage) { - schedule.insert(ifOp, stage, prologueCluster); + schedule.insert(op, stage, clusters[maxStage - stage]); } - // Look for the IfOp that is in the forward slice of the root users and put it - // at the end of the loop. - tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); - for (auto rootUser : rootUsers) { - SetVector forwardSlice; - getForwardSlice(rootUser, &forwardSlice); - - int stage = schedule[rootUser].first; - for (auto op : forwardSlice) { - scf::IfOp ifOp = dyn_cast(op); - if (ifOp == nullptr) { - // check if the op is in the body of an if op that's part of the loop - auto parentOp = op->getParentOp(); - if (parentOp != nullptr && - parentOp->getParentOp() == forOp.getOperation()) { - ifOp = dyn_cast(parentOp); - } - } - if (ifOp) { - schedule.insertIfAbsent(ifOp, stage, - epilogueCluster); // after prefetch extracts - } - } - } - return afterPrologue; + return schedule; } // Find dependencies with distance of 1. They will go to the next stage, // but in the cluster before the current op. -static void scheduleDistanceOneDependencies(scf::ForOp forOp, - tt::CoarseSchedule &schedule, - int numStages) { +void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.numStages; auto getNestedOperands = [](Operation *op) -> SmallVector { SmallVector operands; op->walk([&](Operation *nestedOp) { @@ -351,8 +132,7 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp, }; // Mapping from the cluster to the cluster before it. - DenseMap - dist1Cluster; + DenseMap dist1Cluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) continue; @@ -387,14 +167,61 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp, } } -static void -scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, - tt::CoarseSchedule::Cluster afterPrologue, - int numStages) { +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.numStages; + CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + if (!ifsToStage.empty()) { + CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } + } + + // Other IfOps should be pushed to the end. + CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto ifOp = dyn_cast(op)) { + if (ifsToStage.count(ifOp) == 0) { + schedule.insertIfAbsent(ifOp, numStages - 1, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue) { + int numStages = schedule.numStages; // Assign the rest of the ops to the last stage. // Take care of the ordering of the ops - uses cannot be scheduled to the // cluster before the definition. - DenseMap opToCluster; + DenseMap opToCluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) { opToCluster[&op] = afterPrologue; @@ -412,8 +239,8 @@ scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, Operation *op = queue.pop_back_val(); for (auto user : op->getUsers()) { if (opToCluster.count(user)) { - tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; - tt::CoarseSchedule::Cluster opCluster; + CoarseSchedule::Cluster userCluster = opToCluster[user]; + CoarseSchedule::Cluster opCluster; if (schedule.count(op)) opCluster = schedule[op].second; else @@ -430,8 +257,46 @@ scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, } } -#define GEN_PASS_DEF_TRITONGPULOOPSCHEDULING -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +}; // namespace + +void scheduleLoop(scf::ForOp forOp, + const DenseMap &opLatency) { + if (!hasLatenciesAssigned(forOp, opLatency)) + return; + // Based on the latencies, schedule the key ops to the stages. + CoarseSchedule schedule = scheduleKeyOps(forOp, opLatency); + if (schedule.empty()) + return; + LLVM_DEBUG({ + LDBG("Initial coarse schedule:"); + schedule.dump(); + }); + // Schedule the dependencies + CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, schedule); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + schedule.dump(); + }); + scheduleDependencies(forOp, schedule); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + schedule.dump(); + }); + scheduleDistanceOneDependencies(forOp, schedule); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + schedule.dump(); + }); + scheduleRemainingToLastStage(forOp, schedule, afterPrologue); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + schedule.dump(); + }); + + // Write the schedule to the IR + schedule.serialize(forOp); +} class TritonGPULoopSchedulingPass : public impl::TritonGPULoopSchedulingBase { @@ -450,61 +315,28 @@ class TritonGPULoopSchedulingPass } void runOnOperation() override { + // Go over the interesting ops and assign latencies (based on the + // numStages) to the them, trying to populate the allowed stages. This + // step will be at some point extracted to separate pass that will be run + // only for loops missing the latency information. + DenseMap opLatency = + assignLatencies(getOperation(), numStages); + // numStages should not be used below this point. We should know everything + // based on the assigned stages + + // Schedule the loops SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { // Bail out for loops with num_stage <= 1. if (getNumStagesOrDefault(forOp) > 1) loops.push_back(forOp); }); - if (loops.empty()) return; - for (scf::ForOp forOp : loops) { - int loopNumStages = getNumStagesOrDefault(forOp); - DenseSet rootUsers; - tt::CoarseSchedule coarseSchedule(loopNumStages); - scheduleLoads(forOp, coarseSchedule, rootUsers, loopNumStages); - if (coarseSchedule.opToStageAndCluster.size() == 0) - continue; - tt::CoarseSchedule::Cluster afterPrologue = schedulePrologueAndEpilogue( - forOp, coarseSchedule, rootUsers, loopNumStages); - - scheduleDependencies(forOp, coarseSchedule, loopNumStages); - LLVM_DEBUG({ - LDBG("Coarse schedule with dependencies:"); - coarseSchedule.dump(); - }); - scheduleDistanceOneDependencies(forOp, coarseSchedule, loopNumStages); - LLVM_DEBUG({ - LDBG("Coarse schedule with dist 1:"); - coarseSchedule.dump(); - }); - - LDBG("afterPrologue = " << *afterPrologue); - scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, - loopNumStages); - LLVM_DEBUG({ - LDBG("Final coarse schedule:"); - coarseSchedule.dump(); - }); - - // Go through schedule and assign (stage, cluster). - // shift so afterPrologue will be at clusterId 0 - auto ctx = forOp.getContext(); - for (auto [op, stage_, cluster] : coarseSchedule.getOpsInOrder(forOp)) { - op->setAttr(mlir::triton::kLoopStageAttrName, - IntegerAttr::get(IntegerType::get(ctx, 32), stage_)); - op->setAttr(mlir::triton::kLoopClusterAttrName, - IntegerAttr::get(IntegerType::get(ctx, 32), - *cluster /*- *afterPrologue*/)); - LLVM_DEBUG({ - LDBG("set stage " << stage_ << " cluster " << (*cluster)); - op->dump(); - }); - } + for (auto forOp : loops) { + scheduleLoop(forOp, opLatency); } - return; } }; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp new file mode 100644 index 000000000000..f274363730c4 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -0,0 +1,252 @@ +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-pipeline-schedule" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Return true if the preconditions for pipelining the loop are met. +bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (loopHasDistGreaterThanOne(forOp)) + return false; + // Don't pipeline outer loops. + if (isOuterLoop(forOp)) + return false; + return true; +} + +bool canHaveSharedEncoding(tt::LoadOp op) { + // If used by an user with DotOp encoding, all the uses must be compatible. + bool incompatible = false; + getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible); + if (incompatible) + return false; + // If the load is used by a LocalAllocOp, all the users need to have the same + // encoding. + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + ttg::SharedEncodingAttr localAllocEnc; + for (auto user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) + return false; + } + return true; + } + return true; +} + +bool isSmallLoad(tt::LoadOp loadOp, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return true; + auto ty = cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + return width < 32; +} + +bool isPipeliningBeneficial(Operation *op, Operation *finalUser, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + if (auto loadOp = dyn_cast(op)) { + if (isSmallLoad(loadOp, axisInfoAnalysis)) { + LDBG("Load " << *loadOp << " is too small for pipelining"); + return false; + } + } + if (isa(op)) + return true; + if (isa(finalUser) && + getMMALoadType(op) == MMALoadType::DoNotPipeline) { + LDBG("Load " << *op << " used by WarpGroupDotOp with incompatible layout"); + return false; + } + if (!canHaveSharedEncoding(cast(op))) { + LDBG("Load " << *op << " cannot have shared encoding"); + return false; + } + + return true; +} + +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +llvm::MapVector +loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadOpToIndLevel; + DenseSet seen; + DenseSet excluded; + + std::function dfs = + [&](Operation *op, Operation *finalUser, int distance) { + if (!seen.insert(op).second || excluded.count(op)) + return; + if (isa(op)) { + if (!isPipeliningBeneficial(op, finalUser, axisInfoAnalysis)) + return; + if (loadOpToIndLevel.count(op)) { + int level = loadOpToIndLevel[op]; + if (level != distance) { + // If we have multiple uses at different distances, we don't know + // which one to pick. + LDBG("Load " << *op + << " has multiple uses at different distances:" + << level << " and " << distance); + loadOpToIndLevel.erase(op); + excluded.insert(op); + return; + } + } else { + LDBG("Load " << *op << " considered for pipelining with distance " + << distance); + loadOpToIndLevel[op] = distance; + } + finalUser = op; + distance++; + } + for (Value operand : op->getOperands()) { + if (op->hasTrait()) { + // Heuristic: only pipeline A and B operands of the dot op. + if (operand == op->getOperand(2)) + continue; + } + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, finalUser, distance); + } + } + }; + + bool seenDot = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seenDot = true; + seen.clear(); + dfs(&op, &op, 0); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (pipelineWithoutDot && !seenDot) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, &op, 0); + } + } + + return loadOpToIndLevel; +} + +} // namespace + +// Look for load ops that directly or indirectly feed into dot ops. Based +// on the requested number of stages assign the latencies in a way that +// cover all the stages with the sum of latencies in the chain from the first +// load to the final dot op. +DenseMap assignLatencies(ModuleOp moduleOp, + int defaultNumStages) { + auto getNumStagesOrDefault = [defaultNumStages](scf::ForOp forOp) -> int { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return defaultNumStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + }; + + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (preCondition(forOp) && getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + if (loops.empty()) + return DenseMap(); + + DenseMap opLatency; + for (auto forOp : loops) { + int numStages = getNumStagesOrDefault(forOp); + bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName); + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + llvm::MapVector loadOpToIndLevel = + loadOpsToIndirectionLevel(forOp, pipelineWithoutDot, axisInfoAnalysis); + if (loadOpToIndLevel.empty()) + continue; + + // We assume loads with different dist are assigned to different stages. + // If numStages is 2, we will have no stage available for indirect loads + // with dist >= 1. In general, when dist is equal to numStages - 1, we + // should not pipeline it. + for (auto iter = loadOpToIndLevel.begin(); + iter != loadOpToIndLevel.end();) { + if (iter->second >= numStages - 1) + iter = loadOpToIndLevel.erase(iter); + else + ++iter; + } + + // Calculate the stage distance between applicable loads. + auto vals = llvm::make_second_range(loadOpToIndLevel); + int maxIndirectionLevel = + vals.empty() ? 0 : *std::max_element(vals.begin(), vals.end()); + unsigned loadLatency = (numStages - 1) / (maxIndirectionLevel + 1); + + for (auto [loadOp, dist] : loadOpToIndLevel) { + opLatency[loadOp] = loadLatency; + } + } + return opLatency; +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 9aa6a8f8d3fa..9407276c21d8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -822,7 +822,7 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule coarseSchedule(numStages); coarseSchedule.deSerialize(forOp); - scheduleDependencies(forOp, coarseSchedule, numStages); + scheduleDependencies(forOp, coarseSchedule); coarseSchedule.serialize(forOp); // Make sure all ops have attributes. diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 5fd98355f9c5..aab560770720 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -13,6 +13,20 @@ namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; +bool mlir::triton::loopHasDistGreaterThanOne(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + }); +} + +bool mlir::triton::isOuterLoop(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getOperations(), [](Operation &op) { + return isa(op); + }); +} + // Combine the current mask with the given predicate. static Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask, Value pred) { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index 2710c0365384..bfb31a3e8d6a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -114,10 +114,11 @@ void tt::CoarseSchedule::deSerialize(scf::ForOp &forOp) { } } +// TODO: Should this be moved somewhere else? // Add dependencies of anchor ops to the coarse schedule. Schedule them to // the same stage and ordering cluster as the anchor op. -void tt::scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, - int numStages) { +void tt::scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule) { + int numStages = schedule.numStages; SmallVector> opsInOrder = schedule.getOpsInOrder(forOp); // Schedule dependencies stage by stage. diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 8766e82b9f15..3361087a9c7d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -37,22 +37,10 @@ namespace gpu { static bool preCondition(scf::ForOp forOp) { // Skip loop with distance > 1 for now. // TODO: relax the constraint in the expander. - if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), - [](Value operand) { - Operation *def = operand.getDefiningOp(); - return !def; - })) + if (loopHasDistGreaterThanOne(forOp)) return false; // Don't pipeline outer loops. - if (forOp - ->walk([&](Operation *op) { - if (forOp.getOperation() == op) - return WalkResult::advance(); - if (isa(op)) - return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted()) + if (isOuterLoop(forOp)) return false; return true; } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp new file mode 100644 index 000000000000..ae3f3a97f9d4 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp @@ -0,0 +1,43 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINEASSIGNLATENCIES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static const char *kLatencyAttrName = "tt.latency"; + +struct TestPipelineAssignLatencies + : public impl::TritonGPUTestPipelineAssignLatenciesBase< + TestPipelineAssignLatencies> { + using impl::TritonGPUTestPipelineAssignLatenciesBase< + TestPipelineAssignLatencies>::TritonGPUTestPipelineAssignLatenciesBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + DenseMap opLatencies = assignLatencies(m, numStages); + + for (auto [op, latency] : opLatencies) { + op->setAttr( + kLatencyAttrName, + IntegerAttr::get(IntegerType::get(m.getContext(), 32), latency)); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp new file mode 100644 index 000000000000..54956c7177ca --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp @@ -0,0 +1,54 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINESCHEDULELOOP +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static const char *kLatencyAttrName = "tt.latency"; + +struct TestPipelineScheduleLoop + : public impl::TritonGPUTestPipelineScheduleLoopBase< + TestPipelineScheduleLoop> { + using impl::TritonGPUTestPipelineScheduleLoopBase< + TestPipelineScheduleLoop>::TritonGPUTestPipelineScheduleLoopBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + DenseMap opLatencies; + + // Deserialize latencies from the IR. + m.walk([&](Operation *op) { + if (op->hasAttr(kLatencyAttrName)) { + int latency = + mlir::cast(op->getAttr(kLatencyAttrName)).getInt(); + op->removeAttr(kLatencyAttrName); + opLatencies[op] = latency; + } + }); + + SmallVector loops; + m.walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + + for (auto forOp : loops) { + scheduleLoop(forOp, opLatencies); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 602b1d395b91..6646d94f50a8 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -162,16 +162,6 @@ def kernel_pipe_error(in_ptr, out_ptr): if tl.max(val) > 0: k += 1 - with enable_remark_context(): - triton.compile( - triton.compiler.ASTSource( - fn=kernel_pipe_error, - signature={"in_ptr": "*fp32", "out_ptr": "*fp32"}, - constants={}, - ), - options={"cluster_dims": (1, 1, 1)}, - ) - - _, err = capfd.readouterr() - - assert "operation scheduled before its operands" not in err, "expect swp op remark" + i = torch.empty(64 * 64, dtype=torch.float32).cuda() + o = torch.empty(64 * 64, dtype=torch.float32).cuda() + kernel_pipe_error[(1, )](i, o) diff --git a/test/TritonGPU/loop-schedule.mlir b/test/TritonGPU/loop-schedule.mlir index adf7050da315..afd4ec75db54 100644 --- a/test/TritonGPU/loop-schedule.mlir +++ b/test/TritonGPU/loop-schedule.mlir @@ -12,8 +12,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABLE: @matmul_loop_load_acc // CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} // CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} -// CHECK: tt.load %{{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} -// CHECK: tt.dot {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} +// CHECK: tt.load %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}, diff --git a/test/TritonGPU/pipeline-assign-latencies.mlir b/test/TritonGPU/pipeline-assign-latencies.mlir new file mode 100644 index 000000000000..9ff318b77983 --- /dev/null +++ b/test/TritonGPU/pipeline-assign-latencies.mlir @@ -0,0 +1,376 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-test-pipeline-assign-latencies=num-stages=3 -canonicalize | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared2 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @default_stages +tt.func @default_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @small_load +// We should *not* assign latency to the load of b_ptr. +tt.func @small_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} + // CHECK-NOT: tt.latency + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @load_into_shared +tt.func @load_into_shared(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #mma> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #mma> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.local_alloc %a_ : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> + + %c = ttng.warp_group_dot %a, %b, %prev_c {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma> + } + tt.return %loop#2: tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @load_into_shared_incompat_layout +tt.func @load_into_shared_incompat_layout(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #mma> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #mma> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.local_alloc %a_ : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> + // CHECK: tt.load + // CHECK-NOT: {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #shared2, #ttg.shared_memory> + + %c = ttng.warp_group_dot %a, %b, %prev_c {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<32x128xf16, #shared2, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma> + } + tt.return %loop#2: tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load +tt.func @indirect_load(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @mixed_loads +tt.func @mixed_loads(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @per_loop_stages +tt.func @per_loop_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> (tensor<128x128xf32, #C>, tensor<128x128xf32, #C>) { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop_cust_stages:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 3 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 3 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 4 : i32} + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop_cust_stages#2, %loop#2: tensor<128x128xf32, #C>, tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load_cust_stages +tt.func @indirect_load_cust_stages(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 5 : i32} + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load_few_stages +tt.func @indirect_load_few_stages(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load + // CHECK-NOT: tt.latency + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load + // CHECK-NOT: tt.latency + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 2 : i32} + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @non_dot_pipeline +tt.func @non_dot_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x32xf16, #A> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = arith.addf %a, %prev_c : tensor<128x32xf16, #A> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A> + } {tt.num_stages = 3 : i32} + tt.return %loop#1: tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @no_pipeline +tt.func @no_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x32xf16, #A> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A>) { + // CHECK: tt.load + // CHECK-NOT: tt.latency + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = arith.addf %a, %prev_c : tensor<128x32xf16, #A> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A> + } + tt.return %loop#1: tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use_cust_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 3 : i32} + tt.return %loop#2: tensor<128x128xf32, #C> +} + +} diff --git a/test/TritonGPU/pipeline-schedule-loop.mlir b/test/TritonGPU/pipeline-schedule-loop.mlir new file mode 100644 index 000000000000..bd66562d528b --- /dev/null +++ b/test/TritonGPU/pipeline-schedule-loop.mlir @@ -0,0 +1,337 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-test-pipeline-schedule-loop -canonicalize | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @one_dep +tt.func @one_dep(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps +tt.func @parallel_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %b = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps_uneven1 +tt.func @parallel_deps_uneven1(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32} + %b = tt.load %a_ptr_init {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps_uneven2 +tt.func @parallel_deps_uneven2(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32} + %a = tt.load %a_ptr_init {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %b = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @direct_deps +tt.func @direct_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %a_ptr = %a_ptr_init) -> (tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A>) { + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a_ptr_next = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_next {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res, %a_ptr_next : tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @dist1_deps +tt.func @dist1_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %a_ptr = %a_ptr_init) -> (tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %a_ptr_next = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + scf.yield %res, %a_ptr_next : tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @prologue_if +tt.func @prologue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: scf.if + // CHECK: {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %a_ptr = scf.if %cnd -> tensor<128x32x!tt.ptr, #A> { + %a_ptr_ret = tt.addptr %a_ptr_init, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + scf.yield %a_ptr_ret : tensor<128x32x!tt.ptr, #A> + } else { + scf.yield %a_ptr_init : tensor<128x32x!tt.ptr, #A> + } + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @independent_epilogue_if +tt.func @independent_epilogue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: scf.if + // CHECK: {loop.cluster = 4 : i32, loop.stage = 2 : i32} + scf.if %cnd { + tt.store %a_ptr_init, %init : tensor<128x32x!tt.ptr, #A> + } + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @independent_last_stage +tt.func @independent_last_stage(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %acc2 = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res2 = arith.addf %acc2, %init : tensor<128x32xf16, #A> + scf.yield %res, %res2 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @basic_pipeline +tt.func @basic_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @unpipelined_load +tt.func @unpipelined_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // load below should be in the same stage as tt.dot (not pipelined) + // CHECK: tt.load {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // addptr below should be scheduled to the last stage + // CHECK: tt.addptr {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @epilogue_if +tt.func @epilogue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>, + %c_ptr_store : tensor<128x128x!tt.ptr, #C>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: scf.if + // CHECK: {loop.cluster = 4 : i32, loop.stage = 2 : i32} + scf.if %cnd { + tt.store %c_ptr_store, %c : tensor<128x128x!tt.ptr, #C> + } + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: arith.mulf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load +tt.func @indirect_load(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 4 : i32, loop.stage = 0 : i32} + %a_off = tt.load %a_ind_ptr {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #AL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // addptr below scheduled by scheduleDependencies to the same stage as tt.load that is using it + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %a_ = tt.load %next_a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %b_ = tt.load %next_b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} +}