Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
Expand All @@ -414,9 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
Expand All @@ -440,9 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
Expand Down Expand Up @@ -470,9 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def matmul(
threads,
num_bits=4,
):
from bitblas.quantization import _tir_packed_to_unsigned_convert
from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
Expand Down

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions examples/dequantize_gemm/test_example_dequantize_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper_tma
import example_dequant_groupedgemm_bf16_mxfp4_hopper
import example_dequant_gemm_w4a8


Expand Down Expand Up @@ -31,6 +32,13 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_w4a8():
example_dequant_gemm_w4a8.main()

Expand Down
108 changes: 76 additions & 32 deletions examples/dequantize_gemm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

def torch_convert_bit_twiddling(tensor):
"""
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.

This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.

Parameters:
Expand All @@ -16,38 +14,46 @@ def torch_convert_bit_twiddling(tensor):
Raises:
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
"""
assert tensor.dim() == 2 and tensor.dtype == torch.uint8
N, K = tensor.shape
assert K % 2 == 0, "Number of columns must be even"

def _convert(val0, val1, pos) -> torch.bfloat16:
assert val0.dtype == torch.uint8
assert val1.dtype == torch.uint8
val0 = val0.view(torch.uint8)
val1 = val1.view(torch.uint8)
val_concat = (val0.item() << 8) | val1.item()
mask = 0b1000000111000000
if pos == 0:
bf16 = val_concat & mask
elif pos == 1:
bf16 = (val_concat << 3) & mask
elif pos == 2:
bf16 = (val_concat << 6) & mask
elif pos == 3:
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000
bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | (
(val_concat >> 7) & mask3)
bf16_new = torch.tensor([bf16], dtype=torch.uint16, device=val0.device).view(torch.bfloat16)
# Add bias for change from fp4 to bf16
bf16_new = bf16_new.item() * (2**126)
return bf16_new
# Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA
val0 = tensor[:, 0::2].to(torch.int32)
val1 = tensor[:, 1::2].to(torch.int32)
val_concat = (val0 << 8) | val1 # (N, K//2), uint32

N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 4 * 2], tensor[i][j // 4 * 2 + 1], j % 4)
return new_tensor
# Expand to match output shape where each pair generates 4 values
val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4)

# Positional encoding for bit-twiddling logic
pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,)

# Bit masks for decoding (as uint32 for CUDA compatibility)
mask = 0b1000000111000000
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000

# Calculate results for all 4 positions in parallel
res0 = val_concat_expanded & mask
res1 = (val_concat_expanded << 3) & mask
res2 = (val_concat_expanded << 6) & mask
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
(val_concat_expanded >> 7) & mask3)

# Select the correct result based on position
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
torch.where(pos == 2, res2, res3)))

# Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
bf16_bf16 = bf16_uint16.view(torch.bfloat16)

# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)

return bf16_new
Comment on lines +53 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Return dtype likely float32; violates docstring contract (bf16 expected)

Multiplying a bf16 tensor by a float literal promotes to float32 in PyTorch; the function then returns f32.

Apply:

-    # Avoid integer overflow by using a float32 multiplier for the exponent scaling
-    bf16_new = bf16_bf16 * (2.0**126)
-
-    return bf16_new
+    # Keep result in bf16 as documented
+    scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device)
+    return (bf16_bf16 * scale)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)
return bf16_new
# Keep result in bf16 as documented
scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device)
return (bf16_bf16 * scale)
🤖 Prompt for AI Agents
In examples/dequantize_gemm/utils.py around lines 53 to 56, the multiplication
by a float literal promotes the bf16 tensor to float32 so the function returns
f32 instead of bf16 as the docstring promises; after computing bf16_new keep the
numeric scaling but cast the result back to bfloat16 (e.g., use
.to(torch.bfloat16) or .bfloat16()) before returning so the return dtype matches
the docstring.



def torch_convert(tensor, scale_size=None, Scale=None):
Expand Down Expand Up @@ -106,3 +112,41 @@ def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)


def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")


def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim


def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = (1. - sim).item()
print(f'{diff=}')
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff=}')
if raise_assert:
raise AssertionError
4 changes: 2 additions & 2 deletions tilelang/language/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,13 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,


def sync_threads():
"""Synchronize all threads in a warp.
"""Synchronize all threads in a block.
"""
return tir.op.tvm_storage_sync("shared")


def sync_global():
"""Synchronize all threads in a block.
"""Synchronize all threads in the entire grid.
"""
tx, ty, tz = get_thread_bindings()
ex, ey, ez = get_block_extents()
Expand Down
1 change: 1 addition & 0 deletions tilelang/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
_tir_packed_to_fp4_to_f16, # noqa: F401
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
_tir_u8_to_f4_to_bf16, # noqa: F401
)

from .utils import (
Expand Down