Skip to content

Commit 7717d16

Browse files
authored
Merge branch 'main' into main
2 parents 6eecc9d + f0d6669 commit 7717d16

File tree

71 files changed

+544
-421
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+544
-421
lines changed

benchmark/matmul/benchmark_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def ref_program(A, B):
3232
def get_configs(args, kwargs):
3333
"""
3434
Generate a list of configuration dictionaries that will be used for tuning.
35-
35+
3636
Parameters
3737
----------
3838
with_roller : bool

benchmark/matmul/benchmark_matmul_intrinsic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def ref_program(A, B):
165165
def get_configs(args, kwargs):
166166
"""
167167
Generate a list of configuration dictionaries that will be used for tuning.
168-
168+
169169
Parameters
170170
----------
171171
with_roller : bool

benchmark/matmul/benchmark_matmul_sp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def ref_program(A, B):
3535
def get_configs(M, N, K):
3636
"""
3737
Generate a list of configuration dictionaries that will be used for tuning.
38-
38+
3939
Parameters
4040
----------
4141
with_roller : bool

benchmark/matmul_fp8/benchmark_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ref_program(A, B):
3333
def get_configs(args, kwargs):
3434
"""
3535
Generate a list of configuration dictionaries that will be used for tuning.
36-
36+
3737
Parameters
3838
----------
3939
with_roller : bool

examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def bitnet_158_int8xint2_prefill(
8484
):
8585
"""
8686
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
87-
87+
8888
The returned prim_func expects:
8989
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
9090
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
9191
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
92-
92+
9393
Details:
9494
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
9595
- Tiling parameters:
@@ -99,7 +99,7 @@ def bitnet_158_int8xint2_prefill(
9999
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32").
100100
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
101101
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.
102-
102+
103103
Parameters:
104104
M, N, K (int): Global matrix dimensions.
105105
in_dtype (str): Input and decoded B element dtype; "float16" or "int8".
@@ -111,7 +111,7 @@ def bitnet_158_int8xint2_prefill(
111111
warp_row_tiles (int): Tiles per warp in row dimension.
112112
warp_col_tiles (int): Tiles per warp in column dimension.
113113
chunk (int): K-length per block (block_K).
114-
114+
115115
Returns:
116116
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
117117
"""
@@ -187,18 +187,18 @@ def main(
187187
):
188188
"""
189189
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
190-
190+
191191
This kernel:
192192
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
193193
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
194194
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
195195
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
196-
196+
197197
Parameters:
198198
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
199199
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
200200
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
201-
201+
202202
Side effects:
203203
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
204204
"""

examples/bitnet-1.58b/vllm_workspace/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText],
77
name_0: str, name_1: str):
88
"""
9-
Compare the two sequences generated by different models,
9+
Compare the two sequences generated by different models,
1010
which should be equal.
1111
"""
1212
assert len(outputs_0_lst) == len(outputs_1_lst)

examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
def get_configs():
1111
"""
1212
Return a list of tuning configuration dictionaries for the autotuned matmul kernel.
13-
13+
1414
Each dictionary is a single combination (Cartesian product) of the following parameters:
1515
- block_M: tile size for M dimension (one of 64, 128, 256)
1616
- block_N: tile size for N dimension (one of 64, 128, 256)
17-
- block_K: tile size for K dimension
17+
- block_K: tile size for K dimension
1818
- num_stages: pipeline stages for K-loop (0 or 2)
1919
- threads: number of threads to launch (128, 256, or 512)
2020
- split: K-splitting factor (1 or 2)
21-
21+
2222
Returns:
2323
list[dict]: List of configuration dicts usable by the autotuner, where each dict maps
2424
the parameter name to its chosen value.
@@ -62,30 +62,30 @@ def matmul(M,
6262
split=1):
6363
"""
6464
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
65-
65+
6666
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
6767
- A: dense input of shape (M, K) with dtype `in_dtype`.
6868
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
6969
- C: output of shape (M, N) with dtype `out_dtype`.
70-
70+
7171
The generated kernel supports two dequantization paths:
7272
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
7373
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
74-
74+
7575
Important behavior and requirements:
7676
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
7777
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
7878
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
7979
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
8080
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
81-
81+
8282
Parameters that alter kernel layout/behavior (brief):
8383
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
8484
- num_stages: number of software pipeline stages for the K-loop.
8585
- threads: number of threads used per kernel block.
8686
- split: extra K-splitting factor; K must be divisible by block_K * split.
8787
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
88-
88+
8989
Returns:
9090
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
9191
"""
@@ -124,12 +124,12 @@ def matmul(M,
124124
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
125125
"""
126126
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
127-
127+
128128
This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which:
129129
- Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads).
130130
- Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers.
131131
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
132-
132+
133133
Notes and preconditions:
134134
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
135135
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
@@ -149,17 +149,17 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared):
149149
# import fast_dequantize plugin
150150
"""
151151
Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer.
152-
152+
153153
This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters.
154-
154+
155155
Parameters:
156156
B_shared: Shared-memory buffer containing packed quantized values (packed FP4 layout).
157157
B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine).
158-
158+
159159
Side effects:
160160
- Imports the external dequantization plugin via `import_source` and invokes `func_name`.
161161
- Writes dequantized BF16 results into `B_dequantize_shared`.
162-
162+
163163
Notes:
164164
- This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`).
165165
- No value is returned; results are produced by mutation of `B_dequantize_shared`.
@@ -197,18 +197,18 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared):
197197
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
198198
"""
199199
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
200-
200+
201201
The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like
202202
`B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It:
203203
- Unpacks 4-bit FP values from the packed uint8 representation in B_shared.
204204
- Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`.
205205
- Writes the dequantized bfloat16 block into B_dequantize_shared.
206-
206+
207207
Constraints:
208208
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
209209
- The helper assumes nbit == 4 and produces bfloat16 values.
210210
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
211-
211+
212212
Returns:
213213
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
214214
"""
@@ -219,22 +219,22 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
219219
scale: tir.PrimExpr, dtype: str):
220220
"""
221221
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
222-
222+
223223
This helper extracts the 4-bit field located at the bit position `pos` within the
224224
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
225225
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
226226
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
227-
227+
228228
Parameters:
229229
nbit (int): Number of bits in the packed element; must be 4.
230230
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
231231
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
232232
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
233233
dtype (str): Target dtype string; must be "bfloat16".
234-
234+
235235
Returns:
236236
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
237-
237+
238238
Notes:
239239
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
240240
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
@@ -262,16 +262,16 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
262262
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared):
263263
"""
264264
Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer.
265-
265+
266266
This helper:
267267
- Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared.
268268
- Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`.
269269
- Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope.
270-
270+
271271
Parameters:
272272
B_shared: shared-memory buffer containing packed FP4 data (uint8-packed).
273273
B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values.
274-
274+
275275
Side effects:
276276
Writes dequantized BF16 values into B_dequantize_shared. No return value.
277277
"""
@@ -298,7 +298,7 @@ def main(
298298
):
299299
"""
300300
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
301-
301+
302302
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
303303
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
304304
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
@@ -307,16 +307,16 @@ def main(
307307
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
308308
- Performs a GEMM accumulating into C_local with B transposed.
309309
- Stores the accumulated block from C_local back to the global output C via C_shared.
310-
310+
311311
Parameters:
312312
- A: input tile of shape (M, K) with dtype `in_dtype`.
313313
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
314314
- C: output tensor of shape (M, N) with dtype `out_dtype`.
315-
315+
316316
Side effects:
317317
- Writes the computed output block into the global tensor `C`.
318318
- Uses and updates shared memory buffers and per-thread accumulators.
319-
319+
320320
No value is returned.
321321
"""
322322
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
@@ -352,14 +352,14 @@ def main(
352352
def ref_program_twiddling(A, qB):
353353
"""
354354
Compute reference BF16 matrix multiply using bit-twiddled FP4 quantized B.
355-
355+
356356
Converts qB (a bit-twiddled, packed FP4 representation of matrix B) back to floating,
357357
performs C = A @ B^T in full precision, and returns the result converted to bfloat16.
358-
358+
359359
Parameters:
360360
A (torch.Tensor): Left operand with shape (M, K). Treated as floating-point (converted to torch.float for compute).
361361
qB (torch.Tensor): Bit-twiddled, packed FP4 representation of B (quantized). Shape corresponds to B's packed layout.
362-
362+
363363
Returns:
364364
torch.Tensor: Result matrix C with shape (M, N) in bfloat16.
365365
"""
@@ -373,13 +373,13 @@ def ref_program_twiddling(A, qB):
373373
def ref_program_simple(A, qB):
374374
"""
375375
Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB.
376-
376+
377377
Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning.
378-
378+
379379
Parameters:
380380
A (torch.Tensor): Left input matrix with shape (M, K).
381381
qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A.
382-
382+
383383
Returns:
384384
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
385385
"""
@@ -393,16 +393,16 @@ def ref_program_simple(A, qB):
393393
def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
394394
"""
395395
Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference.
396-
396+
397397
This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs.
398-
398+
399399
Parameters:
400400
m (int): Number of rows of A and output C (default 256).
401401
n (int): Number of columns of B and output C (default 256).
402402
k (int): Inner dimension (columns of A, rows of B) (default 256).
403403
fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True).
404404
tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False).
405-
405+
406406
Side effects:
407407
- Prints latency and TFLOPs to stdout.
408408
- Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01).

0 commit comments

Comments
 (0)