Skip to content

Commit

Permalink
Merge branch 'main' into mogball/unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mogball authored Nov 14, 2024
2 parents 342b4fe + 38f6a6d commit f7f6168
Show file tree
Hide file tree
Showing 18 changed files with 372 additions and 147 deletions.
2 changes: 1 addition & 1 deletion cmake/AddTritonUnitTest.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function(add_triton_ut)
# Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac
# laptop. I think the issue may be that the very first time you run a program
# it's a bit slow.
gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60)
gtest_discover_tests(${__NAME} DISCOVERY_TIMEOUT 60)

# Add the unit test to the top-level unit test target.
add_dependencies(TritonUnitTests ${__NAME})
Expand Down
15 changes: 12 additions & 3 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ namespace mlir {
namespace triton {
class AllocationAnalysis;

/// Callback to allow backends to specify target-specific scratch sizes for
/// some operations.
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;

unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);

// To convert a tensor from one layout to another, we need to allocate a
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
// require multiple iterations, with each iteration involving multiple
Expand Down Expand Up @@ -102,7 +108,8 @@ class Allocation {
explicit Allocation(Operation *operation) : operation(operation) {}

/// Runs allocation analysis on the given top-level operation.
void run(FuncAllocMapT &funcAllocMap);
void run(FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);

/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
Expand Down Expand Up @@ -250,7 +257,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;

explicit ModuleAllocation(ModuleOp moduleOp)
ModuleAllocation(ModuleOp moduleOp,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
triton::defaultAllocationAnalysisScratchSizeFn)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
// Pre-order edge walk callback
Expand All @@ -259,7 +268,7 @@ class ModuleAllocation : public CallGraph<Allocation> {
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.run(funcMap);
iter->second.run(funcMap, scratchSizeGetter);
});
}

Expand Down
139 changes: 71 additions & 68 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,70 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
return scratchConfig;
}

unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
return helper.getScratchSizeInBytes();
}
if (auto scanOp = dyn_cast<ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
return helper.getScratchSizeInBytes();
}
if (auto histogram = dyn_cast<HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
return std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
}
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType();
auto dstTy = cvtLayout.getType();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
// Conversions from/to shared memory do not need scratch memory.
return 0;
}
// ConvertLayoutOp with both input/output non-shared_layout
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
// also possible to realize it with other approaches in restricted
// conditions, such as warp-shuffle
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
return isa<PointerType>(srcTy.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
}
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (dyn_cast<RankedTensorType>(value.getType())) {
return 0;
}
auto smemShape = getRepShapeForAtomic(op->getResult(0));
auto elems = getNumScratchElements(smemShape);
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
}
if (auto createTensormap = dyn_cast<ExperimentalTensormapCreateOp>(op)) {
constexpr int32_t kTMASize = 128;
return kTMASize;
}
return 0;
}

class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation,
Allocation::FuncAllocMapT *funcAllocMap,
Allocation *allocation)
Allocation *allocation,
AllocationAnalysisScratchSizeFn scratchSizeGetter)
: operation(operation), funcAllocMap(funcAllocMap),
allocation(allocation) {
allocation(allocation), scratchSizeGetter(scratchSizeGetter) {
run();
}

Expand Down Expand Up @@ -177,77 +234,19 @@ class AllocationAnalysis {

/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
const size_t scratchAlignment = 128;
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto scanOp = dyn_cast<ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto histogram = dyn_cast<HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
auto bytes = std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType();
auto dstTy = cvtLayout.getType();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
// Conversions from/to shared memory do not need scratch memory.
return;
}
// ConvertLayoutOp with both input/output non-shared_layout
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
// also possible to realize it with other approaches in restricted
// conditions, such as warp-shuffle
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
auto bytes =
isa<PointerType>(srcTy.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (dyn_cast<RankedTensorType>(value.getType())) {
// nothing to do
} else {
auto smemShape = getRepShapeForAtomic(op->getResult(0));
auto elems = getNumScratchElements(smemShape);
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
auto bytes =
elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
constexpr size_t scratchAlignment = 128;
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
scratchAlignment);
} else if (auto createTensormap =
dyn_cast<ExperimentalTensormapCreateOp>(op)) {
constexpr int32_t kTMASize = 128;
constexpr int32_t kTMAAlign = 128;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize,
kTMAAlign);
return;
}
unsigned bytes = scratchSizeGetter(op);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}

void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
Expand Down Expand Up @@ -547,12 +546,16 @@ class AllocationAnalysis {
Allocation::FuncAllocMapT *funcAllocMap;
Allocation *allocation;
BufferRangeMapT bufferRange;
AllocationAnalysisScratchSizeFn scratchSizeGetter;
};

} // namespace triton

void Allocation::run(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
void Allocation::run(
FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this,
scratchSizeGetter);
}

std::map<Operation *, SmallVector<Allocation::BufferId>>
Expand Down
13 changes: 9 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,17 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(Attribute layout) {
static bool isSupportedDotOpLayout(RankedTensorType type) {
auto layout = type.getEncoding();
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto kWidth = dot.getKWidth();
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
// - kWidth == 8
// - kWidth == 4, bitwidth = 32
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
bool legacyLoweringIsBuggy = dot.getKWidth() >= 8;
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
return legacyLoweringIsBuggy && mma.isAmpere();
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
Expand All @@ -162,7 +167,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout) ||
isSupportedDotOpLayout(dstLayout))) {
isSupportedDotOpLayout(dstTy))) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -202,7 +207,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
43 changes: 30 additions & 13 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
import triton.language as tl

input_dtypes = ["float16", "float32", "float64"]
if triton.runtime.driver.active.get_current_target().backend == "cuda":
input_dtypes += ["int8", "float8_e5m2"]
cc = torch.cuda.get_device_capability(0)
if cc >= (8, 9):
input_dtypes += ["float8_e4m3fn"]
out_dtypes = ["float16", "float32"]


Expand Down Expand Up @@ -63,37 +68,49 @@ def matmul_kernel(A, B, C, M, N, K, #
tl.store(C, acc, mask=mask)


@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype",
[(M, K, N, w, x, o) #
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] #
@pytest.mark.parametrize("M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype",
[(M, K, N, BLOCK_K, w, x, o) #
for BLOCK_K in [16, 32] #
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] #
for w in input_dtypes
for x in input_dtypes #
for o in out_dtypes])
def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype):
if x_dtype == w_dtype:
pytest.skip("skip the same input dtype")
device = torch.cuda.current_device()
x_dtype = getattr(torch, x_dtype)
w_dtype = getattr(torch, w_dtype)
a = torch.randn((M, K), device=device, dtype=x_dtype)
b = torch.randn((K, N), device=device, dtype=w_dtype)
x_dtype: torch.dtype = getattr(torch, x_dtype)
w_dtype: torch.dtype = getattr(torch, w_dtype)

def init_tensor(dtype, shape):
if dtype == torch.int8:
return torch.randint(0, 2, shape, device=device, dtype=dtype)
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
return torch.randn(shape, device=device, dtype=torch.float16).to(dtype)
else:
return torch.randn(shape, device=device, dtype=dtype)

torch.manual_seed(42)
a = init_tensor(w_dtype, (M, K))
b = init_tensor(x_dtype, (K, N))

torch_dtype = getattr(torch, out_dtype)
triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)

# launch kernel
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = ((triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), 1)
block_m, block_n, block_k = 16, 16, BLOCK_K
grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1)

matmul_kernel[grid](
a, b, out_triton, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
GROUP_M=8, #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
BLOCK_K=BLOCK_K)
BLOCK_M=block_m, #
BLOCK_N=block_n, #
BLOCK_K=block_k)

torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01)
Loading

0 comments on commit f7f6168

Please sign in to comment.