|
| 1 | +import tilelang |
| 2 | +from tilelang import language as T |
| 3 | +from typing import Optional, Callable, Any |
| 4 | +import torch |
| 5 | +from tilelang import DataType |
| 6 | +from tilelang.quantize import ( |
| 7 | + _tir_packed_int_to_int_convert,) |
| 8 | + |
| 9 | + |
| 10 | +def dequantize_gemv( |
| 11 | + M: int, |
| 12 | + N: int, |
| 13 | + K: int, |
| 14 | + in_dtype: str, |
| 15 | + out_dtype: str, |
| 16 | + accum_dtype: str, |
| 17 | + num_bits: int = 4, |
| 18 | + storage_dtype: str = "int8", |
| 19 | + source_format: str = "uint", |
| 20 | + n_partition: int = 4, |
| 21 | + reduce_thread: int = 32, |
| 22 | + fast_decoding: bool = False, |
| 23 | + trans_A: bool = False, |
| 24 | + trans_B: bool = True, |
| 25 | + group_size: int = -1, |
| 26 | + with_scaling: bool = False, |
| 27 | +) -> Callable[..., Any]: |
| 28 | + |
| 29 | + assert n_partition is not None, "n_partition must be provided" |
| 30 | + assert reduce_thread is not None, ( |
| 31 | + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" |
| 32 | + "sch_outer_reduction_with_config is not implemented") |
| 33 | + |
| 34 | + assert trans_A is False, "Dequantize only implement for trans_A=False currently" |
| 35 | + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" |
| 36 | + storage_type = "".join(c for c in storage_dtype if not c.isdigit()) |
| 37 | + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) |
| 38 | + num_elems_per_byte = storage_nbit // num_bits |
| 39 | + |
| 40 | + MAX_TRANSACTION_SIZE_IN_BITS = 128 |
| 41 | + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits |
| 42 | + micro_size_k_compressed = micro_size_k // num_elems_per_byte |
| 43 | + block_K = reduce_thread * micro_size_k |
| 44 | + |
| 45 | + if group_size == -1: |
| 46 | + group_size = K |
| 47 | + |
| 48 | + A_shape = (M, K) |
| 49 | + B_shape = (N, K // storage_nbit * num_bits) |
| 50 | + C_shape = (M, N) |
| 51 | + |
| 52 | + dp4a_size = 4 |
| 53 | + use_dp4a = in_dtype == "int8" and accum_dtype == "int32" |
| 54 | + |
| 55 | + import_source: Optional[str] = None |
| 56 | + func_name: str = "" |
| 57 | + if fast_decoding is True: |
| 58 | + # Lazy import to decrease the startup time |
| 59 | + # as intrin registry may take a while to load |
| 60 | + from tilelang.quantize import get_lop3_intrin_group |
| 61 | + |
| 62 | + lop3_intrin_info = get_lop3_intrin_group( |
| 63 | + out_dtype=in_dtype, |
| 64 | + source_format=source_format, |
| 65 | + source_bit=num_bits, |
| 66 | + storage_dtype=storage_dtype, |
| 67 | + with_scaling=with_scaling, |
| 68 | + with_zeros=False, |
| 69 | + ) |
| 70 | + import_source = lop3_intrin_info["c_source"] |
| 71 | + func_name = lop3_intrin_info["func_name"] |
| 72 | + assert import_source is not None, "lop3_intrin_info is not found" |
| 73 | + assert func_name is not None, "lop3_intrin_info is not found" |
| 74 | + import_source = import_source |
| 75 | + |
| 76 | + @T.prim_func |
| 77 | + def main( |
| 78 | + A: T.Tensor[A_shape, in_dtype], |
| 79 | + B: T.Tensor[B_shape, storage_dtype], |
| 80 | + C: T.Tensor[C_shape, out_dtype], |
| 81 | + ): |
| 82 | + with T.Kernel( |
| 83 | + T.ceildiv(N, n_partition), |
| 84 | + M, |
| 85 | + threads=(reduce_thread, n_partition), |
| 86 | + ) as ( |
| 87 | + bx, |
| 88 | + by, |
| 89 | + ): |
| 90 | + A_local = T.alloc_local((micro_size_k,), in_dtype) |
| 91 | + B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) |
| 92 | + B_dequantize_local = T.alloc_local([micro_size_k], in_dtype) |
| 93 | + accum_res = T.alloc_local((1,), accum_dtype) |
| 94 | + reduced_accum_res = T.alloc_local((1,), accum_dtype) |
| 95 | + |
| 96 | + kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") |
| 97 | + ni = T.thread_binding(0, n_partition, thread="threadIdx.y") |
| 98 | + |
| 99 | + T.import_source(import_source) |
| 100 | + |
| 101 | + T.clear(accum_res) |
| 102 | + for ko in T.serial(T.ceildiv(K, block_K)): |
| 103 | + for v in T.vectorized(micro_size_k): |
| 104 | + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] |
| 105 | + |
| 106 | + for v in T.vectorized(micro_size_k_compressed): |
| 107 | + B_quant_local[v] = B[ |
| 108 | + bx * n_partition + ni, |
| 109 | + ko * (reduce_thread * micro_size_k_compressed) + |
| 110 | + kr * micro_size_k_compressed + v, |
| 111 | + ] |
| 112 | + |
| 113 | + if fast_decoding: |
| 114 | + T.call_extern( |
| 115 | + func_name, |
| 116 | + T.address_of(B_quant_local[0]), |
| 117 | + T.address_of(B_dequantize_local[0]), |
| 118 | + dtype=in_dtype, |
| 119 | + ) |
| 120 | + else: |
| 121 | + for ki in T.serial(micro_size_k): |
| 122 | + B_dequantize_local[ki] = _tir_packed_int_to_int_convert( |
| 123 | + storage_type, |
| 124 | + storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte], |
| 125 | + ki % num_elems_per_byte, in_dtype) |
| 126 | + |
| 127 | + if use_dp4a: |
| 128 | + for ki in T.serial(micro_size_k // dp4a_size): |
| 129 | + T.dp4a( |
| 130 | + A_local[ki * dp4a_size], |
| 131 | + B_dequantize_local[ki * dp4a_size], |
| 132 | + accum_res[0], |
| 133 | + ) |
| 134 | + else: |
| 135 | + for ki in T.serial(micro_size_k): |
| 136 | + accum_res[0] += A_local[ki] * B_dequantize_local[ki] |
| 137 | + |
| 138 | + with T.attr( |
| 139 | + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), |
| 140 | + "reduce_scope", |
| 141 | + T.reinterpret(T.uint64(0), dtype="handle"), |
| 142 | + ): |
| 143 | + T.evaluate( |
| 144 | + T.tvm_thread_allreduce( |
| 145 | + T.uint32(1), |
| 146 | + accum_res[0], |
| 147 | + True, |
| 148 | + reduced_accum_res[0], |
| 149 | + kr, |
| 150 | + dtype="handle", |
| 151 | + )) |
| 152 | + if kr == 0: |
| 153 | + C[by, bx * n_partition + ni] = reduced_accum_res[0] |
| 154 | + |
| 155 | + return main |
| 156 | + |
| 157 | + |
| 158 | +def main() -> None: |
| 159 | + M = 1 |
| 160 | + N = 1024 |
| 161 | + K = 1024 |
| 162 | + in_dtype = "float16" |
| 163 | + out_dtype = "float16" |
| 164 | + accum_dtype = "float16" |
| 165 | + num_bits = 4 |
| 166 | + storage_dtype = "int8" |
| 167 | + source_format = "uint" |
| 168 | + n_partition = 4 |
| 169 | + reduce_thread = 32 |
| 170 | + fast_decoding = True |
| 171 | + trans_A = False |
| 172 | + trans_B = True |
| 173 | + group_size = -1 |
| 174 | + with_scaling = False |
| 175 | + |
| 176 | + program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, |
| 177 | + source_format, n_partition, reduce_thread, fast_decoding, trans_A, |
| 178 | + trans_B, group_size, with_scaling) |
| 179 | + |
| 180 | + kernel = tilelang.compile(program) |
| 181 | + |
| 182 | + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) |
| 183 | + num_elems_per_byte = storage_nbit // num_bits |
| 184 | + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() |
| 185 | + qB = torch.randint( |
| 186 | + 0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() |
| 187 | + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() |
| 188 | + |
| 189 | + if fast_decoding: |
| 190 | + from tilelang.quantize.utils import interleave_weight |
| 191 | + qB = interleave_weight(qB, num_bits, in_dtype) |
| 192 | + kernel(A, qB, C) |
| 193 | + |
| 194 | + # int4 reference |
| 195 | + B = ( |
| 196 | + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, |
| 197 | + dtype=torch.half).to(torch.half).to(A.device)) |
| 198 | + for j in range(B.shape[1]): |
| 199 | + B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) |
| 200 | + |
| 201 | + # Get Reference Result |
| 202 | + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) |
| 203 | + print("C: ", C) |
| 204 | + print("Ref C: ", ref_c) |
| 205 | + # doesn't apply scaling, the absolute error is large |
| 206 | + torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) |
| 207 | + |
| 208 | + |
| 209 | +if __name__ == "__main__": |
| 210 | + main() |
0 commit comments