Skip to content

Commit

Permalink
[BW] Enable compilation until ptx for simple dot (#20)
Browse files Browse the repository at this point in the history
* [BW] Enable compilation until ptx for simple dot

* turn off GPU tests
  • Loading branch information
ThomasRaoux authored Jun 15, 2024
1 parent 51b2405 commit d23c360
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 19 deletions.
11 changes: 6 additions & 5 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ steps:
- cd python/test/unit/ && pytest -vs language/test_compile_only.py
agents:
queue: blackwell-crow-cpu-v54
- label: sample GPU tests
commands:
- echo "Running GPU tests"
agents:
queue: blackwell-crow-h100-1gpu-v54
# Turn off GPU tests until we need it.
#- label: sample GPU tests
# commands:
# - echo "Running GPU tests"
# agents:
# queue: blackwell-crow-h100-1gpu-v54
25 changes: 20 additions & 5 deletions python/test/unit/language/test_compile_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_compile_only_dot() -> None:

@triton.jit
def simple_dot(a_base, b_base, out):
SIZE: tl.constexpr = 128
SIZE: tl.constexpr = 64
a_ptr = a_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :]
b_ptr = b_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :]
a = tl.load(a_ptr)
Expand All @@ -33,13 +33,10 @@ def simple_dot(a_base, b_base, out):
out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :]
tl.store(out_ptr, c)

os.environ["TRITON_DISABLE_LLIR"] = "1"
os.environ["TRITON_DISABLE_PTX"] = "1"
os.environ["TRITON_DISABLE_CUBIN"] = "1"
k = triton.compile(triton.compiler.ASTSource(fn=simple_dot, signature="*fp16,*fp16,*fp16", constants={}),
target=GPUTarget("cuda", 100, 32))
ttgir = k.asm["ttgir"]
print(ttgir)

pattern = (r"%(?P<TMEM_BASE>\d+) = triton_gpu\.local_alloc"
r"(.|\n)*"
Expand All @@ -56,7 +53,25 @@ def simple_dot(a_base, b_base, out):
r"triton_gpu\.local_load %(?P=TMEM_BASE)")

assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern."


ptx = k.asm["ptx"]
pattern = (r"mov\.u32 %r(?P<G>\d+), global_smem;"
r"(.|\n)*"
r"tcgen05\.alloc\.cta_group::1\.sync\.aligned\.shared::cta\.b32 \[%r(?P=G)], 32"
r"(.|\n)*"
r"tcgen05\.relinquish_alloc_permit\.cta_group::1\.sync\.aligned"
r"(.|\n)*"
r"tcgen05\.mma\.cta_group::1::kind::f16"
r"(.|\n)*"
r"tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64"
r"(.|\n)*"
r"mbarrier.try_wait.parity.shared.b64"
r"(.|\n)*"
r"tcgen05.ld.sync.aligned.32x32b.x32.b32"
r"(.|\n)*"
r"tcgen05.wait::ld.sync.aligned")
assert re.search(pattern, str(ptx)), "The PTX does not match the expected pattern."


def test_compile_only_k_loop() -> None:

Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm_blackwell.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 65544 : i32, triton_gpu.target = "cuda:100", triton_gpu.tensor_memory_size = 128 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @tensor_memory_ld
// CHECK: nvgpu.tensor_memory_base
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.ld.sync.aligned.32x32b.x128.b32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127}, [$128] $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127, $128;", "=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,r" %{{.*}} : (!llvm.ptr<3>) -> vector<128xi32>
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.ld.sync.aligned.32x32b.x128.b32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127}, [$128] $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127, $128;", "=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,r" %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
tt.func public @tensor_memory_ld(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
%0 = triton_gpu.local_alloc %cst_0 {tensor_memory_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !tt.memdesc<128x128xf32, #shared, #triton_nvidia_gpu.tensor_memory, mutable>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ void convertDot(const LLVMTypeConverter *typeConverter,
}
}
}
createMMACommit(rewriter, loc, barrier);
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
op.getLoc(), barrier, i64_ty, rewriter);
createMMACommit(rewriter, loc, smemObj.getBase());
}

struct TCGen5MMAOpConversion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ struct TensorMemoryAllocOpConversion
.getValue()
.getZExtValue();
Value allocAddress = add(baseInt, i32_val(offset));
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 6);
// Cast to address space 3 as the shared memory object uses 3.
// TODO: clean this up and use either a int or ptr address space 6
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
Value ptr = inttoptr(ptrTy, allocAddress);

SmallVector<unsigned> order(op.getType().getRank());
Expand Down Expand Up @@ -64,18 +66,20 @@ static Value createTensorMemoryLoad(Location loc, triton::gpu::LocalLoadOp op,
operands.push_back(ptxBuilder.newOperand(address, "r"));
auto &ld = *ptxBuilder.create<PTXInstr>(opcode);
ld(operands);
auto vecTy = vec_ty(i32_ty, numCol);
Value ret = ptxBuilder.launch(rewriter, loc, vecTy);
SmallVector<Type> elemTypes(numCol, i32_ty);
MLIRContext *ctx = op.getContext();
Type structTy = struct_ty(elemTypes);
Value ret = ptxBuilder.launch(rewriter, loc, structTy);
return ret;
}

static SmallVector<Value> unpackResults(Value packedValues, Type elemTy,
int numCols,
Location loc,
ConversionPatternRewriter &rewriter) {
SmallVector<Value> resultVals;
for (int i = 0; i < cast<VectorType>(packedValues.getType()).getNumElements();
i++) {
Value result = extract_element(i32_ty, packedValues, i32_val(i));
for (int i = 0; i < numCols; i++) {
Value result = extract_val(i32_ty, packedValues, i);
// TODO: handle types smaller than 32bits.
result = bitcast(result, elemTy);
resultVals.push_back(result);
Expand Down Expand Up @@ -120,7 +124,7 @@ struct TensorMemoryLoadOpConversion
Value packedValues =
createTensorMemoryLoad(loc, op, address, numCols, rewriter);
SmallVector<Value> resultVals = unpackResults(
packedValues, op.getType().getElementType(), loc, rewriter);
packedValues, op.getType().getElementType(), numCols, loc, rewriter);
Type structTy = getTypeConverter()->convertType(op.getType());
Value resultStruct =
packLLElements(loc, getTypeConverter(), resultVals, rewriter, structTy);
Expand Down

0 comments on commit d23c360

Please sign in to comment.