You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
87
-
87
+
88
88
The returned prim_func expects:
89
89
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
90
90
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
91
91
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
92
-
92
+
93
93
Details:
94
94
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
warp_row_tiles (int): Tiles per warp in row dimension.
112
112
warp_col_tiles (int): Tiles per warp in column dimension.
113
113
chunk (int): K-length per block (block_K).
114
-
114
+
115
115
Returns:
116
116
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
117
117
"""
@@ -187,18 +187,18 @@ def main(
187
187
):
188
188
"""
189
189
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
190
-
190
+
191
191
This kernel:
192
192
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
193
193
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
194
194
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
195
195
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
196
-
196
+
197
197
Parameters:
198
198
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
199
199
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
200
200
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
201
-
201
+
202
202
Side effects:
203
203
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
Copy file name to clipboardExpand all lines: examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
+37-37Lines changed: 37 additions & 37 deletions
Original file line number
Diff line number
Diff line change
@@ -10,15 +10,15 @@
10
10
defget_configs():
11
11
"""
12
12
Return a list of tuning configuration dictionaries for the autotuned matmul kernel.
13
-
13
+
14
14
Each dictionary is a single combination (Cartesian product) of the following parameters:
15
15
- block_M: tile size for M dimension (one of 64, 128, 256)
16
16
- block_N: tile size for N dimension (one of 64, 128, 256)
17
-
- block_K: tile size for K dimension
17
+
- block_K: tile size for K dimension
18
18
- num_stages: pipeline stages for K-loop (0 or 2)
19
19
- threads: number of threads to launch (128, 256, or 512)
20
20
- split: K-splitting factor (1 or 2)
21
-
21
+
22
22
Returns:
23
23
list[dict]: List of configuration dicts usable by the autotuner, where each dict maps
24
24
the parameter name to its chosen value.
@@ -62,30 +62,30 @@ def matmul(M,
62
62
split=1):
63
63
"""
64
64
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
65
-
65
+
66
66
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
67
67
- A: dense input of shape (M, K) with dtype `in_dtype`.
68
68
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
69
69
- C: output of shape (M, N) with dtype `out_dtype`.
70
-
70
+
71
71
The generated kernel supports two dequantization paths:
72
72
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
73
73
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
74
-
74
+
75
75
Important behavior and requirements:
76
76
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
77
77
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
78
78
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
79
79
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
80
80
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
81
-
81
+
82
82
Parameters that alter kernel layout/behavior (brief):
83
83
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
84
84
- num_stages: number of software pipeline stages for the K-loop.
85
85
- threads: number of threads used per kernel block.
86
86
- split: extra K-splitting factor; K must be divisible by block_K * split.
87
87
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
88
-
88
+
89
89
Returns:
90
90
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
127
-
127
+
128
128
This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which:
129
129
- Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads).
130
130
- Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers.
131
131
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
132
-
132
+
133
133
Notes and preconditions:
134
134
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
135
135
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer.
152
-
152
+
153
153
This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters.
B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine).
158
-
158
+
159
159
Side effects:
160
160
- Imports the external dequantization plugin via `import_source` and invokes `func_name`.
161
161
- Writes dequantized BF16 results into `B_dequantize_shared`.
162
-
162
+
163
163
Notes:
164
164
- This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`).
165
165
- No value is returned; results are produced by mutation of `B_dequantize_shared`.
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
200
-
200
+
201
201
The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like
202
202
`B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It:
203
203
- Unpacks 4-bit FP values from the packed uint8 representation in B_shared.
204
204
- Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`.
205
205
- Writes the dequantized bfloat16 block into B_dequantize_shared.
206
-
206
+
207
207
Constraints:
208
208
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
209
209
- The helper assumes nbit == 4 and produces bfloat16 values.
210
210
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
211
-
211
+
212
212
Returns:
213
213
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer.
265
-
265
+
266
266
This helper:
267
267
- Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared.
268
268
- Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`.
269
269
- Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope.
270
-
270
+
271
271
Parameters:
272
272
B_shared: shared-memory buffer containing packed FP4 data (uint8-packed).
273
273
B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values.
274
-
274
+
275
275
Side effects:
276
276
Writes dequantized BF16 values into B_dequantize_shared. No return value.
277
277
"""
@@ -298,7 +298,7 @@ def main(
298
298
):
299
299
"""
300
300
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
301
-
301
+
302
302
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
303
303
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
304
304
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
@@ -307,16 +307,16 @@ def main(
307
307
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
308
308
- Performs a GEMM accumulating into C_local with B transposed.
309
309
- Stores the accumulated block from C_local back to the global output C via C_shared.
310
-
310
+
311
311
Parameters:
312
312
- A: input tile of shape (M, K) with dtype `in_dtype`.
313
313
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
314
314
- C: output tensor of shape (M, N) with dtype `out_dtype`.
315
-
315
+
316
316
Side effects:
317
317
- Writes the computed output block into the global tensor `C`.
318
318
- Uses and updates shared memory buffers and per-thread accumulators.
319
-
319
+
320
320
No value is returned.
321
321
"""
322
322
withT.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
@@ -352,14 +352,14 @@ def main(
352
352
defref_program_twiddling(A, qB):
353
353
"""
354
354
Compute reference BF16 matrix multiply using bit-twiddled FP4 quantized B.
355
-
355
+
356
356
Converts qB (a bit-twiddled, packed FP4 representation of matrix B) back to floating,
357
357
performs C = A @ B^T in full precision, and returns the result converted to bfloat16.
358
-
358
+
359
359
Parameters:
360
360
A (torch.Tensor): Left operand with shape (M, K). Treated as floating-point (converted to torch.float for compute).
361
361
qB (torch.Tensor): Bit-twiddled, packed FP4 representation of B (quantized). Shape corresponds to B's packed layout.
362
-
362
+
363
363
Returns:
364
364
torch.Tensor: Result matrix C with shape (M, N) in bfloat16.
Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB.
376
-
376
+
377
377
Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning.
378
-
378
+
379
379
Parameters:
380
380
A (torch.Tensor): Left input matrix with shape (M, K).
381
381
qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A.
382
-
382
+
383
383
Returns:
384
384
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference.
396
-
396
+
397
397
This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs.
398
-
398
+
399
399
Parameters:
400
400
m (int): Number of rows of A and output C (default 256).
401
401
n (int): Number of columns of B and output C (default 256).
402
402
k (int): Inner dimension (columns of A, rows of B) (default 256).
403
403
fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True).
404
404
tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False).
405
-
405
+
406
406
Side effects:
407
407
- Prints latency and TFLOPs to stdout.
408
408
- Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01).
0 commit comments