Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
e1fce95
installation script fix
LeiWang1999 Jan 20, 2025
cfff033
readme typo fix
LeiWang1999 Jan 20, 2025
39803df
doc fix for dequantize gemm
LeiWang1999 Jan 20, 2025
22f2f21
Merge branch 'main' of https://github.com/tile-ai/tilelang into fix_doc
LeiWang1999 Jan 22, 2025
35400de
[Doc] remove CODE_OF_CONDUCT.md and SECURITY.md; update references in…
LeiWang1999 Jan 22, 2025
e1f9728
[Doc] add unit tests for AnnotateDeviceRegions transform; remove SUPP…
LeiWang1999 Jan 22, 2025
ef64424
update license
LeiWang1999 Jan 23, 2025
4f71edb
[Enhancement] add tensor supply handling for unsigned integers; impro…
LeiWang1999 Jan 23, 2025
3ab2364
[Refactor] improve code readability by reformatting function signatur…
LeiWang1999 Jan 23, 2025
c8599df
Merge branch 'main' of https://github.com/tile-ai/tilelang into test_…
LeiWang1999 Jan 23, 2025
eb33049
[Refactor] replace torch.manual_seed with tilelang.testing.set_random…
LeiWang1999 Jan 23, 2025
0698ec0
[Refactor] unify thread binding variable naming across kernel and exa…
LeiWang1999 Jan 23, 2025
a63e600
[Refactor] remove unused thread binding parameter from matrix multipl…
LeiWang1999 Jan 23, 2025
a492c82
[Refactor] remove unused thread binding parameter from matrix multipl…
LeiWang1999 Jan 23, 2025
abc17fb
[Refactor] enable main testing function in tilelang kernel gemm test
LeiWang1999 Jan 23, 2025
8923345
bug fix
LeiWang1999 Jan 23, 2025
85444e7
Merge branch 'main' of https://github.com/tile-ai/tilelang into test_…
LeiWang1999 Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/dequantize_gemm/example_dequant_gemm_fine_grained.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def main(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y")

T.annotate_layout({
Expand All @@ -279,7 +279,7 @@ def main(
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_bindings
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
Expand All @@ -299,7 +299,6 @@ def main(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)

Expand All @@ -308,7 +307,6 @@ def main(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)

Expand Down Expand Up @@ -343,7 +341,6 @@ def main(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

for i, j in T.Parallel(block_M, (block_N // reduce_k)):
Expand Down
15 changes: 6 additions & 9 deletions examples/gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
Expand Down Expand Up @@ -367,16 +367,14 @@ def tl_matmul(
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
ki
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
ki
)

# Perform Matrix Multiplication
Expand All @@ -386,7 +384,6 @@ def tl_matmul(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

# Store shared into global
Expand Down Expand Up @@ -416,10 +413,10 @@ def tl_matmul(
```python
for ki in T.serial(0, (block_K // micro_size_k)):
# Warp-synchronous load for A
mma_emitter.ldmatrix_a(A_local, A_shared, ki, thread_bindings=thread_bindings)
mma_emitter.ldmatrix_a(A_local, A_shared, ki)

# Warp-synchronous load for B
mma_emitter.ldmatrix_b(B_local, B_shared, ki, thread_bindings=thread_bindings)
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.

Expand All @@ -437,7 +434,7 @@ def tl_matmul(
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared, thread_bindings=thread_bindings)
mma_emitter.stmatrix(C_local, C_shared)
```
orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer.

Expand Down
22 changes: 3 additions & 19 deletions examples/gemm/example_gemm_intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def main(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
Expand All @@ -141,30 +139,16 @@ def main(
for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
mma_emitter.ldmatrix_a(A_local, A_shared, ki)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
mma_emitter.ldmatrix_b(B_local, B_shared, ki)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)

# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
mma_emitter.stmatrix(C_local, C_shared)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
Expand Down
6 changes: 0 additions & 6 deletions testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def main(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
Expand Down Expand Up @@ -128,15 +126,13 @@ def main(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)

# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)

# Perform Matrix Multiplication
Expand All @@ -147,7 +143,6 @@ def main(
mfma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

# Store shared into global
Expand All @@ -162,7 +157,6 @@ def main(
mfma_emitter.stmatrix(
C_local,
C,
thread_bindings=thread_bindings,
pid_m=by,
pid_n=bx,
)
Expand Down
5 changes: 0 additions & 5 deletions testing/python/dynamic/test_tilelang_dynamic_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def main(
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
Expand Down Expand Up @@ -142,15 +140,13 @@ def main(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)

# Perform Matrix Multiplication
Expand All @@ -160,7 +156,6 @@ def main(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

# Store shared into global
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def main(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y")

T.annotate_layout({
Expand All @@ -479,7 +479,7 @@ def main(
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_bindings
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
Expand All @@ -499,7 +499,6 @@ def main(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)

Expand All @@ -508,7 +507,6 @@ def main(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)

Expand Down Expand Up @@ -543,7 +541,6 @@ def main(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

for i, j in T.Parallel(block_M, (block_N // reduce_k)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def main(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
Expand Down Expand Up @@ -148,15 +148,13 @@ def main(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)

# Perform Matrix Multiplication
Expand All @@ -166,7 +164,6 @@ def main(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

# Store shared into global
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
Expand Down Expand Up @@ -138,16 +138,14 @@ def main(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
Expand All @@ -156,7 +154,6 @@ def main(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

# Store shared into global
Expand Down Expand Up @@ -297,7 +294,7 @@ def main(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
Expand Down Expand Up @@ -328,16 +325,14 @@ def main(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
Expand All @@ -346,7 +341,6 @@ def main(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

# Store shared into global
Expand Down
Loading
Loading