|
| 1 | +import pickle as pkl |
| 2 | +import time |
| 3 | +from dataclasses import dataclass |
| 4 | +from itertools import product |
| 5 | +from typing import Callable, Iterable, List, Optional |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.utils.benchmark as TBenchmark |
| 9 | +from torch.utils.benchmark import Measurement as TMeasurement |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | +import vllm._custom_ops as ops |
| 13 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 14 | + |
| 15 | + |
| 16 | +@dataclass |
| 17 | +class bench_params_t: |
| 18 | + num_tokens: int |
| 19 | + hidden_size: int |
| 20 | + add_residual: bool |
| 21 | + dtype: torch.dtype |
| 22 | + |
| 23 | + def description(self): |
| 24 | + return (f'N {self.num_tokens} ' |
| 25 | + f'x D {self.hidden_size} ' |
| 26 | + f'x R {self.add_residual} ' |
| 27 | + f'x DT {self.dtype}') |
| 28 | + |
| 29 | + |
| 30 | +def get_bench_params() -> List[bench_params_t]: |
| 31 | + ## Test Fixtures |
| 32 | + NUM_TOKENS = [2**x for x in range(11)] |
| 33 | + HIDDEN_SIZES = list(range(1024, 8129, 1024)) |
| 34 | + ADD_RESIDUAL = [True, False] |
| 35 | + DTYPES = [torch.bfloat16, torch.float] |
| 36 | + |
| 37 | + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) |
| 38 | + bench_params = list(map(lambda x: \ |
| 39 | + bench_params_t(x[0], x[1], x[2], x[3]), combinations)) |
| 40 | + return bench_params |
| 41 | + |
| 42 | + |
| 43 | +# Reference impls |
| 44 | +def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, |
| 45 | + residual: Optional[torch.Tensor], |
| 46 | + quant_dtype: torch.dtype): |
| 47 | + # Norm |
| 48 | + torch_out = None |
| 49 | + if residual is None: |
| 50 | + torch_out = rms_norm_layer.forward_cuda(x, residual) |
| 51 | + else: |
| 52 | + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) |
| 53 | + |
| 54 | + # Quant |
| 55 | + torch_out, _, _ = ops.scaled_int8_quant(torch_out) |
| 56 | + |
| 57 | + |
| 58 | +def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, |
| 59 | + residual: Optional[torch.Tensor], |
| 60 | + quant_dtype: torch.dtype): |
| 61 | + # Norm |
| 62 | + torch_out = None |
| 63 | + if residual is None: |
| 64 | + torch_out = rms_norm_layer.forward_cuda(x, residual) |
| 65 | + else: |
| 66 | + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) |
| 67 | + |
| 68 | + # Quant |
| 69 | + torch_out, _ = ops.scaled_fp8_quant(torch_out) |
| 70 | + |
| 71 | + |
| 72 | +def fused_impl( |
| 73 | + rms_norm_layer: RMSNorm, # this stores the weights |
| 74 | + x: torch.Tensor, |
| 75 | + residual: Optional[torch.Tensor], |
| 76 | + quant_dtype: torch.dtype): |
| 77 | + out, _ = ops.rms_norm_dynamic_per_token_quant(x, |
| 78 | + rms_norm_layer.weight, |
| 79 | + 1e-6, |
| 80 | + quant_dtype, |
| 81 | + residual=residual) |
| 82 | + |
| 83 | + |
| 84 | +# Bench functions |
| 85 | +def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, |
| 86 | + quant_dtype: torch.dtype, label: str, sub_label: str, |
| 87 | + fn: Callable, description: str) -> TMeasurement: |
| 88 | + |
| 89 | + min_run_time = 1 |
| 90 | + |
| 91 | + globals = { |
| 92 | + "rms_norm_layer": rms_norm_layer, |
| 93 | + "x": x, |
| 94 | + "residual": residual, |
| 95 | + "quant_dtype": quant_dtype, |
| 96 | + "fn": fn, |
| 97 | + } |
| 98 | + return TBenchmark.Timer( |
| 99 | + stmt="fn(rms_norm_layer, x, residual, quant_dtype)", |
| 100 | + globals=globals, |
| 101 | + label=label, |
| 102 | + sub_label=sub_label, |
| 103 | + description=description, |
| 104 | + ).blocked_autorange(min_run_time=min_run_time) |
| 105 | + |
| 106 | +def bench(params: bench_params_t, label: str, sub_label: str) \ |
| 107 | + -> Iterable[TMeasurement]: |
| 108 | + |
| 109 | + # Make inputs |
| 110 | + layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) |
| 111 | + # Make weights |
| 112 | + layer.weight.data.normal_(mean=1.0, std=0.1) |
| 113 | + # Make inputs |
| 114 | + scale = 1 / params.hidden_size |
| 115 | + x = torch.randn(params.num_tokens, |
| 116 | + params.hidden_size, |
| 117 | + dtype=params.dtype, |
| 118 | + device='cuda') * scale |
| 119 | + residual = (torch.randn_like(x) * scale).to(device='cuda') \ |
| 120 | + if params.add_residual else None |
| 121 | + |
| 122 | + timers = [] |
| 123 | + |
| 124 | + # unfused int8 impl. |
| 125 | + timers.append( |
| 126 | + bench_fn(layer, x, residual, torch.int8, label, sub_label, |
| 127 | + unfused_int8_impl, "unfused_int8_impl")) |
| 128 | + |
| 129 | + # unfused fp8 impl. |
| 130 | + timers.append( |
| 131 | + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, |
| 132 | + unfused_fp8_impl, "unfused_fp8_impl")) |
| 133 | + |
| 134 | + # fused int8 impl. |
| 135 | + timers.append( |
| 136 | + bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, |
| 137 | + "fused_int8_impl")) |
| 138 | + |
| 139 | + # fused fp8 impl. |
| 140 | + timers.append( |
| 141 | + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, |
| 142 | + fused_impl, "fused_fp8_impl")) |
| 143 | + |
| 144 | + print_timers(timers) |
| 145 | + |
| 146 | + return timers |
| 147 | + |
| 148 | + |
| 149 | +# launch bench |
| 150 | +# runner |
| 151 | +def print_timers(timers: Iterable[TMeasurement]): |
| 152 | + compare = TBenchmark.Compare(timers) |
| 153 | + compare.print() |
| 154 | + |
| 155 | + |
| 156 | +def main(): |
| 157 | + torch.set_default_device('cuda') |
| 158 | + bench_params = get_bench_params() |
| 159 | + |
| 160 | + timers = [] |
| 161 | + for bp in tqdm(bench_params): |
| 162 | + timers.extend( |
| 163 | + bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) |
| 164 | + print_timers(timers) |
| 165 | + |
| 166 | + # pickle all the results |
| 167 | + timestamp = int(time.time()) |
| 168 | + with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: |
| 169 | + pkl.dump(timers, f) |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == '__main__': |
| 173 | + main() |
0 commit comments