Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DIALECT] Rename triton_gpu to ttg and triton_nvidia_gpu to ttng #5266

Merged
merged 4 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ using namespace mlir;
// clang-format off
// Example usage:
//
// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
//
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt
//
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view
//
// An input file usually looks like:
// '''
// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
// '''
// clang-format on

Expand Down Expand Up @@ -83,7 +83,7 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace();

// Dispatch to the corresponding dialect helper function to print the layout.
if (dialectName == "triton_gpu") {
if (dialectName == "ttg") {
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
}
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ class Allocation {
private:
/// A class that represents a shared memory buffer
struct BufferT {
/// Explicit: triton_gpu.local_alloc
/// Scratch: triton_gpu.convert_layout
/// Explicit: ttg.local_alloc
/// Scratch: ttg.convert_layout
/// Virtual: triton.call
enum class BufferKind { Explicit, Scratch, Virtual };

Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,

ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
"triton_gpu.global_scratch_memory_size");
"ttg.global_scratch_memory_size");
if (!allocSizeAttr) {
return gmemBase;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ template <typename T> class OperationPass;

namespace triton {

constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
constexpr static char AttrTargetName[] = "triton_gpu.target";
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
constexpr static char AttrTargetName[] = "ttg.target";

constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";

// Create the pass with numWarps passed from cl::opt.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
Expand Down
8 changes: 4 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg)
add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc)
add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonGPUTableGen)
Expand Down
6 changes: 3 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warp

for

#triton_gpu.blocked_layout<{
#ttg.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
Expand All @@ -642,7 +642,7 @@ Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warp
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for

#triton_gpu.blocked_layout<{
#ttg.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
Expand Down Expand Up @@ -672,7 +672,7 @@ CTA [1,0] CTA [1,1]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for

#triton_gpu.blocked_layout<{
#ttg.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
Expand Down
18 changes: 9 additions & 9 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
include "mlir/IR/OpBase.td"

def TritonGPU_Dialect : Dialect {
let name = "triton_gpu";
let name = "ttg";

let cppNamespace = "::mlir::triton::gpu";

Expand All @@ -21,24 +21,24 @@ def TritonGPU_Dialect : Dialect {
];

let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if (!mod->hasAttr("triton_gpu.num-warps"))
if (!mod->hasAttr("ttg.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-warps")).getInt();
"TritonGPU module should contain a ttg.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
}
static int getNumCTAs(ModuleOp mod) {
if (!mod->hasAttr("triton_gpu.num-ctas"))
if (!mod->hasAttr("ttg.num-ctas"))
return 1;
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-ctas")).getInt();
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
}
void registerTypes();

static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
static std::string getThreadsPerWarpAttrName() { return "ttg.threads-per-warp"; }

static int getThreadsPerWarp(ModuleOp mod) {
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
Attribute threadsPerWarp = mod->getDiscardableAttr("ttg.threads-per-warp");
if(!threadsPerWarp) {
return 32;
}
Expand Down
10 changes: 5 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM

Because we assume a memdesc is dead at the first point that post-dominates
its uses, ops that wait for an async operation on a memdesc to complete
(such as triton_nvidia_gpu.warp_group_dot_wait) should also take the memdesc as an
(such as ttng.warp_group_dot_wait) should also take the memdesc as an
operand.
}];

let arguments = (ins TTG_MemDescType:$src);

// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
}

Expand All @@ -215,7 +215,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
let arguments = (
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);

// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];

let results = (outs TTG_MemDescType:$result);
Expand Down Expand Up @@ -262,7 +262,7 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffe
build($_builder, $_state, retType, src, /*token=*/static_cast<mlir::Value>(nullptr));
}]>];

// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}];

let results = (outs TT_Tensor:$result);
Expand All @@ -277,7 +277,7 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
let arguments = (ins TT_Tensor:$src, TTG_MemDescType:$dst);

let hasVerifier = 1;
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{
$src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst))
}];
Expand Down
8 changes: 4 additions & 4 deletions include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttng)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttng)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttng)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttng)
add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc)
add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonNvidiaGPUTableGen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
include "mlir/IR/OpBase.td"

def TritonNvidiaGPU_Dialect : Dialect {
let name = "triton_nvidia_gpu";
let name = "ttng";

let cppNamespace = "::mlir::triton::nvidia_gpu";

Expand All @@ -43,18 +43,18 @@ def TritonNvidiaGPU_Dialect : Dialect {
];

let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
if(!mod->hasAttr("ttg.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-warps")).getInt();
"TritonGPU module should contain a ttg.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
}
static int getNumCTAs(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-ctas"))
if(!mod->hasAttr("ttg.num-ctas"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-ctas")).getInt();
"TritonGPU module should contain a ttg.num-ctas attribute");
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
}
void registerTypes();
}];
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct AllocateSharedMemory
IntegerAttr::get(IntegerType::get(ctx, 32), offset));
});
});
mod->setAttr("triton_gpu.shared",
mod->setAttr("ttg.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
allocation.getSharedMemorySize()));
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
}

auto opOffsetAttr = caller->getAttrOfType<mlir::IntegerAttr>(
"triton_gpu.global_scratch_memory_offset");
"ttg.global_scratch_memory_offset");
Value opOffsetVal;
if (opOffsetAttr) {
auto opOffset = opOffsetAttr.getValue().getZExtValue();
Expand Down
22 changes: 11 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ static void allocateGMem(Operation *parentOp,
// Recursively visit any dependency functions
parentOp->walk([&](triton::CallOp call) {
auto callable = call.resolveCallable();
if (!callable->hasAttr("triton_gpu.global_scratch_memory_size")) {
if (!callable->hasAttr("ttg.global_scratch_memory_size")) {
auto inserted = callStack.insert(parentOp);
assert(inserted && "call cycle detected");
allocateGMem(callable, callStack);
Expand All @@ -46,9 +46,9 @@ static void allocateGMem(Operation *parentOp,
} else if (auto callOp = dyn_cast<triton::CallOp>(op)) {
auto callable = callOp.resolveCallable();
auto nbytes_attr = callable->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_size");
"ttg.global_scratch_memory_size");
auto align_attr = callable->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_alignment");
"ttg.global_scratch_memory_alignment");
assert(nbytes_attr);
assert(align_attr);

Expand All @@ -57,16 +57,16 @@ static void allocateGMem(Operation *parentOp,
}
if (nbytes > 0) {
offset = roundUp(offset, align);
op->setAttr("triton_gpu.global_scratch_memory_offset",
op->setAttr("ttg.global_scratch_memory_offset",
builder.getI32IntegerAttr(offset));
offset += nbytes;
largestAlignment = std::max(largestAlignment, align);
}
});
int32_t totalMemorySize = roundUp(offset, largestAlignment);
parentOp->setAttr("triton_gpu.global_scratch_memory_size",
parentOp->setAttr("ttg.global_scratch_memory_size",
builder.getI32IntegerAttr(totalMemorySize));
parentOp->setAttr("triton_gpu.global_scratch_memory_alignment",
parentOp->setAttr("ttg.global_scratch_memory_alignment",
builder.getI32IntegerAttr(largestAlignment));
}

Expand All @@ -86,14 +86,14 @@ class TritonGPUGlobalScratchAllocationPass
if (func.getVisibility() == SymbolTable::Visibility::Public) {
assert(!seenKernel);
seenKernel = true;
auto size = func->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_size");
auto size =
func->getAttrOfType<IntegerAttr>("ttg.global_scratch_memory_size");
auto align = func->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_alignment");
"ttg.global_scratch_memory_alignment");
assert(size);
assert(align);
mod->setAttr("triton_gpu.global_scratch_memory_size", size);
mod->setAttr("triton_gpu.global_scratch_memory_alignment", align);
mod->setAttr("ttg.global_scratch_memory_size", size);
mod->setAttr("ttg.global_scratch_memory_alignment", align);
}
});
assert(seenKernel);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct GlobalScratchAllocOpConversion
Location loc = op.getLoc();

auto opOffsetAttr = op->getAttrOfType<mlir::IntegerAttr>(
"triton_gpu.global_scratch_memory_offset");
"ttg.global_scratch_memory_offset");
assert(opOffsetAttr);
auto opOffset = opOffsetAttr.getValue().getZExtValue();

Expand Down
Loading
Loading