Skip to content

Commit

Permalink
[TRANSFORM] Disable block ptr in the TMA mode (triton-lang#14)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update

* Update
  • Loading branch information
Jokeren authored Jul 21, 2023
1 parent a67b21b commit 1a7a7ac
Show file tree
Hide file tree
Showing 19 changed files with 343 additions and 326 deletions.
8 changes: 6 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# List of TODOs before merging into main

## debug flags

** ENABLE_TMA: ** This flag enables TMA related code.
** ENABLE_MMA_V3: ** This flag enables MMA V3 related code.
## cleanups

* Hard coded alignment, don't why it exists and why it is that number
Expand All @@ -17,7 +21,7 @@ https://github.com/openai/triton-hopper/blob/1ada046fdaef13f94dc7e2f6e6d0966e5d1
* linearize/delinearize helper have been duplicated (most likely due to layering problems). This should be merged
* Try scf.if in pipeline to replace remui when indices are at the boundaries
* Clean up waitIdx, phase, and other indices. Now there are a bunch of loop-carried variables.
https://github.com/openai/triton-hopper/blob/9453151688804ebaf8bebca38a62ada5bb343d3c/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#L166
https://github.com/openai/triton-hopper/blob/9453151688804ebaf8bebca38a62ada5bb343d3c/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#L166
* Get rid of the hacky mode variable in pipeline
https://github.com/openai/triton-hopper/blob/9453151688804ebaf8bebca38a62ada5bb343d3c/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#L180
* Pipeline shouldn't have special handling for hopper, ideally it is agnostic to the architecture
Expand All @@ -34,4 +38,4 @@ https://github.com/openai/triton-hopper/blob/1ada046fdaef13f94dc7e2f6e6d0966e5d1
* We rely on the `cuda-python` package currently, which prevents us from building triton on any node without CUDA installed. We should invoke TMA related functions in our thin CUDA wrapper.
https://github.com/openai/triton-hopper/blob/b6a6b32b0ee79e93247d20c95f15fd75039a40b9/python/triton/compiler/utils.py#L3
* Pipeline doesn't handle block ptrs correctly
* Pipeline doesn't handle TMAs correctly
* Pipeline doesn't handle TMAs correctly
9 changes: 3 additions & 6 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@

namespace mlir {

// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool convertLayoutUseDSmem(const Attribute &srcLayout,
const Attribute &dstLayout);

class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op)
Expand Down Expand Up @@ -128,7 +123,9 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
bool isMmaToMmaShortcut(triton::gpu::MmaEncodingAttr &src,
triton::gpu::MmaEncodingAttr &dst);

Type getElementType(Value value);
// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ bool isaDistributedLayout(Attribute layout);

bool isSharedEncoding(Value value);

bool isExpensiveCat(CatOp cat, Attribute &targetEncoding);
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

} // namespace gpu
} // namespace triton
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

if (convertLayoutUseDSmem(srcLayout, dstLayout)) {
if (shouldUseDistSmem(srcLayout, dstLayout)) {
// TODO: padding to avoid bank conflicts
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}
Expand Down
10 changes: 1 addition & 9 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ bool ReduceOpHelper::isFastReduction() {
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented
// in the future
bool convertLayoutUseDSmem(const Attribute &srcLayout,
const Attribute &dstLayout) {
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout);
assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) &&
"Invalid layout conversion: the numbers of CTAs of src and dst "
Expand Down Expand Up @@ -382,13 +381,6 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}

bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
Expand Down
9 changes: 5 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,9 @@ struct ConvertLayoutOpConversion
}

LogicalResult
lowerDistToDistWithDSmem(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();

Value src = op.getSrc();
Expand Down Expand Up @@ -515,8 +516,8 @@ struct ConvertLayoutOpConversion
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

if (convertLayoutUseDSmem(srcLayout, dstLayout))
return lowerDistToDistWithDSmem(op, adaptor, rewriter);
if (shouldUseDistSmem(srcLayout, dstLayout))
return lowerDistToDistWithDistSmem(op, adaptor, rewriter);

auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,13 @@ static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
llvm_unreachable("unimplemented code path");
}

inline Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}

inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
Type srcTy,
ConversionPatternRewriter &rewriter,
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {

SmallVector<Value> emitCTAOffsetForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
Attribute layout,
ArrayRef<int64_t> shape) const {
unsigned rank = shape.size();
SmallVector<unsigned> CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
Expand Down
7 changes: 3 additions & 4 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ bool isSharedEncoding(Value value) {
return false;
}

bool isExpensiveCat(CatOp cat, Attribute &targetEncoding) {
bool isExpensiveCat(CatOp cat, Attribute targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
Expand Down Expand Up @@ -622,9 +622,8 @@ static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
return success();
}

static LogicalResult parseBoolAttrValue(AsmParser &parser,
const Attribute &attr, bool &value,
StringRef desc) {
static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr,
bool &value, StringRef desc) {
auto boolAttr = attr.dyn_cast<BoolAttr>();
if (!boolAttr) {
parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc;
Expand Down
Loading

0 comments on commit 1a7a7ac

Please sign in to comment.