Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren committed Nov 19, 2024
1 parent 689dcfe commit e9a1d0f
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 38 deletions.
8 changes: 6 additions & 2 deletions include/triton/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td)
mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td)
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)

add_public_tablegen_target(TritonTableGen)
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
Expand Down
34 changes: 34 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOpInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef TRITON_OP_INTERFACES
#define TRITON_OP_INTERFACES

include "mlir/IR/OpBase.td"


def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
let description = [{
This interface is implemented by operations that perform a transpose.
It provides methods to access common properties such as the order attribute and the source operand.
}];

let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<
/*desc=*/[{
Get the source operand of the transposition.
}],
/*retType=*/"::mlir::Value",
/*methodName=*/"getSrc",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Get the order of the transposition.
}],
/*retType=*/"::mlir::ArrayRef<int32_t>",
/*methodName=*/"getOrder",
/*args=*/(ins)>
];
}


#endif // TRITON_OP_INTERFACES
58 changes: 28 additions & 30 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"


//
Expand All @@ -44,8 +45,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Cast int64 to pointer";

let arguments = (ins TT_I64Like:$src);
Expand All @@ -58,8 +58,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Cast pointer to int64";

let arguments = (ins TT_PtrLike:$src);
Expand All @@ -73,8 +72,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Cast between types of the same bitwidth";

let arguments = (ins TT_Type:$src);
Expand All @@ -89,8 +87,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Floating point casting for custom types";

let description = [{
Expand Down Expand Up @@ -118,8 +115,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
//

def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Clamp operation for floating point types";

let description = [{
Expand Down Expand Up @@ -149,8 +146,8 @@ def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
//

def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Precise sqrt for floating point types";

let description = [{
Expand All @@ -165,8 +162,8 @@ def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
}

def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Precise div for floating point types";

let description = [{
Expand All @@ -181,8 +178,8 @@ def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
}

def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Most significant N bits of the 2N-bit product of two integers";

let description = [{
Expand All @@ -200,12 +197,12 @@ def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
// Pointer Arith Ops
//
def TT_AddPtrOp : TT_Op<"addptr",
[Pure,
Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
[Pure,
Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);

let results = (outs TT_PtrLike:$result);
Expand Down Expand Up @@ -546,6 +543,7 @@ def TT_SplitOp : TT_Op<"split", [
}

def TT_TransOp : TT_Op<"trans", [Pure,
TransposeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {

Expand Down Expand Up @@ -677,10 +675,10 @@ def TT_DotOp : TT_Op<"dot", [Pure,
// DotScaled Op
//
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
AttrSizedOperandSegments,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
AttrSizedOperandSegments,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot_scaled";

let description = [{
Expand Down Expand Up @@ -783,10 +781,10 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
// External Elementwise op
//
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ConditionallySpeculatable]> {
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ConditionallySpeculatable]> {

let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.h.inc"

#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc"
#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc"

namespace mlir {

Expand Down
21 changes: 21 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
Expand Down Expand Up @@ -221,6 +222,26 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
let hasVerifier = 1;
}

def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, TransposeOpInterface]> {
let summary = "transpose the descriptor";

let description = [{
This operation returns a new descriptor
representing a transposed view of the buffer.
}];

let arguments = (ins TT_MemDescType:$src, Variadic<I32>:$order);

let arguments = (
ins TT_MemDescType:$src,
DenseI32ArrayAttr:$order
);

let results = (outs TT_MemDescType:$result);

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Load a buffer from local memory into a distributed tensor";

Expand Down
5 changes: 2 additions & 3 deletions lib/Analysis/Alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation(
if (isa<triton::gpu::LocalAllocOp>(op)) {
aliasInfo.insert(result);
pessimistic = false;
} else if (isa<triton::gpu::MemDescSubviewOp, triton::TransOp>(op)) {
// extract_slice %src
// trans %src
} else if (isa<triton::gpu::MemDescSubviewOp, triton::gpu::MemDescTransOp>(
op)) {
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else {
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/Triton/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
Expand All @@ -12,8 +11,10 @@
#include "mlir/IR/DialectImplementation.h"

#include "mlir/Transforms/InliningUtils.h"
#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc"
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc"
#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc"
#include "triton/Dialect/Triton/IR/TypeInterfaces.cpp.inc"

using namespace mlir;
using namespace mlir::triton;
Expand Down

0 comments on commit e9a1d0f

Please sign in to comment.