Skip to content

Commit eeab422

Browse files
committed
Merge branch 'main' of https://github.com/tile-ai/tilelang into test_0825
2 parents 09cd8f9 + b39aaf5 commit eeab422

File tree

5 files changed

+256
-122
lines changed

5 files changed

+256
-122
lines changed

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def matmul(M,
9090
num_bits=4,
9191
scale_size=32,
9292
fast_dequant=True,
93+
with_bias=False,
9394
block_M=256,
9495
block_N=128,
9596
block_K=128,
@@ -120,7 +121,8 @@ def matmul(M,
120121
num_stages (int, optional): pipelining stages for K loop (default 2).
121122
threads (int, optional): threads per block used by the kernel (default 256).
122123
split (int, optional): split factor along K used by the scheduler (default 1).
123-
124+
with_bias (bool, optional): whether to add Bias to the output (default False).
125+
124126
Returns:
125127
A T.prim_func implementing the tiled, pipelined GEMM that:
126128
- loads tiled blocks of A and packed B to shared memory,
@@ -139,9 +141,11 @@ def matmul(M,
139141
Block_QK = block_K // num_elems_per_byte
140142
A_shape = (M, K)
141143
B_shape = (N, QK)
144+
Bias_shape = (M, N)
142145
Scale_shape = (N, K // scale_size)
143146
A_shared_shape = (block_M, block_K)
144147
B_shared_shape = (block_N, Block_QK)
148+
Bias_shared_shape = (block_M, block_N)
145149
B_dequantize_shared_shape = (block_N, block_K)
146150
assert K % (block_K * split) == 0
147151

@@ -311,6 +315,7 @@ def main(
311315
A: T.Tensor(A_shape, in_dtype),
312316
B: T.Tensor(B_shape, storage_dtype),
313317
Scale: T.Tensor(Scale_shape, storage_dtype),
318+
Bias: T.Tensor(Bias_shape, out_dtype),
314319
C: T.Tensor((M, N), out_dtype),
315320
):
316321
"""
@@ -328,7 +333,7 @@ def main(
328333
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
329334
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
330335
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
331-
336+
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
332337
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
333338
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
334339

@@ -337,10 +342,22 @@ def main(
337342
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
338343
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
339344
})
345+
346+
if with_bias:
347+
T.annotate_layout({
348+
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
349+
})
350+
340351
if threads == 512:
341352
T.disable_warp_group_reg_alloc()
342353

343-
T.clear(C_local)
354+
if with_bias:
355+
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
356+
Bias_shared)
357+
T.copy(Bias_shared, C_local)
358+
else:
359+
T.clear(C_local)
360+
344361
for k in T.Pipelined(K // block_K, num_stages=num_stages):
345362
T.copy(A[by * block_M, k * block_K], A_shared)
346363
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
@@ -356,7 +373,7 @@ def main(
356373
return main
357374

358375

359-
def ref_program_twiddling(A, qB, Scale):
376+
def ref_program_twiddling(A, qB, Scale, Bias=None):
360377
"""
361378
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
362379
@@ -380,7 +397,32 @@ def ref_program_twiddling(A, qB, Scale):
380397
return C
381398

382399

383-
def ref_program_simple(A, qB, Scale):
400+
def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
401+
"""
402+
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
403+
404+
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
405+
406+
Parameters:
407+
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
408+
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
409+
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
410+
Bias (torch.Tensor): Bias tensor with shape (M, N).
411+
412+
Returns:
413+
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
414+
"""
415+
dtypeC = "bfloat16"
416+
B = torch_convert_bit_twiddling(qB)
417+
for i in range(B.shape[0]):
418+
for j in range(B.shape[1]):
419+
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
420+
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
421+
C = C.to(torch.__getattribute__(dtypeC))
422+
return C
423+
424+
425+
def ref_program_simple(A, qB, Scale, Bias=None):
384426
"""
385427
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
386428
@@ -406,7 +448,37 @@ def ref_program_simple(A, qB, Scale):
406448
return C
407449

408450

409-
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
451+
def ref_program_simple_with_bias(A, qB, Scale, Bias):
452+
"""
453+
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
454+
455+
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
456+
457+
Parameters:
458+
459+
Returns:
460+
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
461+
- qB: Quantized representation of B accepted by `torch_convert`.
462+
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
463+
- Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul).
464+
465+
466+
Returns:
467+
- 2D bfloat16 tensor C containing the matrix product A · B^T.
468+
469+
No in-place modification is performed on inputs (a local floating copy of B is scaled).
470+
"""
471+
dtypeC = "bfloat16"
472+
B = torch_convert(qB)
473+
for i in range(B.shape[0]):
474+
for j in range(B.shape[1]):
475+
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
476+
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
477+
C = C.to(torch.__getattribute__(dtypeC))
478+
return C
479+
480+
481+
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False):
410482
"""
411483
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
412484
@@ -435,7 +507,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
435507
"float32",
436508
num_bits=4,
437509
scale_size=scale_size,
438-
fast_dequant=fast_dequant)
510+
fast_dequant=fast_dequant,
511+
with_bias=with_bias)
439512
else:
440513
kernel = matmul(
441514
m,
@@ -452,14 +525,21 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
452525
num_stages=2,
453526
threads=256,
454527
split=1,
455-
fast_dequant=fast_dequant)
528+
fast_dequant=fast_dequant,
529+
with_bias=with_bias)
456530

457531
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
458532

459533
if fast_dequant:
460-
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
534+
if with_bias:
535+
profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01)
536+
else:
537+
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
461538
else:
462-
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
539+
if with_bias:
540+
profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01)
541+
else:
542+
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
463543
print("All checks pass.")
464544
latency = profiler.do_bench(warmup=500)
465545
print("Tile-lang: {:.2f} ms".format(latency))
@@ -469,5 +549,7 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
469549
if __name__ == "__main__":
470550
M, N, K = 256, 256, 256
471551
scale_size = 32
472-
main(M, N, K, scale_size, fast_dequant=True)
473-
main(M, N, K, scale_size, fast_dequant=False)
552+
main(M, N, K, scale_size, fast_dequant=True, with_bias=True)
553+
main(M, N, K, scale_size, fast_dequant=False, with_bias=True)
554+
main(M, N, K, scale_size, fast_dequant=True, with_bias=False)
555+
main(M, N, K, scale_size, fast_dequant=False, with_bias=False)

src/tl_templates/cuda/common.h

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -240,53 +240,4 @@ template <int barrier_id = 0, int thread_count = 0>
240240
TL_DEVICE void __sync_thread_partial() {
241241
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count));
242242
}
243-
244-
// Template parameter:
245-
// thread_extent: the logical size (in number of threads) of each "group"
246-
// within which we want to elect exactly ONE representative
247-
// thread.
248-
template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {
249-
250-
// Special case: thread_extent == 0 means "elect exactly one thread
251-
// in the entire thread block", i.e., the leader of the first warp of the
252-
// block.
253-
if constexpr (thread_extent == 0) {
254-
// cutlass::canonical_warp_idx_sync():
255-
// Returns the warp ID within the thread block in a "canonical" way
256-
// (0 for the first warp, 1 for the second, ...).
257-
// cute::elect_one_sync():
258-
// Elect exactly one lane in the warp to return true (typically lane 0),
259-
// other lanes return false.
260-
// The condition ensures that:
261-
// (1) We are in warp 0 of the block.
262-
// (2) We are the elected lane in this warp.
263-
return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync();
264-
}
265-
266-
// General case: thread_extent != 0
267-
// (threadIdx.x / 32) is the warp index in the block.
268-
// (thread_extent / 32) is the number of warps in one group of size
269-
// thread_extent. We take warp_id % num_warps_in_group to get the warp's index
270-
// within the group.
271-
// __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all
272-
// lanes in the warp. Here it broadcasts the group-local warp index from lane
273-
// 0. Comparing to 0 selects only the group's warp 0.
274-
return __shfl_sync(0xffffffff, // full warp mask
275-
(threadIdx.x / 32) %
276-
(thread_extent / 32), // warp index within group
277-
0 // take the value from lane 0
278-
) == 0 &&
279-
// Within that group leader warp, elect exactly one lane (typically
280-
// lane 0) to be the single representative for the group.
281-
cute::elect_one_sync();
282-
}
283-
284-
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
285-
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
286-
}
287-
288-
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
289-
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
290-
}
291-
292243
} // namespace tl

0 commit comments

Comments
 (0)