Skip to content

Commit

Permalink
[DIALECT] Rename triton_gpu to ttg and triton_nvidia_gpu to `tt…
Browse files Browse the repository at this point in the history
…ng` (#5266)

It may cause changes for downstream tasks but we think it's beneficial
to shorten dialect name and make them consistent. That is, we are using
`tt` to represent the `triton` dialect.
  • Loading branch information
Jokeren authored Nov 27, 2024
1 parent 2003685 commit 6d3ed0b
Show file tree
Hide file tree
Showing 96 changed files with 4,781 additions and 4,786 deletions.
8 changes: 4 additions & 4 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ using namespace mlir;
// clang-format off
// Example usage:
//
// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
//
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt
//
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view
//
// An input file usually looks like:
// '''
// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
// '''
// clang-format on

Expand Down Expand Up @@ -83,7 +83,7 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace();

// Dispatch to the corresponding dialect helper function to print the layout.
if (dialectName == "triton_gpu") {
if (dialectName == "ttg") {
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
}
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ class Allocation {
private:
/// A class that represents a shared memory buffer
struct BufferT {
/// Explicit: triton_gpu.local_alloc
/// Scratch: triton_gpu.convert_layout
/// Explicit: ttg.local_alloc
/// Scratch: ttg.convert_layout
/// Virtual: triton.call
enum class BufferKind { Explicit, Scratch, Virtual };

Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,

ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
"triton_gpu.global_scratch_memory_size");
"ttg.global_scratch_memory_size");
if (!allocSizeAttr) {
return gmemBase;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ template <typename T> class OperationPass;

namespace triton {

constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
constexpr static char AttrTargetName[] = "triton_gpu.target";
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
constexpr static char AttrTargetName[] = "ttg.target";

constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";

// Create the pass with numWarps passed from cl::opt.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
Expand Down
8 changes: 4 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg)
add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc)
add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonGPUTableGen)
Expand Down
6 changes: 3 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warp

for

#triton_gpu.blocked_layout<{
#ttg.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
Expand All @@ -642,7 +642,7 @@ Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warp
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for

#triton_gpu.blocked_layout<{
#ttg.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
Expand Down Expand Up @@ -672,7 +672,7 @@ CTA [1,0] CTA [1,1]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for

#triton_gpu.blocked_layout<{
#ttg.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
Expand Down
18 changes: 9 additions & 9 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
include "mlir/IR/OpBase.td"

def TritonGPU_Dialect : Dialect {
let name = "triton_gpu";
let name = "ttg";

let cppNamespace = "::mlir::triton::gpu";

Expand All @@ -21,24 +21,24 @@ def TritonGPU_Dialect : Dialect {
];

let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if (!mod->hasAttr("triton_gpu.num-warps"))
if (!mod->hasAttr("ttg.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-warps")).getInt();
"TritonGPU module should contain a ttg.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
}
static int getNumCTAs(ModuleOp mod) {
if (!mod->hasAttr("triton_gpu.num-ctas"))
if (!mod->hasAttr("ttg.num-ctas"))
return 1;
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-ctas")).getInt();
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
}
void registerTypes();

static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
static std::string getThreadsPerWarpAttrName() { return "ttg.threads-per-warp"; }

static int getThreadsPerWarp(ModuleOp mod) {
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
Attribute threadsPerWarp = mod->getDiscardableAttr("ttg.threads-per-warp");
if(!threadsPerWarp) {
return 32;
}
Expand Down
10 changes: 5 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM

Because we assume a memdesc is dead at the first point that post-dominates
its uses, ops that wait for an async operation on a memdesc to complete
(such as triton_nvidia_gpu.warp_group_dot_wait) should also take the memdesc as an
(such as ttng.warp_group_dot_wait) should also take the memdesc as an
operand.
}];

let arguments = (ins TTG_MemDescType:$src);

// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
}

Expand All @@ -215,7 +215,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
let arguments = (
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);

// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];

let results = (outs TTG_MemDescType:$result);
Expand Down Expand Up @@ -262,7 +262,7 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffe
build($_builder, $_state, retType, src, /*token=*/static_cast<mlir::Value>(nullptr));
}]>];

// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}];

let results = (outs TT_Tensor:$result);
Expand All @@ -277,7 +277,7 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
let arguments = (ins TT_Tensor:$src, TTG_MemDescType:$dst);

let hasVerifier = 1;
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{
$src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst))
}];
Expand Down
8 changes: 4 additions & 4 deletions include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttng)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttng)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttng)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttng)
add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc)
add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonNvidiaGPUTableGen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
include "mlir/IR/OpBase.td"

def TritonNvidiaGPU_Dialect : Dialect {
let name = "triton_nvidia_gpu";
let name = "ttng";

let cppNamespace = "::mlir::triton::nvidia_gpu";

Expand All @@ -43,18 +43,18 @@ def TritonNvidiaGPU_Dialect : Dialect {
];

let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
if(!mod->hasAttr("ttg.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-warps")).getInt();
"TritonGPU module should contain a ttg.num-warps attribute");
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
}
static int getNumCTAs(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-ctas"))
if(!mod->hasAttr("ttg.num-ctas"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
return cast<IntegerAttr>(mod->getAttr("triton_gpu.num-ctas")).getInt();
"TritonGPU module should contain a ttg.num-ctas attribute");
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
}
void registerTypes();
}];
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct AllocateSharedMemory
IntegerAttr::get(IntegerType::get(ctx, 32), offset));
});
});
mod->setAttr("triton_gpu.shared",
mod->setAttr("ttg.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
allocation.getSharedMemorySize()));
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
}

auto opOffsetAttr = caller->getAttrOfType<mlir::IntegerAttr>(
"triton_gpu.global_scratch_memory_offset");
"ttg.global_scratch_memory_offset");
Value opOffsetVal;
if (opOffsetAttr) {
auto opOffset = opOffsetAttr.getValue().getZExtValue();
Expand Down
22 changes: 11 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ static void allocateGMem(Operation *parentOp,
// Recursively visit any dependency functions
parentOp->walk([&](triton::CallOp call) {
auto callable = call.resolveCallable();
if (!callable->hasAttr("triton_gpu.global_scratch_memory_size")) {
if (!callable->hasAttr("ttg.global_scratch_memory_size")) {
auto inserted = callStack.insert(parentOp);
assert(inserted && "call cycle detected");
allocateGMem(callable, callStack);
Expand All @@ -46,9 +46,9 @@ static void allocateGMem(Operation *parentOp,
} else if (auto callOp = dyn_cast<triton::CallOp>(op)) {
auto callable = callOp.resolveCallable();
auto nbytes_attr = callable->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_size");
"ttg.global_scratch_memory_size");
auto align_attr = callable->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_alignment");
"ttg.global_scratch_memory_alignment");
assert(nbytes_attr);
assert(align_attr);

Expand All @@ -57,16 +57,16 @@ static void allocateGMem(Operation *parentOp,
}
if (nbytes > 0) {
offset = roundUp(offset, align);
op->setAttr("triton_gpu.global_scratch_memory_offset",
op->setAttr("ttg.global_scratch_memory_offset",
builder.getI32IntegerAttr(offset));
offset += nbytes;
largestAlignment = std::max(largestAlignment, align);
}
});
int32_t totalMemorySize = roundUp(offset, largestAlignment);
parentOp->setAttr("triton_gpu.global_scratch_memory_size",
parentOp->setAttr("ttg.global_scratch_memory_size",
builder.getI32IntegerAttr(totalMemorySize));
parentOp->setAttr("triton_gpu.global_scratch_memory_alignment",
parentOp->setAttr("ttg.global_scratch_memory_alignment",
builder.getI32IntegerAttr(largestAlignment));
}

Expand All @@ -86,14 +86,14 @@ class TritonGPUGlobalScratchAllocationPass
if (func.getVisibility() == SymbolTable::Visibility::Public) {
assert(!seenKernel);
seenKernel = true;
auto size = func->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_size");
auto size =
func->getAttrOfType<IntegerAttr>("ttg.global_scratch_memory_size");
auto align = func->getAttrOfType<IntegerAttr>(
"triton_gpu.global_scratch_memory_alignment");
"ttg.global_scratch_memory_alignment");
assert(size);
assert(align);
mod->setAttr("triton_gpu.global_scratch_memory_size", size);
mod->setAttr("triton_gpu.global_scratch_memory_alignment", align);
mod->setAttr("ttg.global_scratch_memory_size", size);
mod->setAttr("ttg.global_scratch_memory_alignment", align);
}
});
assert(seenKernel);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct GlobalScratchAllocOpConversion
Location loc = op.getLoc();

auto opOffsetAttr = op->getAttrOfType<mlir::IntegerAttr>(
"triton_gpu.global_scratch_memory_offset");
"ttg.global_scratch_memory_offset");
assert(opOffsetAttr);
auto opOffset = opOffsetAttr.getValue().getZExtValue();

Expand Down
Loading

0 comments on commit 6d3ed0b

Please sign in to comment.