|  | 
| 10 | 10 | from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( | 
| 11 | 11 |     GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, | 
| 12 | 12 |     GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) | 
|  | 13 | +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( | 
|  | 14 | +    ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) | 
| 13 | 15 | from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | 
| 14 | 16 |     GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, | 
| 15 | 17 |     MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) | 
|  | 
| 18 | 20 | from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( | 
| 19 | 21 |     marlin_24_quantize) | 
| 20 | 22 | from vllm.model_executor.layers.quantization.utils.quant_utils import ( | 
| 21 |  | -    gptq_pack, gptq_quantize_weights, sort_weights) | 
|  | 23 | +    gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) | 
| 22 | 24 | from vllm.scalar_type import ScalarType | 
| 23 | 25 | from vllm.utils import FlexibleArgumentParser | 
| 24 | 26 | 
 | 
| 25 | 27 | DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] | 
| 26 |  | -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] | 
|  | 28 | +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] | 
| 27 | 29 | 
 | 
| 28 | 30 | ACT_ORDER_OPTS = [False, True] | 
| 29 | 31 | K_FULL_OPTS = [False, True] | 
| @@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str, | 
| 81 | 83 |                                           GPTQ_MARLIN_24_MAX_PARALLEL) | 
| 82 | 84 |     marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) | 
| 83 | 85 | 
 | 
|  | 86 | +    # AllSpark W8A16 quant | 
|  | 87 | +    as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES | 
|  | 88 | +                         and group_size == -1 and not act_order and is_k_full) | 
|  | 89 | +    if as_supported_case: | 
|  | 90 | +        properties = torch.cuda.get_device_properties(b.device.index) | 
|  | 91 | +        sm_count = properties.multi_processor_count | 
|  | 92 | +        sm_version = properties.major * 10 + properties.minor | 
|  | 93 | + | 
|  | 94 | +        supported_arch = (sm_version >= 80 and sm_version < 90) | 
|  | 95 | +        as_supported_case = as_supported_case and supported_arch | 
|  | 96 | +        if supported_arch: | 
|  | 97 | +            has_zp = False | 
|  | 98 | +            w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, | 
|  | 99 | +                                                has_zp) | 
|  | 100 | +            qw = qw.to(torch.uint8) | 
|  | 101 | + | 
|  | 102 | +            qw_reorder, s_reorder, zp_reorder = \ | 
|  | 103 | +                ops.allspark_repack_weight( | 
|  | 104 | +                qw, s, zp, has_zp) | 
|  | 105 | +            CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD | 
|  | 106 | + | 
| 84 | 107 |     globals = { | 
| 85 | 108 |         # Gen params | 
| 86 | 109 |         "quant_type": quant_type, | 
| @@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str, | 
| 109 | 132 |         # GPTQ params | 
| 110 | 133 |         "q_w_gptq": q_w_gptq, | 
| 111 | 134 |         "repack_sort_indices": repack_sort_indices, | 
|  | 135 | +        # AllSpark W8A16 params | 
|  | 136 | +        "qw_reorder": qw_reorder if as_supported_case else None, | 
|  | 137 | +        "s_reorder": s_reorder if as_supported_case else None, | 
|  | 138 | +        "zp_reorder": zp_reorder if as_supported_case else None, | 
|  | 139 | +        "sm_count": sm_count if as_supported_case else None, | 
|  | 140 | +        "sm_version": sm_version if as_supported_case else None, | 
|  | 141 | +        "CUBLAS_M_THRESHOLD": | 
|  | 142 | +        CUBLAS_M_THRESHOLD if as_supported_case else None, | 
| 112 | 143 |         # Kernels | 
| 113 | 144 |         "gptq_marlin_gemm": ops.gptq_marlin_gemm, | 
| 114 | 145 |         "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, | 
| 115 | 146 |         "gptq_marlin_repack": ops.gptq_marlin_repack, | 
|  | 147 | +        "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, | 
| 116 | 148 |     } | 
| 117 | 149 | 
 | 
| 118 | 150 |     min_run_time = 1 | 
| @@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str, | 
| 172 | 204 |             description="gptq_marlin_repack", | 
| 173 | 205 |         ).blocked_autorange(min_run_time=min_run_time)) | 
| 174 | 206 | 
 | 
|  | 207 | +    if as_supported_case: | 
|  | 208 | +        results.append( | 
|  | 209 | +            benchmark.Timer( | 
|  | 210 | +                stmt= | 
|  | 211 | +                "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501 | 
|  | 212 | +                globals=globals, | 
|  | 213 | +                label=label, | 
|  | 214 | +                sub_label=sub_label, | 
|  | 215 | +                description="allspark_w8a16_gemm_fp32", | 
|  | 216 | +            ).blocked_autorange(min_run_time=min_run_time)) | 
|  | 217 | + | 
| 175 | 218 | 
 | 
| 176 | 219 | def main(args): | 
| 177 | 220 |     print("Benchmarking models:") | 
|  | 
0 commit comments