Skip to content

Commit

Permalink
[BW] Add lowering support for tensor memory (#19)
Browse files Browse the repository at this point in the history
* Handle tensor memory allocation in a naive bump pointer way
* Handle basic lowering of tensor memory ld
* Handle lowering of the base pointer for the tensor memory
  • Loading branch information
ThomasRaoux authored Jun 14, 2024
1 parent eb26aa4 commit 51b2405
Show file tree
Hide file tree
Showing 16 changed files with 469 additions and 77 deletions.
4 changes: 4 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);

std::unique_ptr<Pass> createTritonNvidiaGPUTMALoweringPass();

std::unique_ptr<Pass> createTensorMemoryAllocationPass();

std::unique_ptr<Pass> createTritonNvidiaGPUMMALoweringPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
Expand Down
29 changes: 28 additions & 1 deletion include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::M
];
}


def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::ModuleOp"> {
let summary = "lower to TMA load/store operations";

Expand All @@ -74,4 +73,32 @@ def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::M
];
}

def TritionTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
let summary = "Assign tensor memory allocation";

let description = [{
Decide on tensor memory allocation and assign attributes to each allocation.
}];

let constructor = "mlir::createTensorMemoryAllocationPass()";

let dependentDialects = [
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
}

def TritonNvidiaGPUMMALoweringPass : Pass<"triton-nvidia-mma-lowering", "mlir::ModuleOp"> {
let summary = "lower mma operations if needed";

let description = [{
Lower MMA ops to prepare for conversion to LLVM.
}];

let constructor = "mlir::createTritonNvidiaGPUMMALoweringPass()";

let dependentDialects = [
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
}

#endif
2 changes: 2 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_triton_library(TritonNvidiaGPUTransforms
FenceInsertion.cpp
MMALowering.cpp
PlanCTA.cpp
TensorMemoryAllocation.cpp
TMALowering.cpp

DEPENDS
Expand Down
72 changes: 72 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#include <memory>

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"

namespace {

using namespace mlir;
using namespace triton;
using namespace triton::gpu;
using namespace triton::nvidia_gpu;

class SyncMMALowering : public OpRewritePattern<TCGen5MMAOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TCGen5MMAOp op,
PatternRewriter &rewriter) const override {
// If the op doesn't have synchronous semantic skip the pattern.
if (op.getBarrier())
return failure();
MLIRContext *ctx = op.getContext();
Location loc = op.getLoc();
Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx);
auto barrierCTALayout = CTALayoutAttr::get(
/*context=*/ctx, /*CTAsPerCGA=*/{1},
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
auto barrierEncoding =
SharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout);
MemDescType barrierMemDescType =
MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding,
sharedMemorySpace, /*mutableMemory=*/true);
Value barrierAlloc =
rewriter.create<LocalAllocOp>(loc, barrierMemDescType, Value());
rewriter.create<InitBarrierOp>(loc, barrierAlloc, 1);
op.getBarrierMutable().assign(barrierAlloc);

rewriter.setInsertionPointAfter(op);
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase);
rewriter.create<InvalBarrierOp>(loc, barrierAlloc);
return success();
}
};

class TritonNvidiaGPUMMALoweringPass
: public TritonNvidiaGPUMMALoweringPassBase<
TritonNvidiaGPUMMALoweringPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

mlir::RewritePatternSet patterns(context);
patterns.add<SyncMMALowering>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};

} // namespace

std::unique_ptr<Pass> mlir::createTritonNvidiaGPUMMALoweringPass() {
return std::make_unique<TritonNvidiaGPUMMALoweringPass>();
}
68 changes: 68 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"

using namespace mlir;
using namespace triton;
using namespace triton::gpu;
using namespace triton::nvidia_gpu;

static int getMemSize(MemDescType memDescType) {
int sizeInBytes = product(memDescType.getShape()) *
memDescType.getElementType().getIntOrFloatBitWidth() / 8;
const int numRows = 128;
int numColumn = sizeInBytes / (128 * 4);
return numColumn;
}

namespace {

class TritionTensorMemoryAllocationPass
: public TritionTensorMemoryAllocationPassBase<
TritionTensorMemoryAllocationPass> {
public:
void runOnOperation() override {
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
SmallVector<triton::gpu::LocalAllocOp> allocs;
mod.walk([&](triton::gpu::LocalAllocOp alloc) { allocs.push_back(alloc); });
int totalMemorySize = 0;
// For now simple bump pointer allocator.
for (triton::gpu::LocalAllocOp alloc : allocs) {
auto memDescType = alloc.getType();
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>(
memDescType.getMemorySpace()))
continue;
alloc->setAttr(
"tensor_memory_offset",
IntegerAttr::get(IntegerType::get(ctx, 32), totalMemorySize));
totalMemorySize += getMemSize(memDescType);
}
std::array<int, 6> possibleAllocations = {0, 32, 64, 128, 256, 512};
if (totalMemorySize > 512)
llvm::report_fatal_error("Exceeded the maximum amount of tensor memory.");

for (int size : possibleAllocations) {
if (totalMemorySize <= size) {
totalMemorySize = size;
break;
}
}

mod->setAttr("triton_gpu.tensor_memory_size",
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
totalMemorySize));
}
};

} // namespace

std::unique_ptr<Pass> mlir::createTensorMemoryAllocationPass() {
return std::make_unique<TritionTensorMemoryAllocationPass>();
}
18 changes: 18 additions & 0 deletions test/Conversion/nvgpu_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 65544 : i32, triton_gpu.target = "cuda:100", triton_gpu.tensor_memory_size = 128 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @tensor_memory_base_lowering
// CHECK: %[[SHMEM:.+]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3>
// CHECK: %[[A:.+]] = llvm.inline_asm has_side_effects
// CHECK-SAME: "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], 128", "r" %[[SHMEM]] : (!llvm.ptr<3>) -> !llvm.void
// CHECK: %[[AR:.+]] = llvm.load %[[SHMEM]] : !llvm.ptr<3> -> i32
// CHECK: nvvm.barrier0
// CHECK: "tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned", "" : () -> !llvm.void
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.dealloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], 128", "r" %{{.+}} : (!llvm.ptr<6>) -> !llvm.void
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
llvm.func @tensor_memory_base_lowering() -> i32 attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = array<i32: 128>} {
%263 = nvgpu.tensor_memory_base
%264 = llvm.ptrtoint %263 : !llvm.ptr<6> to i32
llvm.return %264 : i32
}
}
22 changes: 12 additions & 10 deletions test/Conversion/tritongpu_to_llvm_blackwell.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :

// -----

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.shared" = 1024 : i32, "triton_gpu.tensor_memory_size" = 32 : i32} {
// CHECK-LABEL: @tensor_mem
// CHECK: %[[SHMEM:.+]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3>
// CHECK: %[[A:.+]] = llvm.inline_asm has_side_effects
// CHECK-SAME: "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], 32", "r" %[[SHMEM]] : (!llvm.ptr<3>) -> !llvm.void
// CHECK: %[[AR:.+]] = llvm.load %[[SHMEM]] : !llvm.ptr<3> -> i32
// CHECK: nvvm.barrier0
// CHECK: "tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned", "" : () -> !llvm.void
// CHECK: %4 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.dealloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], 32", "r" %2 : (i32) -> !llvm.void
tt.func @tensor_mem() {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 65544 : i32, triton_gpu.target = "cuda:100", triton_gpu.tensor_memory_size = 128 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @tensor_memory_ld
// CHECK: nvgpu.tensor_memory_base
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.ld.sync.aligned.32x32b.x128.b32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127}, [$128] $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127, $128;", "=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,r" %{{.*}} : (!llvm.ptr<3>) -> vector<128xi32>
tt.func public @tensor_memory_ld(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
%0 = triton_gpu.local_alloc %cst_0 {tensor_memory_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !tt.memdesc<128x128xf32, #shared, #triton_nvidia_gpu.tensor_memory, mutable>
%20 = triton_gpu.local_load %0 : !tt.memdesc<128x128xf32, #shared, #triton_nvidia_gpu.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1>
tt.return
}
}
14 changes: 14 additions & 0 deletions test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: triton-opt %s -split-input-file -triton-tensor-memory-allocation | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 65536 : i32, triton_gpu.target = "cuda:100", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK: triton_gpu.tensor_memory_size = 128
// CHECK: alloc_tensor_memory
tt.func public @alloc_tensor_memory(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
// CHECK: triton_gpu.local_alloc %{{.+}} {tensor_memory_offset = 0 : i32}
%0 = triton_gpu.local_alloc %cst : (tensor<128x128xf32, #blocked>) -> !tt.memdesc<128x128xf32, #shared, #triton_nvidia_gpu.tensor_memory, mutable>
tt.return
}
}
2 changes: 2 additions & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,12 @@ def make_llir(src, metadata, options, capability):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.convert.add_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
passes.convert.add_arith_to_llvmir(pm)
Expand Down
11 changes: 11 additions & 0 deletions third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include "NVGPUAttrDefs.td"

def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
def LLVM_PointerTensorMemory : LLVM_PointerInAddressSpace<6>;

class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
Expand Down Expand Up @@ -136,4 +137,14 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
let assemblyFormat = "attr-dict";
}

def NVGPU_TensorMemoryBaseAddress : NVGPU_Op<"tensor_memory_base", [Pure]> {
let description = [{
Op to represent base address of tensor memory in a kernel.
This is used to simplify lowering from TritonGPU to LLVM.
}];
let results = (outs LLVM_PointerTensorMemory:$result);
let assemblyFormat = "attr-dict";
}


#endif
Loading

0 comments on commit 51b2405

Please sign in to comment.