-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BW] Add lowering support for tensor memory (#19)
* Handle tensor memory allocation in a naive bump pointer way * Handle basic lowering of tensor memory ld * Handle lowering of the base pointer for the tensor memory
- Loading branch information
1 parent
eb26aa4
commit 51b2405
Showing
16 changed files
with
469 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "triton/Dialect/Triton/IR/Utility.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" | ||
|
||
#include <memory> | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
using namespace mlir; | ||
using namespace triton; | ||
using namespace triton::gpu; | ||
using namespace triton::nvidia_gpu; | ||
|
||
class SyncMMALowering : public OpRewritePattern<TCGen5MMAOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(TCGen5MMAOp op, | ||
PatternRewriter &rewriter) const override { | ||
// If the op doesn't have synchronous semantic skip the pattern. | ||
if (op.getBarrier()) | ||
return failure(); | ||
MLIRContext *ctx = op.getContext(); | ||
Location loc = op.getLoc(); | ||
Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); | ||
auto barrierCTALayout = CTALayoutAttr::get( | ||
/*context=*/ctx, /*CTAsPerCGA=*/{1}, | ||
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); | ||
auto barrierEncoding = | ||
SharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout); | ||
MemDescType barrierMemDescType = | ||
MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding, | ||
sharedMemorySpace, /*mutableMemory=*/true); | ||
Value barrierAlloc = | ||
rewriter.create<LocalAllocOp>(loc, barrierMemDescType, Value()); | ||
rewriter.create<InitBarrierOp>(loc, barrierAlloc, 1); | ||
op.getBarrierMutable().assign(barrierAlloc); | ||
|
||
rewriter.setInsertionPointAfter(op); | ||
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); | ||
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase); | ||
rewriter.create<InvalBarrierOp>(loc, barrierAlloc); | ||
return success(); | ||
} | ||
}; | ||
|
||
class TritonNvidiaGPUMMALoweringPass | ||
: public TritonNvidiaGPUMMALoweringPassBase< | ||
TritonNvidiaGPUMMALoweringPass> { | ||
public: | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ModuleOp m = getOperation(); | ||
|
||
mlir::RewritePatternSet patterns(context); | ||
patterns.add<SyncMMALowering>(context); | ||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) | ||
signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUMMALoweringPass() { | ||
return std::make_unique<TritonNvidiaGPUMMALoweringPass>(); | ||
} |
68 changes: 68 additions & 0 deletions
68
lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "triton/Dialect/Triton/IR/Utility.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" | ||
|
||
using namespace mlir; | ||
using namespace triton; | ||
using namespace triton::gpu; | ||
using namespace triton::nvidia_gpu; | ||
|
||
static int getMemSize(MemDescType memDescType) { | ||
int sizeInBytes = product(memDescType.getShape()) * | ||
memDescType.getElementType().getIntOrFloatBitWidth() / 8; | ||
const int numRows = 128; | ||
int numColumn = sizeInBytes / (128 * 4); | ||
return numColumn; | ||
} | ||
|
||
namespace { | ||
|
||
class TritionTensorMemoryAllocationPass | ||
: public TritionTensorMemoryAllocationPassBase< | ||
TritionTensorMemoryAllocationPass> { | ||
public: | ||
void runOnOperation() override { | ||
ModuleOp mod = getOperation(); | ||
MLIRContext *ctx = &getContext(); | ||
SmallVector<triton::gpu::LocalAllocOp> allocs; | ||
mod.walk([&](triton::gpu::LocalAllocOp alloc) { allocs.push_back(alloc); }); | ||
int totalMemorySize = 0; | ||
// For now simple bump pointer allocator. | ||
for (triton::gpu::LocalAllocOp alloc : allocs) { | ||
auto memDescType = alloc.getType(); | ||
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>( | ||
memDescType.getMemorySpace())) | ||
continue; | ||
alloc->setAttr( | ||
"tensor_memory_offset", | ||
IntegerAttr::get(IntegerType::get(ctx, 32), totalMemorySize)); | ||
totalMemorySize += getMemSize(memDescType); | ||
} | ||
std::array<int, 6> possibleAllocations = {0, 32, 64, 128, 256, 512}; | ||
if (totalMemorySize > 512) | ||
llvm::report_fatal_error("Exceeded the maximum amount of tensor memory."); | ||
|
||
for (int size : possibleAllocations) { | ||
if (totalMemorySize <= size) { | ||
totalMemorySize = size; | ||
break; | ||
} | ||
} | ||
|
||
mod->setAttr("triton_gpu.tensor_memory_size", | ||
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), | ||
totalMemorySize)); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> mlir::createTensorMemoryAllocationPass() { | ||
return std::make_unique<TritionTensorMemoryAllocationPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s | ||
|
||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 65544 : i32, triton_gpu.target = "cuda:100", triton_gpu.tensor_memory_size = 128 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { | ||
// CHECK-LABEL: @tensor_memory_base_lowering | ||
// CHECK: %[[SHMEM:.+]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3> | ||
// CHECK: %[[A:.+]] = llvm.inline_asm has_side_effects | ||
// CHECK-SAME: "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], 128", "r" %[[SHMEM]] : (!llvm.ptr<3>) -> !llvm.void | ||
// CHECK: %[[AR:.+]] = llvm.load %[[SHMEM]] : !llvm.ptr<3> -> i32 | ||
// CHECK: nvvm.barrier0 | ||
// CHECK: "tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned", "" : () -> !llvm.void | ||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.dealloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], 128", "r" %{{.+}} : (!llvm.ptr<6>) -> !llvm.void | ||
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> | ||
llvm.func @tensor_memory_base_lowering() -> i32 attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = array<i32: 128>} { | ||
%263 = nvgpu.tensor_memory_base | ||
%264 = llvm.ptrtoint %263 : !llvm.ptr<6> to i32 | ||
llvm.return %264 : i32 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// RUN: triton-opt %s -split-input-file -triton-tensor-memory-allocation | FileCheck %s | ||
|
||
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> | ||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 65536 : i32, triton_gpu.target = "cuda:100", "triton_gpu.threads-per-warp" = 32 : i32} { | ||
// CHECK: triton_gpu.tensor_memory_size = 128 | ||
// CHECK: alloc_tensor_memory | ||
tt.func public @alloc_tensor_memory(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) attributes {noinline = false} { | ||
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> | ||
// CHECK: triton_gpu.local_alloc %{{.+}} {tensor_memory_offset = 0 : i32} | ||
%0 = triton_gpu.local_alloc %cst : (tensor<128x128xf32, #blocked>) -> !tt.memdesc<128x128xf32, #shared, #triton_nvidia_gpu.tensor_memory, mutable> | ||
tt.return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.