Skip to content

Commit a2ccb8f

Browse files
committed
[Refactor] Introduce quantize components of TileLang and add testing for dequant gemm exmaple (tile-ai#494)
* Remove deprecated example_dequant_gemm.py and add DataType import in __init__.py * lint fix * lint fix * Refactor dequantization examples to use tilelang imports and update data type handling in quantization utilities * lint fix
1 parent 9e7d43d commit a2ccb8f

File tree

11 files changed

+2054
-14
lines changed

11 files changed

+2054
-14
lines changed

examples/dequantize_gemm/example_dequant_gemm.py

Whitespace-only changes.

examples/dequantize_gemm/example_dequant_gemm_fine_grained.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,5 +433,10 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
433433
256, 1024, 512, "float16", "float16", "float16", 3)
434434

435435

436+
def main():
437+
test_run_dequantize_gemm()
438+
test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4()
439+
440+
436441
if __name__ == "__main__":
437-
tilelang.testing.main()
442+
main()

examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,12 @@ def ref_program(A, qB):
266266
return C.transpose(0, 1)
267267

268268

269-
if __name__ == "__main__":
270-
parser = argparse.ArgumentParser()
271-
parser.add_argument('--m', type=int, default=256, help='M')
272-
parser.add_argument('--n', type=int, default=256, help='N')
273-
parser.add_argument('--k', type=int, default=256, help='K')
274-
parser.add_argument('--tune', action='store_true', help='tune configs')
275-
args = parser.parse_args()
276-
M, N, K = args.m, args.n, args.k
277-
total_flops = 2 * M * N * K
269+
def main(m=256, n=256, k=256, tune=False):
270+
total_flops = 2 * m * n * k
278271

279-
if (not args.tune):
272+
if (not tune):
280273
program = matmul(
281-
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
274+
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
282275
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
283276
kernel = tilelang.compile(program, out_idx=[2])
284277
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
@@ -291,10 +284,20 @@ def ref_program(A, qB):
291284
print("Tile-lang: {:.2f} ms".format(latency))
292285
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
293286
else:
294-
best_result = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
287+
best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)
295288
best_latency = best_result.latency
296289
best_config = best_result.config
297-
ref_latency = best_result.ref_latency
298290
print(f"Best latency: {best_latency}")
299291
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
300292
print(f"Best config: {best_config}")
293+
294+
295+
if __name__ == "__main__":
296+
parser = argparse.ArgumentParser()
297+
parser.add_argument('--m', type=int, default=256, help='M')
298+
parser.add_argument('--n', type=int, default=256, help='N')
299+
parser.add_argument('--k', type=int, default=256, help='K')
300+
parser.add_argument('--tune', action='store_true', help='tune configs')
301+
args = parser.parse_args()
302+
M, N, K = args.m, args.n, args.k
303+
main(M, N, K, args.tune)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import tilelang
2+
from tilelang import language as T
3+
from typing import Optional, Callable, Any
4+
import torch
5+
from tilelang import DataType
6+
from tilelang.quantize import (
7+
_tir_packed_int_to_int_convert,)
8+
9+
10+
def dequantize_gemv(
11+
M: int,
12+
N: int,
13+
K: int,
14+
in_dtype: str,
15+
out_dtype: str,
16+
accum_dtype: str,
17+
num_bits: int = 4,
18+
storage_dtype: str = "int8",
19+
source_format: str = "uint",
20+
n_partition: int = 4,
21+
reduce_thread: int = 32,
22+
fast_decoding: bool = False,
23+
trans_A: bool = False,
24+
trans_B: bool = True,
25+
group_size: int = -1,
26+
with_scaling: bool = False,
27+
) -> Callable[..., Any]:
28+
29+
assert n_partition is not None, "n_partition must be provided"
30+
assert reduce_thread is not None, (
31+
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
32+
"sch_outer_reduction_with_config is not implemented")
33+
34+
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
35+
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
36+
storage_type = "".join(c for c in storage_dtype if not c.isdigit())
37+
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
38+
num_elems_per_byte = storage_nbit // num_bits
39+
40+
MAX_TRANSACTION_SIZE_IN_BITS = 128
41+
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
42+
micro_size_k_compressed = micro_size_k // num_elems_per_byte
43+
block_K = reduce_thread * micro_size_k
44+
45+
if group_size == -1:
46+
group_size = K
47+
48+
A_shape = (M, K)
49+
B_shape = (N, K // storage_nbit * num_bits)
50+
C_shape = (M, N)
51+
52+
dp4a_size = 4
53+
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
54+
55+
import_source: Optional[str] = None
56+
func_name: str = ""
57+
if fast_decoding is True:
58+
# Lazy import to decrease the startup time
59+
# as intrin registry may take a while to load
60+
from tilelang.quantize import get_lop3_intrin_group
61+
62+
lop3_intrin_info = get_lop3_intrin_group(
63+
out_dtype=in_dtype,
64+
source_format=source_format,
65+
source_bit=num_bits,
66+
storage_dtype=storage_dtype,
67+
with_scaling=with_scaling,
68+
with_zeros=False,
69+
)
70+
import_source = lop3_intrin_info["c_source"]
71+
func_name = lop3_intrin_info["func_name"]
72+
assert import_source is not None, "lop3_intrin_info is not found"
73+
assert func_name is not None, "lop3_intrin_info is not found"
74+
import_source = import_source
75+
76+
@T.prim_func
77+
def main(
78+
A: T.Tensor[A_shape, in_dtype],
79+
B: T.Tensor[B_shape, storage_dtype],
80+
C: T.Tensor[C_shape, out_dtype],
81+
):
82+
with T.Kernel(
83+
T.ceildiv(N, n_partition),
84+
M,
85+
threads=(reduce_thread, n_partition),
86+
) as (
87+
bx,
88+
by,
89+
):
90+
A_local = T.alloc_local((micro_size_k,), in_dtype)
91+
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
92+
B_dequantize_local = T.alloc_local([micro_size_k], in_dtype)
93+
accum_res = T.alloc_local((1,), accum_dtype)
94+
reduced_accum_res = T.alloc_local((1,), accum_dtype)
95+
96+
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
97+
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
98+
99+
T.import_source(import_source)
100+
101+
T.clear(accum_res)
102+
for ko in T.serial(T.ceildiv(K, block_K)):
103+
for v in T.vectorized(micro_size_k):
104+
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
105+
106+
for v in T.vectorized(micro_size_k_compressed):
107+
B_quant_local[v] = B[
108+
bx * n_partition + ni,
109+
ko * (reduce_thread * micro_size_k_compressed) +
110+
kr * micro_size_k_compressed + v,
111+
]
112+
113+
if fast_decoding:
114+
T.call_extern(
115+
func_name,
116+
T.address_of(B_quant_local[0]),
117+
T.address_of(B_dequantize_local[0]),
118+
dtype=in_dtype,
119+
)
120+
else:
121+
for ki in T.serial(micro_size_k):
122+
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(
123+
storage_type,
124+
storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte],
125+
ki % num_elems_per_byte, in_dtype)
126+
127+
if use_dp4a:
128+
for ki in T.serial(micro_size_k // dp4a_size):
129+
T.dp4a(
130+
A_local[ki * dp4a_size],
131+
B_dequantize_local[ki * dp4a_size],
132+
accum_res[0],
133+
)
134+
else:
135+
for ki in T.serial(micro_size_k):
136+
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
137+
138+
with T.attr(
139+
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
140+
"reduce_scope",
141+
T.reinterpret(T.uint64(0), dtype="handle"),
142+
):
143+
T.evaluate(
144+
T.tvm_thread_allreduce(
145+
T.uint32(1),
146+
accum_res[0],
147+
True,
148+
reduced_accum_res[0],
149+
kr,
150+
dtype="handle",
151+
))
152+
if kr == 0:
153+
C[by, bx * n_partition + ni] = reduced_accum_res[0]
154+
155+
return main
156+
157+
158+
def main() -> None:
159+
M = 1
160+
N = 1024
161+
K = 1024
162+
in_dtype = "float16"
163+
out_dtype = "float16"
164+
accum_dtype = "float16"
165+
num_bits = 4
166+
storage_dtype = "int8"
167+
source_format = "uint"
168+
n_partition = 4
169+
reduce_thread = 32
170+
fast_decoding = True
171+
trans_A = False
172+
trans_B = True
173+
group_size = -1
174+
with_scaling = False
175+
176+
program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
177+
source_format, n_partition, reduce_thread, fast_decoding, trans_A,
178+
trans_B, group_size, with_scaling)
179+
180+
kernel = tilelang.compile(program)
181+
182+
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
183+
num_elems_per_byte = storage_nbit // num_bits
184+
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
185+
qB = torch.randint(
186+
0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
187+
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
188+
189+
if fast_decoding:
190+
from tilelang.quantize.utils import interleave_weight
191+
qB = interleave_weight(qB, num_bits, in_dtype)
192+
kernel(A, qB, C)
193+
194+
# int4 reference
195+
B = (
196+
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
197+
dtype=torch.half).to(torch.half).to(A.device))
198+
for j in range(B.shape[1]):
199+
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
200+
201+
# Get Reference Result
202+
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
203+
print("C: ", C)
204+
print("Ref C: ", ref_c)
205+
# doesn't apply scaling, the absolute error is large
206+
torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1)
207+
208+
209+
if __name__ == "__main__":
210+
main()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import tilelang.testing
2+
3+
import example_dequant_gemv_fp16xint4
4+
import example_dequant_gemm_fp4_hopper
5+
6+
7+
@tilelang.testing.requires_cuda
8+
def test_example_dequant_gemv_fp16xint4():
9+
example_dequant_gemv_fp16xint4.main()
10+
11+
12+
@tilelang.testing.requires_cuda
13+
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
14+
def test_example_dequant_gemm_fp4_hopper():
15+
example_dequant_gemm_fp4_hopper.main()
16+
17+
18+
if __name__ == "__main__":
19+
tilelang.testing.main()

0 commit comments

Comments
 (0)