| 
10 | 10 | from vllm.model_executor.layers.quantization.gptq_marlin_24 import (  | 
11 | 11 |     GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,  | 
12 | 12 |     GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)  | 
 | 13 | +from vllm.model_executor.layers.quantization.utils.allspark_utils import (  | 
 | 14 | +    ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)  | 
13 | 15 | from vllm.model_executor.layers.quantization.utils.marlin_utils import (  | 
14 | 16 |     GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,  | 
15 | 17 |     MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)  | 
 | 
18 | 20 | from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (  | 
19 | 21 |     marlin_24_quantize)  | 
20 | 22 | from vllm.model_executor.layers.quantization.utils.quant_utils import (  | 
21 |  | -    gptq_pack, gptq_quantize_weights, sort_weights)  | 
 | 23 | +    gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)  | 
22 | 24 | from vllm.scalar_type import ScalarType  | 
23 | 25 | from vllm.utils import FlexibleArgumentParser  | 
24 | 26 | 
 
  | 
25 | 27 | DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]  | 
26 |  | -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]  | 
 | 28 | +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]  | 
27 | 29 | 
 
  | 
28 | 30 | ACT_ORDER_OPTS = [False, True]  | 
29 | 31 | K_FULL_OPTS = [False, True]  | 
@@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,  | 
81 | 83 |                                           GPTQ_MARLIN_24_MAX_PARALLEL)  | 
82 | 84 |     marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)  | 
83 | 85 | 
 
  | 
 | 86 | +    # AllSpark W8A16 quant  | 
 | 87 | +    as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES  | 
 | 88 | +                         and group_size == -1 and not act_order and is_k_full)  | 
 | 89 | +    if as_supported_case:  | 
 | 90 | +        properties = torch.cuda.get_device_properties(b.device.index)  | 
 | 91 | +        sm_count = properties.multi_processor_count  | 
 | 92 | +        sm_version = properties.major * 10 + properties.minor  | 
 | 93 | + | 
 | 94 | +        supported_arch = (sm_version >= 80 and sm_version < 90)  | 
 | 95 | +        as_supported_case = as_supported_case and supported_arch  | 
 | 96 | +        if supported_arch:  | 
 | 97 | +            has_zp = False  | 
 | 98 | +            w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,  | 
 | 99 | +                                                has_zp)  | 
 | 100 | +            qw = qw.to(torch.uint8)  | 
 | 101 | + | 
 | 102 | +            qw_reorder, s_reorder, zp_reorder = \  | 
 | 103 | +                ops.allspark_repack_weight(  | 
 | 104 | +                qw, s, zp, has_zp)  | 
 | 105 | +            CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD  | 
 | 106 | + | 
84 | 107 |     globals = {  | 
85 | 108 |         # Gen params  | 
86 | 109 |         "quant_type": quant_type,  | 
@@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,  | 
109 | 132 |         # GPTQ params  | 
110 | 133 |         "q_w_gptq": q_w_gptq,  | 
111 | 134 |         "repack_sort_indices": repack_sort_indices,  | 
 | 135 | +        # AllSpark W8A16 params  | 
 | 136 | +        "qw_reorder": qw_reorder if as_supported_case else None,  | 
 | 137 | +        "s_reorder": s_reorder if as_supported_case else None,  | 
 | 138 | +        "zp_reorder": zp_reorder if as_supported_case else None,  | 
 | 139 | +        "sm_count": sm_count if as_supported_case else None,  | 
 | 140 | +        "sm_version": sm_version if as_supported_case else None,  | 
 | 141 | +        "CUBLAS_M_THRESHOLD":  | 
 | 142 | +        CUBLAS_M_THRESHOLD if as_supported_case else None,  | 
112 | 143 |         # Kernels  | 
113 | 144 |         "gptq_marlin_gemm": ops.gptq_marlin_gemm,  | 
114 | 145 |         "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,  | 
115 | 146 |         "gptq_marlin_repack": ops.gptq_marlin_repack,  | 
 | 147 | +        "allspark_w8a16_gemm": ops.allspark_w8a16_gemm,  | 
116 | 148 |     }  | 
117 | 149 | 
 
  | 
118 | 150 |     min_run_time = 1  | 
@@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,  | 
172 | 204 |             description="gptq_marlin_repack",  | 
173 | 205 |         ).blocked_autorange(min_run_time=min_run_time))  | 
174 | 206 | 
 
  | 
 | 207 | +    if as_supported_case:  | 
 | 208 | +        results.append(  | 
 | 209 | +            benchmark.Timer(  | 
 | 210 | +                stmt=  | 
 | 211 | +                "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501  | 
 | 212 | +                globals=globals,  | 
 | 213 | +                label=label,  | 
 | 214 | +                sub_label=sub_label,  | 
 | 215 | +                description="allspark_w8a16_gemm_fp32",  | 
 | 216 | +            ).blocked_autorange(min_run_time=min_run_time))  | 
 | 217 | + | 
175 | 218 | 
 
  | 
176 | 219 | def main(args):  | 
177 | 220 |     print("Benchmarking models:")  | 
 | 
0 commit comments