Skip to content

Commit 5b48f8d

Browse files
authored
[Debug] Introduce T.print for buffer and variables logging on frontend (#45)
* [Doc] Update documentation structure and content: add overview section, revise project name, and change theme to Furo * [Feature] Add device-side debug printing functions and integrate into kernel interface * lint fix * remove debug print * implement test for debug * lint fix * add some comments * Enhance fragment design and assert fragment print * enhance debug print * add test for msg * lint fix
1 parent 61d9016 commit 5b48f8d

File tree

8 files changed

+396
-10
lines changed

8 files changed

+396
-10
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,6 @@ models/frozenmodels/
7676

7777
# build sdist
7878
build_sdist/
79+
80+
# exclude debug testing folder
81+
!testing/python/debug

src/target/codegen_cuda.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ std::string CodeGenTileLangCUDA::Finish() {
8383
decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
8484
decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
8585
decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
86+
decl_stream << "#include <tl_templates/cuda/debug.h>\n";
8687
decl_stream << "\n";
8788
return CodeGenC::Finish();
8889
}

src/tl_templates/cuda/debug.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#pragma once
2+
3+
#include "common.h"
4+
#include <stdio.h>
5+
6+
// Template declaration for device-side debug printing (variable only)
7+
template <typename T> __device__ void debug_print_var(char *msg, T var);
8+
9+
// Specialization for integer type
10+
template <> __device__ void debug_print_var<int>(char *msg, int var) {
11+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
12+
"value=%d\n",
13+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
14+
threadIdx.z, var);
15+
}
16+
17+
// Specialization for float type
18+
template <> __device__ void debug_print_var<float>(char *msg, float var) {
19+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
20+
"value=%f\n",
21+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
22+
threadIdx.z, var);
23+
}
24+
25+
// Specialization for half type
26+
template <> __device__ void debug_print_var<half>(char *msg, half var) {
27+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half "
28+
"value=%f\n",
29+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
30+
threadIdx.z, (float)var);
31+
}
32+
33+
// Specialization for half_t type
34+
template <> __device__ void debug_print_var<half_t>(char *msg, half_t var) {
35+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t "
36+
"value=%f\n",
37+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
38+
threadIdx.z, (float)var);
39+
}
40+
41+
// Specialization for bfloat16_t type
42+
template <>
43+
__device__ void debug_print_var<bfloat16_t>(char *msg, bfloat16_t var) {
44+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
45+
"dtype=bfloat16_t value=%f\n",
46+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
47+
threadIdx.z, (float)var);
48+
}
49+
50+
// Specialization for double type
51+
template <> __device__ void debug_print_var<double>(char *msg, double var) {
52+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
53+
"value=%lf\n",
54+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
55+
threadIdx.z, var);
56+
}
57+
58+
#pragma once
59+
60+
#include "common.h"
61+
#include <stdio.h>
62+
63+
// Template declaration for device-side debug printing (buffer only)
64+
template <typename T>
65+
__device__ void debug_print_buffer_value(char *msg, char *buf_name, int index,
66+
T var);
67+
68+
// Specialization for integer type
69+
template <>
70+
__device__ void debug_print_buffer_value<int>(char *msg, char *buf_name,
71+
int index, int var) {
72+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
73+
"index=%d, dtype=int value=%d\n",
74+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
75+
threadIdx.z, buf_name, index, var);
76+
}
77+
78+
// Specialization for float type
79+
template <>
80+
__device__ void debug_print_buffer_value<float>(char *msg, char *buf_name,
81+
int index, float var) {
82+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
83+
"index=%d, dtype=float value=%f\n",
84+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
85+
threadIdx.z, buf_name, index, var);
86+
}
87+
88+
// Specialization for half type
89+
template <>
90+
__device__ void debug_print_buffer_value<half>(char *msg, char *buf_name,
91+
int index, half var) {
92+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
93+
"index=%d, dtype=half value=%f\n",
94+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
95+
threadIdx.z, buf_name, index, (float)var);
96+
}
97+
98+
// Specialization for half_t type
99+
template <>
100+
__device__ void debug_print_buffer_value<half_t>(char *msg, char *buf_name,
101+
int index, half_t var) {
102+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
103+
"index=%d, dtype=half_t value=%f\n",
104+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
105+
threadIdx.z, buf_name, index, (float)var);
106+
}
107+
108+
// Specialization for bfloat16_t type
109+
template <>
110+
__device__ void debug_print_buffer_value<bfloat16_t>(char *msg, char *buf_name,
111+
int index,
112+
bfloat16_t var) {
113+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
114+
"index=%d, dtype=bfloat16_t value=%f\n",
115+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
116+
threadIdx.z, buf_name, index, (float)var);
117+
}
118+
119+
// Specialization for double type
120+
template <>
121+
__device__ void debug_print_buffer_value<double>(char *msg, char *buf_name,
122+
int index, double var) {
123+
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
124+
"index=%d, dtype=double value=%lf\n",
125+
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
126+
threadIdx.z, buf_name, index, var);
127+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# type: ignore
2+
3+
import tilelang
4+
import tilelang.testing
5+
import tilelang.language as T
6+
7+
8+
def debug_print_buffer(M=16, N=16):
9+
dtype = "float16"
10+
11+
@T.prim_func
12+
def program(Q: T.Buffer((M, N), dtype)):
13+
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
14+
shared_buf = T.alloc_shared([M, N], dtype)
15+
T.print(shared_buf)
16+
17+
jit_kernel = tilelang.JITKernel(program, target="cuda")
18+
profiler = jit_kernel.get_profiler()
19+
profiler.run_once()
20+
21+
22+
def test_debug_print_buffer():
23+
debug_print_buffer(16, 16)
24+
25+
26+
def debug_print_buffer_conditional(M=16, N=16):
27+
dtype = "float16"
28+
29+
@T.prim_func
30+
def program(Q: T.Buffer((M, N), dtype)):
31+
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
32+
shared_buf = T.alloc_shared([M, N], dtype)
33+
34+
if bx == 0 and by == 0 and bz == 0:
35+
T.print(shared_buf)
36+
37+
jit_kernel = tilelang.JITKernel(program, target="cuda")
38+
profiler = jit_kernel.get_profiler()
39+
profiler.run_once()
40+
41+
42+
def test_debug_print_buffer_conditional():
43+
debug_print_buffer_conditional(16, 16)
44+
45+
46+
def debug_print_value_conditional(M=16, N=16):
47+
dtype = "float16"
48+
49+
@T.prim_func
50+
def program(Q: T.Buffer((M, N), dtype)):
51+
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
52+
tid = T.get_thread_binding()
53+
if tid == 0:
54+
T.print(bx + by + bz)
55+
56+
jit_kernel = tilelang.JITKernel(program, target="cuda")
57+
profiler = jit_kernel.get_profiler()
58+
profiler.run_once()
59+
60+
61+
def test_debug_print_value_conditional():
62+
debug_print_value_conditional(16, 16)
63+
64+
65+
def debug_print_register_files(M=16, N=16):
66+
dtype = "float16"
67+
68+
@T.prim_func
69+
def program(Q: T.Buffer((M, N), dtype)):
70+
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
71+
shared_buf = T.alloc_fragment([M, N], dtype)
72+
for i, j in T.Parallel(M, N):
73+
T.print(shared_buf[i, j])
74+
75+
jit_kernel = tilelang.JITKernel(program, target="cuda")
76+
profiler = jit_kernel.get_profiler()
77+
profiler.run_once()
78+
79+
80+
def test_debug_print_register_files():
81+
debug_print_register_files(16, 16)
82+
83+
84+
def debug_print_msg(M=16, N=16):
85+
dtype = "float16"
86+
87+
@T.prim_func
88+
def program(Q: T.Buffer((M, N), dtype)):
89+
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
90+
tid = T.get_thread_binding()
91+
if tid == 0:
92+
T.print(bx + by + bz, msg="hello world")
93+
94+
jit_kernel = tilelang.JITKernel(program, target="cuda")
95+
profiler = jit_kernel.get_profiler()
96+
profiler.run_once()
97+
98+
99+
def test_debug_print_msg():
100+
debug_print_msg(16, 16)
101+
102+
103+
if __name__ == "__main__":
104+
tilelang.testing.main()

tilelang/language/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tilelang.layout import Layout, Fragment # noqa: F401
99
from .parallel import Parallel # noqa: F401
1010
from .pipeline import Pipelined # noqa: F401
11-
from .kernel import Kernel, KernelLaunchFrame # noqa: F401
11+
from .kernel import Kernel, KernelLaunchFrame, get_thread_binding # noqa: F401
1212
from .allocate import (
1313
alloc_local, # noqa: F401
1414
alloc_shared, # noqa: F401
@@ -24,6 +24,7 @@
2424
reduce_sum, # noqa: F401
2525
reduce_abssum, # noqa: F401
2626
)
27+
from .print import print # noqa: F401
2728
from .customize import (
2829
atomic_add, # noqa: F401
2930
atomic_addx2, # noqa: F401

tilelang/language/kernel.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ def get_thread_binding(self, dim: int = 0) -> Var:
132132
"""
133133
return self.frames[-4 + dim].iter_var.var
134134

135+
def get_thread_bindings(self) -> List[Var]:
136+
"""
137+
Returns the thread binding for the given dimension.
138+
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
139+
"""
140+
return [frame.iter_var.var for frame in self.frames[-4:-1]]
141+
135142
def get_num_threads(self) -> int:
136143
"""
137144
Returns the thread indices from the topmost frame.
@@ -213,3 +220,15 @@ def Kernel(
213220
attrs["pragma_import_c"] = prelude
214221

215222
return _ffi_api.KernelLaunch(blocks, threads, attrs)
223+
224+
225+
def get_thread_binding(dim: int = 0) -> Var:
226+
"""Returns the thread binding for the given dimension.
227+
"""
228+
return KernelLaunchFrame.Current().get_thread_binding(dim)
229+
230+
231+
def get_thread_bindings() -> List[Var]:
232+
"""Returns all three thread bindings.
233+
"""
234+
return KernelLaunchFrame.Current().get_thread_bindings()

0 commit comments

Comments
 (0)