diff --git a/examples/dequantize_gemm/example_dequant_gemm.py b/examples/dequantize_gemm/example_dequant_gemm.py deleted file mode 100644 index be440a92a..000000000 --- a/examples/dequantize_gemm/example_dequant_gemm.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Tile-AI Corporation. -# Licensed under the MIT License. diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index 78ef0762e..8bdfe3bdb 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -435,5 +435,10 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): 256, 1024, 512, "float16", "float16", "float16", 3) +def main(): + test_run_dequantize_gemm() + test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4() + + if __name__ == "__main__": - tilelang.testing.main() + main() diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index a7b352edd..aaafaf14b 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -269,19 +269,12 @@ def ref_program(A, qB): return C.transpose(0, 1) -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--m', type=int, default=256, help='M') - parser.add_argument('--n', type=int, default=256, help='N') - parser.add_argument('--k', type=int, default=256, help='K') - parser.add_argument('--tune', action='store_true', help='tune configs') - args = parser.parse_args() - M, N, K = args.m, args.n, args.k - total_flops = 2 * M * N * K +def main(m=256, n=256, k=256, tune=False): + total_flops = 2 * m * n * k - if (not args.tune): + if (not tune): program = matmul( - M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)( + m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) kernel = tilelang.compile(program, out_idx=[2]) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) @@ -294,10 +287,20 @@ def ref_program(A, qB): print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune) + best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune) best_latency = best_result.latency best_config = best_result.config - ref_latency = best_result.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--m', type=int, default=256, help='M') + parser.add_argument('--n', type=int, default=256, help='N') + parser.add_argument('--k', type=int, default=256, help='K') + parser.add_argument('--tune', action='store_true', help='tune configs') + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + main(M, N, K, args.tune) diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py new file mode 100644 index 000000000..5e42869f2 --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -0,0 +1,210 @@ +import tilelang +from tilelang import language as T +from typing import Optional, Callable, Any +import torch +from tilelang import DataType +from tilelang.quantize import ( + _tir_packed_int_to_int_convert,) + + +def dequantize_gemv( + M: int, + N: int, + K: int, + in_dtype: str, + out_dtype: str, + accum_dtype: str, + num_bits: int = 4, + storage_dtype: str = "int8", + source_format: str = "uint", + n_partition: int = 4, + reduce_thread: int = 32, + fast_decoding: bool = False, + trans_A: bool = False, + trans_B: bool = True, + group_size: int = -1, + with_scaling: bool = False, +) -> Callable[..., Any]: + + assert n_partition is not None, "n_partition must be provided" + assert reduce_thread is not None, ( + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" + "sch_outer_reduction_with_config is not implemented") + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + storage_type = "".join(c for c in storage_dtype if not c.isdigit()) + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + micro_size_k_compressed = micro_size_k // num_elems_per_byte + block_K = reduce_thread * micro_size_k + + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = (N, K // storage_nbit * num_bits) + C_shape = (M, N) + + dp4a_size = 4 + use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + # Lazy import to decrease the startup time + # as intrin registry may take a while to load + from tilelang.quantize import get_lop3_intrin_group + + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=with_scaling, + with_zeros=False, + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = import_source + + @T.prim_func + def main( + A: T.Tensor[A_shape, in_dtype], + B: T.Tensor[B_shape, storage_dtype], + C: T.Tensor[C_shape, out_dtype], + ): + with T.Kernel( + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), + ) as ( + bx, + by, + ): + A_local = T.alloc_local((micro_size_k,), in_dtype) + B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([micro_size_k], in_dtype) + accum_res = T.alloc_local((1,), accum_dtype) + reduced_accum_res = T.alloc_local((1,), accum_dtype) + + kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") + ni = T.thread_binding(0, n_partition, thread="threadIdx.y") + + T.import_source(import_source) + + T.clear(accum_res) + for ko in T.serial(T.ceildiv(K, block_K)): + for v in T.vectorized(micro_size_k): + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] + + for v in T.vectorized(micro_size_k_compressed): + B_quant_local[v] = B[ + bx * n_partition + ni, + ko * (reduce_thread * micro_size_k_compressed) + + kr * micro_size_k_compressed + v, + ] + + if fast_decoding: + T.call_extern( + func_name, + T.address_of(B_quant_local[0]), + T.address_of(B_dequantize_local[0]), + dtype=in_dtype, + ) + else: + for ki in T.serial(micro_size_k): + B_dequantize_local[ki] = _tir_packed_int_to_int_convert( + storage_type, + storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte], + ki % num_elems_per_byte, in_dtype) + + if use_dp4a: + for ki in T.serial(micro_size_k // dp4a_size): + T.dp4a( + A_local[ki * dp4a_size], + B_dequantize_local[ki * dp4a_size], + accum_res[0], + ) + else: + for ki in T.serial(micro_size_k): + accum_res[0] += A_local[ki] * B_dequantize_local[ki] + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + accum_res[0], + True, + reduced_accum_res[0], + kr, + dtype="handle", + )) + if kr == 0: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + + return main + + +def main() -> None: + M = 1 + N = 1024 + K = 1024 + in_dtype = "float16" + out_dtype = "float16" + accum_dtype = "float16" + num_bits = 4 + storage_dtype = "int8" + source_format = "uint" + n_partition = 4 + reduce_thread = 32 + fast_decoding = True + trans_A = False + trans_B = True + group_size = -1 + with_scaling = False + + program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, + source_format, n_partition, reduce_thread, fast_decoding, trans_A, + trans_B, group_size, with_scaling) + + kernel = tilelang.compile(program) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() + qB = torch.randint( + 0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() + + if fast_decoding: + from tilelang.quantize.utils import interleave_weight + qB = interleave_weight(qB, num_bits, in_dtype) + kernel(A, qB, C) + + # int4 reference + B = ( + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, + dtype=torch.half).to(torch.half).to(A.device)) + for j in range(B.shape[1]): + B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print("C: ", C) + print("Ref C: ", ref_c) + # doesn't apply scaling, the absolute error is large + torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) + + +if __name__ == "__main__": + main() diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py new file mode 100644 index 000000000..d8ce85d61 --- /dev/null +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -0,0 +1,21 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +import tilelang.testing + +import example_dequant_gemv_fp16xint4 +import example_dequant_gemm_fp4_hopper + + +@tilelang.testing.requires_cuda +def test_example_dequant_gemv_fp16xint4(): + example_dequant_gemv_fp16xint4.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_fp4_hopper(): + example_dequant_gemm_fp4_hopper.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py new file mode 100644 index 000000000..01c15e268 --- /dev/null +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -0,0 +1,231 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + block_M = [128] + block_N = [128] + num_stages = [2] + threads = [256] + _configs = list(itertools.product(block_M, block_N, num_stages, threads)) + + configs = [{ + 'block_M': c[0], + 'block_N': c[1], + 'num_stages': c[2], + 'threads': c[3] + } for c in _configs] + return configs + + +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = "float16" + accum_dtype = "float" + + def kernel_func(block_M, block_N, num_stages, threads): + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + past_len = seq_kv - seq_q + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + + return main + + if tune: + + @autotune(configs=get_configs(), warmup=10, rep=10) + @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) + def kernel(block_M=None, block_N=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel() + else: + + def kernel(block_M, block_N, num_stages, threads): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float('-inf')) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + return output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=8, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length') + parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument('--tune', action='store_true', help='tune configs') + args = parser.parse_args() + batch, heads, seq_q, seq_kv, dim, is_causal = args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if (not args.tune): + program = flashattn( + batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)( + block_M=128, block_N=128, num_stages=2, threads=256) + ref_program = partial(ref_program, is_causal=is_causal) + kernel = tilelang.compile(program, out_idx=[3]) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") diff --git a/tilelang/__init__.py b/tilelang/__init__.py index bb7c0c5ec..65cbfb11f 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -61,6 +61,7 @@ def _init_logger(): import tvm import tvm._ffi.base +from tvm import DataType # noqa: F401 from . import libinfo diff --git a/tilelang/quantize/__init__.py b/tilelang/quantize/__init__.py new file mode 100644 index 000000000..1e1dd3f9c --- /dev/null +++ b/tilelang/quantize/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +from .quantization import ( + _tir_packed_int_to_int_convert, # noqa: F401 + _tir_packed_to_signed_convert, # noqa: F401 + _tir_packed_to_unsigned_convert, # noqa: F401 + _tir_packed_to_fp4_to_f16, # noqa: F401 + _tir_u8_to_f8_e4m3_to_f16, # noqa: F401 + _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 +) + +from .utils import ( + gen_quant4, # noqa: F401 + general_compress, # noqa: F401 + interleave_weight, # noqa: F401 +) + +from .lop3 import get_lop3_intrin_group # noqa: F401 diff --git a/tilelang/quantize/lop3.py b/tilelang/quantize/lop3.py new file mode 100644 index 000000000..0886b3015 --- /dev/null +++ b/tilelang/quantize/lop3.py @@ -0,0 +1,1202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Dict, Literal + +decode_i4_to_f16 = """ +template +__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4u, B_local_decode, N); +} +""" + +decode_i4_to_f16_scale = """ +template +__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4s, B_local_decode, N, scale); +} + +template +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4u, B_local_decode, N, scale); +} + +""" + +decode_i4_to_f16_scale_offset = """ +template +__device__ void decode_i4b_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const int offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } + #pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i4s_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_offset(_i4s, B_local_decode, N, scale, offset); +} + +template +__device__ void decode_i4u_to_f16_scale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_offset(_i4u, B_local_decode, N, scale, offset); +} + +""" + +decode_i4_to_f16_scale_zeros_original = """ +template +__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_zeros_original(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i4_to_f16_scale_zeros_original_offset = """ +template +__device__ void decode_i4b_to_f16_zeros_original_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T3 const zeros_l = *zeros; + T3 const zeros_r = *(zeros + offset); + uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l); + uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_original_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_zeros_original_offset(_i4u, B_local_decode, N, scale, zeros, offset); +} +""" + +decode_i4_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} + +""" + +decode_i4_to_f16_scale_zeros_rescale_offset = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_rescale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T3 const zeros_l = *zeros; + T3 const zeros_r = *(zeros + offset); + uint const packed_zeros_l = 0x80008000 | __pack_half2(zeros_l, zeros_l); + uint const packed_zeros_r = 0x80008000 | __pack_half2(zeros_r, zeros_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(packed_zeros_l)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(packed_zeros_r)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_rescale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_rescale_offset(_i4u, B_local_decode, N, scale, zeros, offset); +} + +""" + +decode_i4_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + int16_t const zero_r = *((int16_t*)zeros); + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_quantized(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i4_to_f16_scale_zeros_quantized_offset = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_quantized_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T1 *qzeros = nullptr, const int scale_offset = 0, const int qzeros_offset = 0, const int group_offset = 0) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + uint const i4s = *reinterpret_cast(_i4s); + + T3 const scale_l = *scale; + T3 const scale_r = *(scale + scale_offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + + const int num_elems_per_storage_dtype = sizeof(T1) * 8 / 4; + + T1 const qzeros_l = *qzeros; + T1 const qzeros_r = *(qzeros + qzeros_offset); + int16_t const zero_l = (qzeros_l >> (group_offset * 4) & 0xf); + int16_t const zero_r = (qzeros_r >> (group_offset * 4) & 0xf); + + uint median_num_l = ((0xe400 | zero_l) << 16) | (0xe400 | zero_l); + uint median_num_r = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + } + #pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_l)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_r)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_quantized_offset(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, storage_dtype *qzeros = nullptr, const int scale_offset = 0, const int zero_offset = 0, const int group_offset = 0, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_quantized_offset(_i4u, B_local_decode, N, scale, qzeros, scale_offset, zero_offset, group_offset); +} +""" + +decode_i2_to_f16 = """ +template +__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2u, B_local_decode, N); +} +""" + +decode_i2_to_f16_scale = """ +template +__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2s, B_local_decode, scale, N); +} + +template +__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2u, B_local_decode, scale, N); +} +""" + +decode_i2_to_f16_scale_zeros_original_offset = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_original_offset(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int offset = 0, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + + T3 const zeros_l = *zeros; + T3 const zeros_r = *(zeros + offset); + uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l); + uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r); + + T3 const scale_l = *scale; + T3 const scale_r = *(scale + offset); + uint const packed_scales_l = __pack_half2(scale_l, scale_l); + uint const packed_scales_r = __pack_half2(scale_r, scale_r); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } + #pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0)); + } +#pragma unroll + for (int i = (N / 4); i < (N / 2); i++) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0)); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_original_offset(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int offset = 0, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, offset, N); +} +""" + +decode_i2_to_f16_scale_zeros_original = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_original(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_original(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, N); +} +""" + +decode_i2_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_rescale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_rescale(_i2u, B_local_decode, scale, zeros, N); +} +""" + +decode_i2_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + int16_t const zero_r = *((int16_t*)zeros); + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1_to_f16 = """ +/* +Kind 0: original +Kind 1: rescale +Kind 2: quantized +# documents for zeros_mode: +# original: target = (dequantize_weight - zero_point) * scale +# rescale: target = dequantize_weight * scale - zero_point +# quantized: target = (dequantize_weight - dequantize_zeros) * scale +# Notice: only support "original" and "rescale" now +zeros_mode: Literal["original", "rescale", "quantized"] = "original" +*/ +template +__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + if constexpr (isSigned) + { + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + } + if constexpr (withZeros && ZerosKind == 0) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } + if constexpr (withScaling) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } + if constexpr (withZeros && ZerosKind == 1) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } + } +} + +template +__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1s, B_local_decode, N); +} + +template +__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1u, B_local_decode, N); +} +""" + +decode_i1_to_f16_scale = """ +template +__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +""" + +decode_i1_to_f16_scale_zeros_original = """ +template +__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_zeros_original(_i1u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1s_to_i8s = """template +__device__ void decode_i1s_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) +{ + int i8s[4]; + // vector load + *reinterpret_cast(i8s) = *reinterpret_cast(_i8s); + int16_t i1b_i16 = *reinterpret_cast(_i1b); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1b = (i1b_i16 & 0x0f0f); + i1b |= ((i1b_i16 & 0xf0f0) << 12); + // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1b and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + static constexpr uint TRANSFORM_SUBTRACT = 0xffffffff; // for signed int 2x - 1 + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vadd4(i8s[i], i8s[i]); + i8s[i] = __vadd4(i8s[i], TRANSFORM_SUBTRACT); + } + *reinterpret_cast(_i8s) = *reinterpret_cast(i8s); +} + +template +__device__ void decode_i1u_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) +{ + int *i8s = reinterpret_cast(_i8s); + int16_t i1b_i16 = *reinterpret_cast(_i1b); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1b = (i1b_i16 & 0x0f0f); + i1b |= ((i1b_i16 & 0xf0f0) << 12); + // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1b and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + static constexpr uint MEDIAN_NUM = 0x00000000; + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} + +""" + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = 0x02020202; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); + } +} +template +__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + +decode_i4s_to_i8s = """template +__device__ void decode_i4s_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4b = reinterpret_cast(_i4b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = 0x07070707; +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i + 2]) + : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); + i8s[i + 2] = __vsubss4(i8s[i + 2], MEDIAN_NUM); + } +} + +template +__device__ void decode_i4u_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4b = reinterpret_cast(_i4b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i + 2]) + : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + +decode_i2s_to_i4s = r""" +template +__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16) +{ + uint *i4s = reinterpret_cast(_i4s); + uint *i2b = reinterpret_cast(_i2b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000; + +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i4s[i]) + : "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + // TODO(lei): uint4 sub should be enhanced. + // 0x03 0x03 0x03 0x03 + // i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; + } + } +} + +template +__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4u, B_local_decode, N); +} +""" + + +def get_lop3_intrin_group( + out_dtype: Literal["float16", "int8", "int4"], + source_format: Literal["int", "uint"] = "uint", + source_bit: int = 4, + storage_dtype: Literal["int32", "int8"] = "int8", + with_scaling: bool = False, + with_zeros: bool = False, + zeros_mode: Literal["original", "rescale", "quantized"] = "original", + storage_scope: str = "local", +) -> Dict[str, str]: + """ + This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. + LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of + intrinsic operations that can be performed on these inputs. This function retrieves and returns this group. + + Parameters + ---------- + in_dtype : Literal["int8"] + The data type of the input. It should be "int8". + + out_dtype : Literal["float16", "int8", "int4"] + The data type of the output. It can be either "float16" or "int8" or "int4". + + storage_nbit : int, optional + The number of bits used for storage. By default, it is 4. + + with_scale : bool, optional + A boolean parameter that indicates whether scaling should be applied. By default, it is False. + + with_zeros : bool, optional + A boolean parameter that indicates whether zeros should be used. By default, it is False. + + zeros_mode : Literal["original", "rescale", "quantized"], optional + The mode of zeros. It can be either "original", "rescale", or "quantized". By default, it is "original". + + storage_scope : Literal["local", "warp"], optional + The scope of the storage. It can be either "local" or "warp". By default, it is "local". + + Returns + ------- + Dict[str, str] + A dictionary mapping the names of the intrinsics to their corresponding implementations. + """ + assert out_dtype in [ + "float16", "int8", "int4" + ], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' .") + + dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"} + target_dtype = dtype_mapping[out_dtype] + + if source_format not in ["int", "uint"]: + raise ValueError( + f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.") + if with_zeros and source_format == "int": + raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}") + + source_symbol = "i" if source_format == "int" else "u" + + import_c_map = { + "i4_to_f16": decode_i4_to_f16, + "i2_to_f16": decode_i2_to_f16, + "i1_to_f16": decode_i1_to_f16, + "i4_to_f16_scale": decode_i4_to_f16_scale, + "i4_to_f16_scale_offset": decode_i4_to_f16_scale_offset, + "i2_to_f16_scale": decode_i2_to_f16_scale, + "i1_to_f16_scale": decode_i1_to_f16_scale, + "i4_to_f16_scale_zeros_original": decode_i4_to_f16_scale_zeros_original, + "i4_to_f16_scale_zeros_original_offset": decode_i4_to_f16_scale_zeros_original_offset, + "i2_to_f16_scale_zeros_original": decode_i2_to_f16_scale_zeros_original, + "i1_to_f16_scale_zeros_original": decode_i1_to_f16_scale_zeros_original, + "i4_to_f16_scale_zeros_rescale": decode_i4_to_f16_scale_zeros_rescale, + "i4_to_f16_scale_zeros_rescale_offset": decode_i4_to_f16_scale_zeros_rescale_offset, + "i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale, + "i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale, + "i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized, + "i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized, + "i4_to_f16_scale_zeros_quantized_offset": decode_i4_to_f16_scale_zeros_quantized_offset, + "i1_to_i8": decode_i1s_to_i8s, + "i2_to_i8": decode_i2s_to_i8s, + "i4_to_i8": decode_i4s_to_i8s, + "i2_to_i4": decode_i2s_to_i4s, + } + key = f"i{source_bit}_to_{target_dtype}" + if with_scaling: + key += "_scale" + if with_zeros: + key += f"_zeros_{zeros_mode}" + + is_ladder_stage3 = (storage_scope == "warp") and with_scaling + if is_ladder_stage3: + key += "_offset" + + if out_dtype == "float16": + d4f = "f16" + elif out_dtype == "int8": + d4f = "i8s" + elif out_dtype == "int4": + d4f = "i4s" + else: + raise ValueError("Unsupported target dtype: {}".format(target_dtype)) + source_symbol = "u" if source_format == "uint" else "s" + func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f) + if with_scaling: + func_name += "_scale" + if with_zeros: + func_name += f"_zeros_{zeros_mode}" + if is_ladder_stage3: + func_name += "_offset" + + return { + "func_name": func_name, + "c_source": import_c_map[key], + } diff --git a/tilelang/quantize/quantization.py b/tilelang/quantize/quantization.py new file mode 100644 index 000000000..4cc931c46 --- /dev/null +++ b/tilelang/quantize/quantization.py @@ -0,0 +1,234 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +# The code below is mostly copied from mlc.ai quantization.py in mlc-llm. +# pylint: disable=invalid-name,missing-function-docstring,unused-variable +"""TIR computation utilities for quantization.""" + +from tilelang import tvm as tvm +from tvm import tir + + +# fmt: off +def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): + mask = tir.const((1 << 16) - 1, "uint32") + res = [] + for data in [v0, v1]: + u32_val = tir.reinterpret("uint32", data) + if round_to_even: + rounding_bias = ((u32_val >> tir.const(16, "uint32")) + & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") + u32_val += rounding_bias + res.append((u32_val >> tir.const(16, "uint32")) & mask) + return res[0] | (res[1] << tir.const(16, "uint32")) + + +def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): + mask = tir.const((1 << 16) - 1, "uint32") + x0 = x & mask + x1 = (x >> 16) & mask + return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) + + +def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == "uint32" + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) + + +def _tir_packed_uint_to_uint_to_float(storage_nbit: int): + storage_dtype = "uint" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) - 1 + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_int_to_int_to_float(storage_nbit: int): + storage_dtype = "int" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + + +def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float32" + val_u32 = tir.reinterpret("uint32", val) + # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) + # e_f32 == 120 -> e_f4 = 1 + # e_f32 < 120 -> e_f4 = 0 + m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") + e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") + s = (val_u32 >> tir.const(31, "uint32")) + e_f4 = tir.Select( + e_f32 > tir.const(120, "uint32"), + tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), + tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), + tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float16" + val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) + m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") + e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") + s = (val_u32 >> tir.const(15, "uint32")) + e_f4 = tir.Select( + e_f16 > tir.const(8, "uint32"), + tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), + tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float32" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f32 = 0 + # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask + s = f4 >> tir.const(3, "uint32") + e_f4 = f4 & tir.const(7, "uint32") + e_f32 = e_f4 | tir.const(120, "uint32") + val_f32 = tir.reinterpret("float32", + (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) + return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) + + +def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float16" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f16 = 0 + # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint16") + f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask + s = f4 >> tir.const(3, "uint16") + e_f4 = f4 & tir.const(7, "uint16") + e_f16 = e_f4 | tir.const(8, "uint16") + val_f16 = tir.reinterpret("float16", + ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16")) + return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16) + +def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + f4 = ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(storage_dtype) + f4 = (val >> (pos.astype(storage_dtype) * tir.const(nbit, storage_dtype))) & mask + s = f4 >> tir.const(3, storage_dtype) + e_f4 = f4 & tir.const(7, storage_dtype) + e_f16 = e_f4 | tir.const(8, storage_dtype) + val_f16 = tir.reinterpret("float16", + ((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype("uint16")) + return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, "float16"), val_f16) + + return f_convert + +def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") + e4 = val & tir.const(0x40, "uint16") + prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), + tir.const(0x4000, "uint16")) + e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix + return tir.reinterpret("float16", s_f16 | e_f16) + + +def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") + e4 = val & tir.const(0x40, "uint16") + e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) + e_f16 = e_f16 ^ tir.const(0x2000, "uint16") + return tir.reinterpret("float16", s_f16 | e_f16) + + +def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + return tir.reinterpret("e5m2_float8", val).astype("float16") + + +def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype) + + return f_convert + + +def _tir_packed_to_unsigned_convert_with_zeros(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm.tir.PrimExpr, + dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return (((val >> (pos * nbit).astype(storage_dtype)) & mask) - zero).astype(dtype) + + return f_convert + + +def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + + +# fmt: on diff --git a/tilelang/quantize/utils.py b/tilelang/quantize/utils.py new file mode 100644 index 000000000..29f800cfe --- /dev/null +++ b/tilelang/quantize/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + + +def gen_quant4(k, n, groupsize=-1): + import torch + import torch.nn as nn + maxq = 2**4 + w = torch.randn((k, n), dtype=torch.half, device="cpu") + + original_w = w.clone() + + if groupsize == -1: + groupsize = k + + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + + # Quantize. + w = torch.round(w / s).int() + + # Unsigned storage. + w += (maxq) // 2 + + w = torch.clamp(w, 0, maxq) + + # Dequantize. + ref = (w - (maxq) // 2).half() * s + + if groupsize != -1: + + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((k, n)).contiguous() + return w + + ref = reshape(ref) + w = reshape(w) + + s = s.reshape((-1, n)).contiguous() + linear = nn.Linear(k, n, bias=False) + linear.weight.data = ref.t() + + return original_w, linear, s, (w - (maxq) // 2) + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None): + import torch + if storage_dtype is None: + storage_dtype = torch.int8 + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == torch.float16: + lowprecision_weight = lowprecision_weight.to(torch.int8) + int8_weight = torch.zeros( + (*lowprecision_weight.shape[:-1], lowprecision_weight.shape[-1] // elems_per_byte), + dtype=torch.int8, + device=lowprecision_weight.device) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] << + (source_bits * k)).to(torch.int8) + + return int8_weight.to(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype="float16"): + """Interleave the weight to the target data type. + + Args: + qweight (_type_): _description_ + nbits (int, optional): _description_. Defaults to 4. + target_dtype (str, optional): _description_. Defaults to "float16". + + Returns: + _type_: _description_ + + Example: + qweight = torch.randint(0, 127, (10, 10), dtype=torch.int8).cuda() + interleave_weight(qweight, 4, "float16") + """ + import torch + assert target_dtype in ["float16", "int8"] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(torch.int32) + new_qweight = torch.zeros_like(qweight) + bits_stride = 8 if target_dtype == "int8" else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == "int8": + # special handling for 1b interleave + n16_weight = new_qweight & torch.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(torch.int8) + elif nbits == 2 and target_dtype == "float16": + n8_weight = new_qweight & torch.int32(0xFF0000FF) + n8_weight |= ((new_qweight & torch.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & torch.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(torch.int8) + elif nbits == 1 and target_dtype == "float16": + n8_weight = new_qweight & torch.int32(0xF000000F) + n8_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 8 + n8_weight |= ((new_qweight & torch.int32(0x00000F00)) >> 8) << 16 + n8_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24 + n8_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4 + n8_weight |= ((new_qweight & torch.int32(0x00F00000)) >> 20) << 12 + n8_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 20 + return n8_weight.view(torch.int8) + + return new_qweight.view(torch.int8)