Skip to content

Commit 2e59910

Browse files
committed
Merge branch 'main' of https://github.com/tile-ai/tilelang into sm89fix
2 parents 88abdf8 + 232782d commit 2e59910

File tree

11 files changed

+641
-58
lines changed

11 files changed

+641
-58
lines changed

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
cmake_minimum_required(VERSION 3.18)
55
project(TILE_LANG C CXX)
66

7+
option(TILE_LANG_STATIC_STDCPP "Statically link libstdc++ for TileLang libraries" ON)
8+
option(TILE_LANG_INSTALL_STATIC_LIB "Install the static library" ON)
9+
10+
if(TILE_LANG_STATIC_STDCPP)
11+
message(STATUS "Enabling static linking of C++ standard library")
12+
# Note: We'll apply static linking flags selectively to avoid Python extension conflicts
13+
# The flags will be applied per-target below rather than globally
14+
endif()
15+
716
# Set default build type to Release if not provided
817
if(NOT CMAKE_BUILD_TYPE)
918
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type")
@@ -218,6 +227,11 @@ add_library(tilelang_static STATIC $<TARGET_OBJECTS:tilelang_objs>)
218227
add_dependencies(tilelang_static tvm_runtime)
219228
set_target_properties(tilelang_static PROPERTIES OUTPUT_NAME tilelang)
220229

230+
# Apply static linking flags only to static library to avoid Python extension conflicts
231+
if(TILE_LANG_STATIC_STDCPP AND CMAKE_CXX_COMPILER_ID MATCHES "GNU")
232+
target_link_options(tilelang_static PRIVATE -static-libstdc++ -static-libgcc)
233+
endif()
234+
221235
# Debug build type-specific definitions
222236
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
223237
target_compile_definitions(tilelang PRIVATE "TVM_LOG_DEBUG")

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
389389
"""
390390
dtypeC = "bfloat16"
391391
B = torch_convert_bit_twiddling(qB)
392-
for i in range(B.shape[0]):
393-
for j in range(B.shape[1]):
394-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
392+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
395393
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
396394
C = C.to(torch.__getattribute__(dtypeC))
397395
return C
@@ -414,9 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
414412
"""
415413
dtypeC = "bfloat16"
416414
B = torch_convert_bit_twiddling(qB)
417-
for i in range(B.shape[0]):
418-
for j in range(B.shape[1]):
419-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
415+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
420416
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
421417
C = C.to(torch.__getattribute__(dtypeC))
422418
return C
@@ -440,9 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
440436
"""
441437
dtypeC = "bfloat16"
442438
B = torch_convert(qB)
443-
for i in range(B.shape[0]):
444-
for j in range(B.shape[1]):
445-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
439+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
446440
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
447441
C = C.to(torch.__getattribute__(dtypeC))
448442
return C
@@ -470,9 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
470464
"""
471465
dtypeC = "bfloat16"
472466
B = torch_convert(qB)
473-
for i in range(B.shape[0]):
474-
for j in range(B.shape[1]):
475-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
467+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
476468
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
477469
C = C.to(torch.__getattribute__(dtypeC))
478470
return C

examples/dequantize_gemm/example_dequant_gemm_fine_grained.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def matmul(
2323
threads,
2424
num_bits=4,
2525
):
26-
from bitblas.quantization import _tir_packed_to_unsigned_convert
26+
from tilelang.quantize import _tir_packed_to_unsigned_convert
2727
num_elems_per_byte = 8 // num_bits
2828
storage_dtype = "int8"
2929
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))

examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

Lines changed: 511 additions & 0 deletions
Large diffs are not rendered by default.

examples/dequantize_gemm/test_example_dequantize_gemm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import example_dequant_gemm_fp4_hopper
55
import example_dequant_gemm_bf16_mxfp4_hopper
66
import example_dequant_gemm_bf16_mxfp4_hopper_tma
7+
import example_dequant_groupedgemm_bf16_mxfp4_hopper
78
import example_dequant_gemm_w4a8
89

910

@@ -31,6 +32,13 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
3132

3233

3334
@tilelang.testing.requires_cuda
35+
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
36+
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
37+
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
38+
39+
40+
@tilelang.testing.requires_cuda
41+
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3442
def test_example_dequant_gemm_w4a8():
3543
example_dequant_gemm_w4a8.main()
3644

examples/dequantize_gemm/utils.py

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
def torch_convert_bit_twiddling(tensor):
55
"""
6-
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
7-
86
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
97
108
Parameters:
@@ -16,38 +14,46 @@ def torch_convert_bit_twiddling(tensor):
1614
Raises:
1715
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
1816
"""
17+
assert tensor.dim() == 2 and tensor.dtype == torch.uint8
18+
N, K = tensor.shape
19+
assert K % 2 == 0, "Number of columns must be even"
1920

20-
def _convert(val0, val1, pos) -> torch.bfloat16:
21-
assert val0.dtype == torch.uint8
22-
assert val1.dtype == torch.uint8
23-
val0 = val0.view(torch.uint8)
24-
val1 = val1.view(torch.uint8)
25-
val_concat = (val0.item() << 8) | val1.item()
26-
mask = 0b1000000111000000
27-
if pos == 0:
28-
bf16 = val_concat & mask
29-
elif pos == 1:
30-
bf16 = (val_concat << 3) & mask
31-
elif pos == 2:
32-
bf16 = (val_concat << 6) & mask
33-
elif pos == 3:
34-
mask1 = 0b1000000000000000
35-
mask2 = 0b0000000110000000
36-
mask3 = 0b0000000001000000
37-
bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | (
38-
(val_concat >> 7) & mask3)
39-
bf16_new = torch.tensor([bf16], dtype=torch.uint16, device=val0.device).view(torch.bfloat16)
40-
# Add bias for change from fp4 to bf16
41-
bf16_new = bf16_new.item() * (2**126)
42-
return bf16_new
21+
# Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA
22+
val0 = tensor[:, 0::2].to(torch.int32)
23+
val1 = tensor[:, 1::2].to(torch.int32)
24+
val_concat = (val0 << 8) | val1 # (N, K//2), uint32
4325

44-
N = tensor.shape[0]
45-
K = tensor.shape[1]
46-
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
47-
for i in range(new_tensor.shape[0]):
48-
for j in range(new_tensor.shape[1]):
49-
new_tensor[i][j] = _convert(tensor[i][j // 4 * 2], tensor[i][j // 4 * 2 + 1], j % 4)
50-
return new_tensor
26+
# Expand to match output shape where each pair generates 4 values
27+
val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4)
28+
29+
# Positional encoding for bit-twiddling logic
30+
pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,)
31+
32+
# Bit masks for decoding (as uint32 for CUDA compatibility)
33+
mask = 0b1000000111000000
34+
mask1 = 0b1000000000000000
35+
mask2 = 0b0000000110000000
36+
mask3 = 0b0000000001000000
37+
38+
# Calculate results for all 4 positions in parallel
39+
res0 = val_concat_expanded & mask
40+
res1 = (val_concat_expanded << 3) & mask
41+
res2 = (val_concat_expanded << 6) & mask
42+
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
43+
(val_concat_expanded >> 7) & mask3)
44+
45+
# Select the correct result based on position
46+
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
47+
torch.where(pos == 2, res2, res3)))
48+
49+
# Convert to uint16 for .view(torch.bfloat16)
50+
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
51+
bf16_bf16 = bf16_uint16.view(torch.bfloat16)
52+
53+
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
54+
bf16_new = bf16_bf16 * (2.0**126)
55+
56+
return bf16_new
5157

5258

5359
def torch_convert(tensor, scale_size=None, Scale=None):
@@ -106,3 +112,41 @@ def print_bit(name, val):
106112
val_cpu = val.cpu().item()
107113
binary_repr = f'{val_cpu:032b}'
108114
print(name, binary_repr)
115+
116+
117+
def print_red_warning(message):
118+
print(f"\033[31mWARNING: {message}\033[0m")
119+
120+
121+
def calc_sim(x, y, name="tensor"):
122+
x, y = x.data.double(), y.data.double()
123+
denominator = (x * x + y * y).sum()
124+
if denominator == 0:
125+
print_red_warning(f'{name} all zero')
126+
return 1
127+
sim = 2 * (x * y).sum() / denominator
128+
return sim
129+
130+
131+
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
132+
x_mask = torch.isfinite(x)
133+
y_mask = torch.isfinite(y)
134+
if not torch.all(x_mask == y_mask):
135+
print_red_warning(f'{name} Error: isfinite mask mismatch')
136+
if raise_assert:
137+
raise AssertionError
138+
if not torch.isclose(
139+
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
140+
equal_nan=True).all():
141+
print_red_warning(f'{name} Error: nonfinite value mismatch')
142+
if raise_assert:
143+
raise AssertionError
144+
x = x.masked_fill(~x_mask, 0)
145+
y = y.masked_fill(~y_mask, 0)
146+
sim = calc_sim(x, y, name)
147+
diff = (1. - sim).item()
148+
print(f'{diff=}')
149+
if not (0 <= diff <= eps):
150+
print_red_warning(f'{name} Error: {diff=}')
151+
if raise_assert:
152+
raise AssertionError

maint/scripts/pypi.Dockerfile

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,40 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu18.04
22

33
RUN set -eux; \
44
apt-get update; \
5-
apt-get install -y wget curl libtinfo-dev zlib1g-dev libssl-dev build-essential libedit-dev libxml2-dev git; \
5+
# Install gcc-9 and g++-9
6+
apt-get install -y software-properties-common; \
7+
add-apt-repository ppa:ubuntu-toolchain-r/test -y; \
8+
apt-get update; \
9+
apt-get install -y wget curl libtinfo-dev zlib1g-dev libssl-dev build-essential \
10+
libedit-dev libxml2-dev git gcc-9 g++-9; \
11+
# Switch default gcc/g++ to new version
12+
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 100; \
13+
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-9 100; \
14+
update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 100; \
15+
update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 100; \
16+
gcc --version; g++ --version; \
617
curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh; \
718
bash Miniconda3-latest-Linux-x86_64.sh -b -p /miniconda3; \
8-
rm Miniconda3-latest-Linux-x86_64.sh
19+
rm Miniconda3-latest-Linux-x86_64.sh;
20+
21+
RUN apt-get update && apt-get install -y ninja-build
922

1023
ENV PATH=/miniconda3/bin/:$PATH
1124

25+
# ✅ Accept Anaconda Terms of Service for both required channels
26+
RUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \
27+
conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r
28+
29+
# Create environments
1230
RUN set -eux; \
13-
conda create -n py38 python=3.8 -y; \
1431
conda create -n py39 python=3.9 -y; \
1532
conda create -n py310 python=3.10 -y; \
1633
conda create -n py311 python=3.11 -y; \
1734
conda create -n py312 python=3.12 -y; \
18-
ln -s /miniconda3/envs/py38/bin/python3.8 /usr/bin/python3.8; \
1935
ln -s /miniconda3/envs/py39/bin/python3.9 /usr/bin/python3.9; \
2036
ln -s /miniconda3/envs/py310/bin/python3.10 /usr/bin/python3.10; \
2137
ln -s /miniconda3/envs/py311/bin/python3.11 /usr/bin/python3.11; \
2238
ln -s /miniconda3/envs/py312/bin/python3.12 /usr/bin/python3.12; \
2339
conda install -y cmake patchelf
2440

25-
WORKDIR /tilelang
41+
WORKDIR /tilelang

pyproject.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@ requires = [
44
"cmake>=3.26",
55
"packaging",
66
"setuptools>=61",
7-
"torch",
87
"wheel",
9-
"tox",
10-
"auditwheel",
118
"patchelf",
12-
"ninja",
13-
"Cython",
9+
"Cython>=3.0.0",
1410
]
1511
build-backend = "setuptools.build_meta"
1612

requirements-build.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Should be mirrored in pyproject.toml
2-
Cython
2+
Cython>=3.0.0
33
build
44
cmake>=3.26
55
packaging
@@ -9,3 +9,4 @@ wheel
99
tox
1010
auditwheel
1111
patchelf
12+
ninja

tilelang/language/builtin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,13 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
331331

332332

333333
def sync_threads():
334-
"""Synchronize all threads in a warp.
334+
"""Synchronize all threads in a block.
335335
"""
336336
return tir.op.tvm_storage_sync("shared")
337337

338338

339339
def sync_global():
340-
"""Synchronize all threads in a block.
340+
"""Synchronize all threads in the entire grid.
341341
"""
342342
tx, ty, tz = get_thread_bindings()
343343
ex, ey, ez = get_block_extents()

0 commit comments

Comments
 (0)