From 67ff543f7ddf0641d799c7889db1e1b8aec03c67 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 7 Jul 2024 19:51:31 +0000 Subject: [PATCH 01/33] add device_spec --- test/profiler/test_device_spec.py | 68 +++++ torchao/profiler/device_spec.py | 417 ++++++++++++++++++++++++++++++ 2 files changed, 485 insertions(+) create mode 100644 test/profiler/test_device_spec.py create mode 100644 torchao/profiler/device_spec.py diff --git a/test/profiler/test_device_spec.py b/test/profiler/test_device_spec.py new file mode 100644 index 000000000..d76fdd3a2 --- /dev/null +++ b/test/profiler/test_device_spec.py @@ -0,0 +1,68 @@ + +import pytest + +cuda_driver = pytest.importorskip("triton.runtime.driver", reason="requires triton cuda driver module") +import itertools +from contextlib import contextmanager +from unittest.mock import patch + +import torch + +from torchao.profiler.device_spec import ( + _AVAILABLE_GPU_SPECS, + CUDADeviceSpec, + get_chip_name, +) + +# -------------------- Device Spec Tests ------------------- # +DEVICE_NAMES = ["h100 sxm", "a100", "nvidia geforce rtx 4090"] +DTYPES = [torch.float32, torch.bfloat16, torch.float16] +USE_TENSORCORES = [True, False] +DEVICE_CONFIGS = itertools.product(DEVICE_NAMES, DTYPES, USE_TENSORCORES) + + +@contextmanager +def patch_device(device_name): + with patch("torch.cuda.get_device_name", return_value=device_name): + yield +@pytest.mark.parametrize("device_name, dtype, use_tensorcores", DEVICE_CONFIGS, ids=lambda x: str(x)) +def test_device_spec(device_name, dtype, use_tensorcores): + with patch_device(device_name): + device_spec = CUDADeviceSpec(dtype=dtype, use_tensorcores=use_tensorcores) + if dtype == torch.float32 and use_tensorcores: + dtype = "tfloat32" + chip_name = get_chip_name(device_name) + expected_flops = _AVAILABLE_GPU_SPECS[chip_name][dtype] + assert device_spec.flop_per_s == expected_flops + assert device_spec.flops_by_dtype[dtype] == expected_flops + assert device_spec.roofline_balancepoint == expected_flops / device_spec.bandwidth + + with pytest.raises(AssertionError): + device_spec.flop_per_s = None + print(device_spec.roofline_balancepoint) + # Prevent setting attributes not in named fields to guard against user error + with pytest.raises(AttributeError): + device_spec.FLOPs = None + +def test_empty_device_spec(): + device_name = "fake device" + with patch_device(device_name): + with pytest.raises(AssertionError): + device_spec = CUDADeviceSpec() + + # Ok to instantiate as long as fields are filled + device_spec = CUDADeviceSpec(name=device_name, + flop_per_s=1.0, + bandwidth=1.0, + dtype=torch.float32, + use_tensorcores=True) + device_name = DEVICE_NAMES[0] + + with patch_device(device_name): + # All critical fields will be auto-filled except for dtype (and vram, but vram is not used for downstream calcs atm) + device_spec = CUDADeviceSpec(dtype=torch.float32) + + # No dtype specified + with pytest.raises(AssertionError): + device_spec = CUDADeviceSpec() + diff --git a/torchao/profiler/device_spec.py b/torchao/profiler/device_spec.py new file mode 100644 index 000000000..6d8851d49 --- /dev/null +++ b/torchao/profiler/device_spec.py @@ -0,0 +1,417 @@ +import logging +from collections import defaultdict +from copy import copy +from dataclasses import dataclass, field, fields +from typing import Dict, Optional, Union + +import torch + +logger = logging.getLogger(__name__) + +# Copied from https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/throughput.py +_AVAILABLE_GPU_SPECS: Dict[str, Dict[Union[str, torch.dtype], float]] = { + # Hopper + # source: https://resources.nvidia.com/en-us-tensor-core + "h100 nvl": { + torch.float64: 67e12, + torch.float32: 133.8e12, + "tfloat32": 989.4e12, + torch.bfloat16: 1978.8e12, + torch.float16: 1978.8e12, + torch.int8: 3957.8e12, + }, + "h100 sxm": { + torch.float64: 33.5e12, + torch.float32: 66.9e12, + "tfloat32": 494.7e12, + torch.bfloat16: 989.4e12, + torch.float16: 989.4e12, + torch.int8: 1978.9e12, + }, + "h100 pcie": { + torch.float64: 25.6e12, + torch.float32: 51.2e12, + "tfloat32": 378e12, + torch.bfloat16: 756e12, + torch.float16: 756e12, + torch.int8: 1513e12, + }, + # Ada + # source: https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf + "rtx 4090": { + torch.float32: 82.6e12, + "tfloat32": 82.6e12, + torch.bfloat16: 82.6e12, + torch.float16: 82.6e12, + torch.int8: 660.6e12, + "int4": 1321.2e12, + }, + "rtx 4080": { + torch.float32: 48.7e12, + "tfloat32": 48.7e12, + torch.bfloat16: 48.7e12, + torch.float16: 48.7e12, + torch.int8: 389.9e12, + "int4": 779.8e12, + }, + "l4": { + torch.float32: 30.3e12, + "tfloat32": 60e12, + torch.bfloat16: 121e12, + torch.float16: 121e12, + torch.int8: 242e12, + "int4": 484e12, + }, + "l40": { + torch.float32: 90.5e12, + "tfloat32": 90.5e12, + torch.bfloat16: 181e12, + torch.float16: 181e12, + torch.int8: 362e12, + "int4": 724e12, + }, + # Ampere + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + # sxm and pcie have same flop counts + "a100": { + torch.float64: 9.7e12, + torch.float32: 19.5e12, + "tfloat32": 156e12, + torch.bfloat16: 312e12, + torch.float16: 312e12, + torch.int8: 624e12, + }, + "a6000": { + torch.float32: 38.7e12, + "tfloat32": 77.4e12, + torch.bfloat16: 38.7e12, + torch.float16: 38.7e12, + torch.int8: 309.7e12, + "int4": 619.3e12, + }, + "a40": { + torch.float32: 37.4e12, + "tfloat32": 74.8e12, + torch.bfloat16: 37.4e12, + torch.float16: 37.4e12, + torch.int8: 299.3e12, + "int4": 598.7e12, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf + "a10g": { + torch.float32: 31.2e12, + "tfloat32": 62.5e12, + torch.bfloat16: 125e12, + torch.float16: 125e12, + torch.int8: 250e12, + "int4": 500e12, + }, + "rtx 3090 ti": { + torch.float32: 40e12, + "tfloat32": 40e12, + torch.bfloat16: 40e12, + torch.float16: 40e12, + torch.int8: 320e12, + "int4": 640e12, + }, + "rtx 3090": { + torch.float32: 35.6e12, + "tfloat32": 35.6e12, + torch.bfloat16: 35.6e12, + torch.float16: 35.6e12, + torch.int8: 284e12, + "int4": 568e12, + }, + "rtx 3080 ti": { + torch.float32: 34.1e12, + "tfloat32": 34.1e12, + torch.bfloat16: 34.1e12, + torch.float16: 34.1e12, + torch.int8: 272.8e12, + "int4": 546.6e12, + }, + "rtx 3080": { + torch.float32: 29.8e12, + "tfloat32": 29.8e12, + torch.bfloat16: 29.8e12, + torch.float16: 29.8e12, + torch.int8: 238e12, + "int4": 476e12, + }, + "rtx 3070": { + torch.float32: 20.3e12, + "tfloat32": 20.3e12, + torch.bfloat16: 20.3e12, + torch.float16: 20.3e12, + torch.int8: 162.6e12, + "int4": 325.2e12, + }, + # Turing + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf + # sxm and pcie have same flop counts + "t4": { + torch.float32: 8.1e12, + torch.float16: 65e12, + torch.int8: 130e12, + "int4": 260e12, + }, + # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf + "quadro rtx 5000": { + torch.float32: 11.2e12, + torch.float16: 89.2e12, + }, + "rtx 2080 super": { + torch.float32: 11.2e12, + torch.float16: 22.3e12, + torch.int8: 178.4e12, + "int4": 356.8e12, + }, + "rtx 2080 ti": { + torch.float32: 14.2e12, + torch.float16: 28.5e12, + torch.int8: 227.7e12, + "int4": 455.4e12, + }, + "rtx 2080": { + torch.float32: 10.6e12, + torch.float16: 21.2e12, + torch.int8: 169.6e12, + "int4": 339.1e12, + }, + # https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.pdf + "rtx 2070 super": { + torch.float32: 9.1e12, + torch.float16: 18.1e12, + torch.int8: 145e12, + "int4": 290e12, + }, + "titan rtx": { + torch.float32: 16.3e12, + torch.float16: 32.6e12, + torch.int8: 261e12, + "int4": 522e12, + }, + # Volta + # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + "v100 sxm": { + torch.float64: 7.8e12, + torch.float32: 15.7e12, + torch.float16: 125e12, + }, + "v100 pcie": { + torch.float64: 7e12, + torch.float32: 14e12, + torch.float16: 112e12, + }, + "v100s pcie": { + torch.float64: 8.2e12, + torch.float32: 16.4e12, + torch.float16: 130e12, + }, +} + + +# Adapted from https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/throughput.py +def get_chip_name(device: int = 0) -> str: + device_name = torch.cuda.get_device_name(device) + chip = device_name.lower() + + if "h100" in chip: + if "hbm3" in chip: + chip = "h100 sxm" + elif "nvl" in chip: + chip = "h100 nvl" + elif "pcie" in chip or "hbm2e" in chip: + chip = "h100 pcie" + elif "l4" in chip: + chip = "l40" if "tesla" in chip else "l4" + elif "geforce rtx" in chip: + number = chip.split(" ")[3] + extra = "" + if "super" in chip: + extra = " super" + elif "ti" in chip: + extra = " ti" + chip = f"rtx {number}{extra}" + elif "a6000" in chip: + chip = "a6000" + elif "a100" in chip: + chip = "a100" + elif "a40" in chip: + chip = "a40" + elif "a10g" in chip: + chip = "a10g" + elif "t4" in chip: + chip = "t4" + elif "quadro rtx 5000" in chip: + chip = "quadro rtx 5000" + elif "titan rtx" in chip: + chip = "titan rtx" + elif "v100-sxm" in chip: + chip = "v100 sxm" + elif "v100-pcie" in chip: + chip = "v100 pcie" + elif "v100s-pcie" in chip: + chip = "v100s pcie" + else: + chip = None + return chip + + +def get_vram(device: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device) + return device_props.total_memory + + +def get_bandwidth(device: int = 0) -> int: + try: + from triton.testing import get_dram_gbps + + bandwidth = get_dram_gbps(device) * 1e9 + except ImportError: + print("Could not import triton to get DRAM Gbps. Please install triton") + bandwidth = None + return bandwidth + + +def get_flops_by_dtype(chip_name: str) -> dict[torch.dtype, float]: + return _AVAILABLE_GPU_SPECS.get(chip_name, None) + + # # Check for tfloat32 + # if ( + # dtype == torch.float32 + # and "tfloat32" in dtype_to_flops + # and torch.get_float32_matmul_precision() != "highest" + # ): + # logger.warning("Using tfloat32 tensorcores FLOPs") + # dtype = "tfloat32" + # if dtype not in dtype_to_flops: + # logger.warning(f"FLOPs not found for {dtype!r} on {chip!r}") + # return None + # return dtype_to_flops[dtype] + + +@dataclass +class DeviceSpec: + """ + Abstract device specs for theoretical peak performance calculations. + + Fields will be auto-populated in __post_init__ if not already specified + and if data is available + - bandwidth (bytes /s) + - flop_per_s (FLOP / s) + - vram (bytes) + - dtype (torch.dtype) dtype used for theoretical peak performance + - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOPs + """ + + device_type: str + name: Optional[str] = None + bandwidth: Optional[int] = None + flop_per_s: Optional[int] = None + vram: Optional[int] = None + dtype: Optional[torch.dtype] = None + flops_by_dtype: dict = field(default_factory=dict) + + def _post_init_check(self): + assert self.bandwidth is not None, "GPU bandwidth is None - please specify the bandwidth in GB/s in order to enable speed of light calculations" + assert self.dtype is not None, "GPU dtype is None - please specify the dtype in order to enable speed of light calculations" + assert self.flop_per_s is not None, "GPU flop_per_s is None - please specify the flop_per_s in FLOP/s in order to enable speed of light calculations" + self.flops_by_dtype.update({self.dtype: self.flop_per_s}) + + # Not needed for downstream calculations atm, no need to assert + if self.vram is None: + print("GPU vram is None - please specify the vram in bytes") + + def __setattr__(self, name, value): + # Check if the attribute is already defined + if name in {f.name for f in fields(self)}: + super().__setattr__(name, value) + else: + raise AttributeError( + f"Cannot add new attribute '{name}' to {self.__class__.__name__}" + ) + + def __str__(self): + if self.bandwidth is not None: + bw = round(self.bandwidth, 4) + if self.flop_per_s is not None: + tflops = round(self.flop_per_s / 1e12, 4) + if self.vram is not None: + vram_GB = round(self.vram / 1e9, 1) + return f"DeviceSpec(device_type={self.device_type}, name={self.name}, dtype={self.dtype}, bandwidth={bw}GB/s, flops={tflops}TFLOPs, vram={vram_GB}GB)" + + @property + def roofline_balancepoint(self): + """ + Arithmetic intensity (FLOP / byte) transition point from + memory-bound to compute-bound regime. + + This is the ridgepoint of the roofline curve. + """ + assert ( + self.bandwidth is not None + ), "Please set bandwidth in order to calculate roofline balancepoint" + assert ( + self.flop_per_s is not None + ), "Please set flop_per_s in order to calculate roofline balancepoint" + + return self.flop_per_s / self.bandwidth + + +@dataclass +class CUDADeviceSpec(DeviceSpec): + """ + CUDA specs for theoretical peak performance + + Fields will be auto-populated in __post_init__ if not already specified + and if data is available + + See DeviceSpec for a list of available fields + See AVAILABLE_GPU_SPECS for a list of available chips + """ + + device_type: str = "cuda" + # Device index + device: Optional[int] = 0 + # Whether to use tfloat32 FLOPs for dtype == torch.float32 + # We assume that tensorcores will always be used for fp16, int8, and other sub-single precision dtypes + use_tensorcores: bool = True + + def __post_init__(self): + # Populate fields if not already populated + self.name = torch.cuda.get_device_name(self.device) + + # Memory bandwidth in bytes / s + if self.bandwidth is None: + self.bandwidth = get_bandwidth() + + # FLOPs + if self.flop_per_s is None: + chip_name = get_chip_name(self.device) + if chip_name is None: + print(f"No FLOPs data available for device name {self.name}") + else: + flops_by_dtype = get_flops_by_dtype(chip_name) + if flops_by_dtype is not None: + self.flops_by_dtype.update(flops_by_dtype) + + # Populate flops if not already populated + if flops_by_dtype is not None and self.dtype in flops_by_dtype: + self.flop_per_s = flops_by_dtype[self.dtype] + + if self.dtype == torch.float32: + use_tf32 = "tfloat32" in flops_by_dtype and self.use_tensorcores + + if use_tf32: + self.flop_per_s = flops_by_dtype["tfloat32"] + else: + print( + f"Could not find FLOPs for dtype {self.dtype} for device {self.name}" + ) + # Vram + if self.vram is None: + self.vram = get_vram() + + # Issue post check warnings + self._post_init_check() From dfc7f8c13e3fa978dc925417a2882497c3c90e27 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 7 Jul 2024 21:06:46 +0000 Subject: [PATCH 02/33] add performance counter --- test/profiler/test_performance_counter.py | 117 ++++++++++++++++++++++ torchao/profiler/performance_counter.py | 101 +++++++++++++++++++ 2 files changed, 218 insertions(+) create mode 100644 test/profiler/test_performance_counter.py create mode 100644 torchao/profiler/performance_counter.py diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py new file mode 100644 index 000000000..bc607de71 --- /dev/null +++ b/test/profiler/test_performance_counter.py @@ -0,0 +1,117 @@ +import pytest + +# Skip if transformers is not installed +transformers = pytest.importorskip("transformers") +LlamaConfig = transformers.models.llama.modeling_llama.LlamaConfig +LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM +import torch + +from torchao.profiler.performance_counter import PerformanceCounterMode +from torchao.utils import TORCH_VERSION_AFTER_2_5 + + +def get_leaf_nodes(count_keys, module_name): + return [k for k in count_keys if k.endswith(module_name)] + +def attn_proj_io_check(model_config, batch_size, seqlen, element_size): + input_size = batch_size * seqlen * model_config.hidden_size * element_size + weight_size = model_config.hidden_size * model_config.hidden_size * element_size + output_size = batch_size * seqlen * model_config.hidden_size * element_size + return input_size + weight_size + output_size +def attn_io_check(model_config, batch_size, seqlen, element_size): + # queries, keys, values -> factor of 3 + input_size = (batch_size * seqlen * model_config.hidden_size * 3) * element_size + output_size = (batch_size * seqlen * model_config.hidden_size) * element_size + return input_size + output_size + +def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): + assert module_name in ["up_proj", "gate_proj", "down_proj"] + + if module_name == "down_proj": + input_size = batch_size * seqlen * model_config.intermediate_size * element_size + else: + input_size = batch_size * seqlen * model_config.hidden_size * element_size + weight_size = model_config.hidden_size * model_config.intermediate_size * element_size + if module_name == "down_proj": + output_size = batch_size * seqlen * model_config.hidden_size * element_size + else: + output_size = batch_size * seqlen * model_config.intermediate_size * element_size + + return input_size + weight_size + output_size + + +CONFIG_7B = (32, 4096, 11008, 32, 32000) +MEDIUM_CONFIG = [p // 2 for p in CONFIG_7B] +SMALL_CONFIG = [p // 4 for p in CONFIG_7B] + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="requires torch >= 2.5") +@pytest.mark.parametrize("num_hidden_layers, hidden_size, intermediate_size, num_attention_heads, vocab_size", [MEDIUM_CONFIG, SMALL_CONFIG]) +@pytest.mark.parametrize("batch_size, seqlen", [(1, 128),]) +@pytest.mark.parametrize("dtype", [torch.float16], ids=lambda p: str(p)) +def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, num_attention_heads, vocab_size, batch_size, seqlen, dtype): + + cfg = LlamaConfig(num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size) + + # Note we set some options manually since the model doesn't seem to be initialized correctly + # when these options are set in LlamaConfig + cfg._attn_implementation = "sdpa" + model = LlamaForCausalLM(cfg).to(dtype).to("cuda") + model_config = model.config + element_size = dtype.itemsize + + input_ids = torch.randint(0, model_config.vocab_size, (batch_size, seqlen), device="cuda") + with torch.no_grad(): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + with PerformanceCounterMode() as perf_counter: + _ = model(input_ids) + + summary_flops = perf_counter.get_summary_flop_counts() + summary_io = perf_counter.get_summary_io_counts() + flops_by_op = perf_counter.get_flop_counts() + io_by_op = perf_counter.get_io_counts() + assert len(summary_flops) == len(summary_io) + assert summary_flops.keys() == summary_io.keys() + + # Attn Projections + for k in ["q_proj", "k_proj", "v_proj"]: + # Flops check + proj_keys = get_leaf_nodes(summary_flops.keys(), k) + assert len(proj_keys) == model.config.num_hidden_layers + expected_flops = 2 * batch_size * seqlen * model_config.hidden_size * model_config.hidden_size + assert expected_flops == summary_flops[proj_keys[0]] + + # io movement check + expected_size = attn_proj_io_check(model_config, batch_size, seqlen, element_size) + assert expected_size == summary_io[proj_keys[0]] + + # Attention + attention_keys = get_leaf_nodes(summary_flops.keys(), "self_attn") + for k in attention_keys: + flops = flops_by_op[k] + io_movement = io_by_op[k] + for op, count in flops.items(): + if "attention" in op.__name__: + expected_flops = 2 * 2 * batch_size * seqlen * seqlen * model_config.hidden_size + assert expected_flops == count + for op, count in io_movement.items(): + if "attention" in op.__name__: + # Check approx equal due to other small artifacts returned by sdpa.mem_efficient_attention + # See #https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L867 + # Check within 100 bytes + expected_size = attn_io_check(model_config, batch_size, seqlen, element_size) + assert abs(expected_size - count) < 100 + # FFN + for k in ["up_proj", "gate_proj", "down_proj"]: + proj_keys = get_leaf_nodes(summary_flops.keys(), k) + assert len(proj_keys) == model.config.num_hidden_layers + expected_flops = 2 * batch_size * seqlen * model_config.hidden_size * model_config.intermediate_size + assert expected_flops == summary_flops[proj_keys[0]] + + # io movement check + expected_size = ffn_io_check(model_config, batch_size, seqlen, element_size, k) + assert expected_size == summary_io[proj_keys[0]] diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py new file mode 100644 index 000000000..a0e467935 --- /dev/null +++ b/torchao/profiler/performance_counter.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +import math +from collections import defaultdict + +import torch +from torch.utils._pytree import tree_map +from torch.utils.flop_counter import FlopCounterMode + +aten = torch.ops.aten + +class PerformanceCounterMode(FlopCounterMode): + def __init__(self, display=False, depth=10, debug=False): + self.debug = debug + self.io_counts = defaultdict(lambda: defaultdict(int)) + super().__init__(display=display, depth=depth) + + def get_io_counts(self): + return {k: dict(v) for k,v in self.io_counts.items()} + + def get_total_io(self): + return sum(self.io_counts['Global'].values()) + + def _get_io_sizes(self, args): + sizes = tree_map(lambda x: x.numel() * x.element_size() if isinstance(x, torch.Tensor) else 0, args) + if not hasattr(sizes, "__len__"): + sizes = [sizes] + return sizes + + def get_summary_flop_counts(self): + flop_counts = self.get_flop_counts() + return {k: sum(v.values()) for k,v in flop_counts.items()} + + def get_summary_io_counts(self): + io_counts = self.get_io_counts() + return {k: sum(v.values()) for k,v in io_counts.items()} + + def _nearest_power_of_10(self, x): + if x == 0: + return x, 0 + + power = int(math.floor(math.log10(abs(x)) / 3)) + scaled_value = x / (10 ** (3 * power)) + + return scaled_value, power + + def pretty_summary_counts(self, type="flops", precision=2, depth=None): + assert type in ["flops", "io"] + metric_units = {0: '', 1: 'k', 2: 'M', 3: 'G', 4: 'T', 5: 'P', 6: 'E', 7: 'Z', 8: 'Y'} + + if depth is None: + depth = self.depth + summary_counts = self.get_summary_flop_counts() if type == "flops" else self.get_summary_io_counts() + keys_to_print = [k for k in summary_counts.keys() if len(k.split(".")) <= depth] + units = "FLOPs" if type == "flops" else "B" + summary_str = [] + for k in sorted(keys_to_print, key=lambda x: len(x.split("."))): + if k == "Global" or k is None: + continue + spaces = " " * (len(k.split(".")) - 1) + scaled_val, power = self._nearest_power_of_10(summary_counts[k]) + formatted_val = f"{scaled_val:.{precision}f}{metric_units[power]}{units}" + summary_str.append(f"{spaces}{k}: {formatted_val}") + + return "\n".join(summary_str) + + def _count_io(self, func_packet, out, args, kwargs): + arg_sizes = self._get_io_sizes(args) + kwargs_sizes = self._get_io_sizes(kwargs.values()) + out_sizes = self._get_io_sizes(out) + arg_size, kwargs_size, out_size = sum(arg_sizes), sum(kwargs_sizes), sum(out_sizes) + return arg_size, kwargs_size, out_size + + def _count_flops(self, func_packet, out, args, kwargs): + if func_packet in self.flop_registry: + flop_count_func = self.flop_registry[func_packet] + flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator] + arg_size, kwarg_size, out_size = self._count_io(func_packet, out, args, kwargs) + total_size = arg_size + kwarg_size + out_size + + for par in set(self.mod_tracker.parents): + if self.debug: + print(f"Counting flops for {par}, {func_packet}: {flop_count}") + print(f"Counting io for {par}, {func_packet}: {sum([arg_size, kwarg_size, out_size])} = {arg_size} + {kwarg_size} + {out_size}") + self.flop_counts[par][func_packet] += flop_count + self.io_counts[par][func_packet] += total_size + + return out + +if __name__ == "__main__": + import torch + from transformers import AutoModelForCausalLM, LlamaForCausalLM + + model_id = "/home/ubuntu/gpt-fast-dev/checkpoints/7B" + model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True) + input_ids = torch.randint(0, model.config.vocab_size, (1, 16), dtype=torch.int64, device="cuda") + + with PerformanceCounterMode(display=False, depth=10) as perf_counter: + _ = model(input_ids) + + print(perf_counter.pretty_summary_counts(type="flops", depth=3)) + print(perf_counter.pretty_summary_counts(type="io", depth=3)) \ No newline at end of file From 0ac59f75ae39dab64f0a09e341431c6ca7a5cab0 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 7 Jul 2024 21:10:23 +0000 Subject: [PATCH 03/33] add more perf counter tools --- torchao/profiler/performance_counter.py | 235 ++++++++++++++++++++++-- 1 file changed, 224 insertions(+), 11 deletions(-) diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index a0e467935..3d2e1976a 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -1,11 +1,18 @@ -# mypy: allow-untyped-defs +import json import math +import textwrap +import time from collections import defaultdict +from contextlib import contextmanager +from copy import deepcopy +from functools import partial import torch from torch.utils._pytree import tree_map from torch.utils.flop_counter import FlopCounterMode +from .device_spec import DeviceSpec + aten = torch.ops.aten class PerformanceCounterMode(FlopCounterMode): @@ -86,16 +93,222 @@ def _count_flops(self, func_packet, out, args, kwargs): return out -if __name__ == "__main__": - import torch - from transformers import AutoModelForCausalLM, LlamaForCausalLM +class PerformanceTimer: + def __init__(self, name, precision=1, display=False, depth=10): + self.name = name + self.precision = precision + self.display = display + self.depth = depth + self.perf_counter = PerformanceCounterMode(display=display, depth=depth) + + def __enter__(self): + self.start = time.perf_counter() + self.perf_counter.__enter__() + return self + + def _print_exit_msg(self): + gflops = round(self.total_flops / 1e9, self.precision) + ms = round(self.elapsed * 1e3, self.precision) + if self.display: + print(f"{self.name.upper()}: Elapsed = {ms} ms, FLOPS = {gflops} GFLOPs") + + def __exit__(self, type, value, traceback): + self.end = time.perf_counter() + #Convert to ms + self.elapsed = (self.end - self.start) + self.perf_counter.__exit__(type, value, traceback) + if self.display: + self._print_exit_msg() + + @property + def total_flops(self): + return self.perf_counter.get_total_flops() + + @property + def total_io(self): + return self.perf_counter.get_total_io() + + @property + def flops_table(self): + return self.perf_counter.get_table() + + def get_summary_flop_counts(self): + return self.perf_counter.get_summary_flop_counts() + + def get_summary_io_counts(self): + return self.perf_counter.get_summary_io_counts() + + @property + def flop_counts(self): + return self.perf_counter.get_flop_counts() + + @property + def io_counts(self): + return self.perf_counter.get_io_counts() + + def get_pretty_summary(self, depth): + return self.perf_counter.pretty_summary_counts(depth=depth if depth is not None else self.depth) +class CUDAPerformanceTimer(PerformanceTimer): + + def __enter__(self): + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self.start.record() + self.perf_counter = PerformanceCounterMode(display=self.display, depth=self.depth) + self.perf_counter.__enter__() + return self + + def __exit__(self, type, value, traceback): + self.end.record() + torch.cuda.synchronize() + # Convert from ms to s + self.elapsed = self.start.elapsed_time(self.end) * 1e-3 + self.perf_counter.__exit__(type, value, traceback) + + if self.display: + self._print_exit_msg() + +class PerformanceCounterManager: + COUNT_KEYS = ["label", "num_tokens", "elapsed", "throughput", "total_flops", "flops_table", "flop_counts"] + def __init__(self, depth=10, timer_cls=PerformanceTimer, verbose=False): + super().__init__() + self._counts = {} + self._depth = depth + self.timer_cls = timer_cls + self.verbose = verbose + + @contextmanager + def count(self, label: str, num_tokens: int): + perf_timer = self.timer_cls(name=label, depth=self._depth) + perf_timer.__enter__() + try: + yield self + finally: + perf_timer.__exit__(None, None, None) + self._counts[label] = {"label": label, + "num_tokens": num_tokens, + "elapsed": perf_timer.elapsed, + "token_throughput": num_tokens / perf_timer.elapsed, + "total_flops": perf_timer.total_flops, + "flops_throughput": perf_timer.total_flops / perf_timer.elapsed, + "total_io": perf_timer.total_io, + "io_throughput": perf_timer.total_io / perf_timer.elapsed, + "summary_flops": perf_timer.get_summary_flop_counts(), + "summary_io": perf_timer.get_summary_io_counts(), + "flop_counts": perf_timer.flop_counts, + "io_counts": perf_timer.io_counts, + "pretty_summary": perf_timer.get_pretty_summary(depth=self._depth), + } + @property + def counts(self): + return self._counts + def get_counts(self): + return self._counts + + @property + def total_flops(self): + return sum(count["total_flops"] for count in self._counts.values()) + + @property + def total_io(self): + return sum(count["total_io"] for count in self._counts.values()) + @property + def total_tokens(self): + return sum(count["num_tokens"] for count in self._counts.values()) - model_id = "/home/ubuntu/gpt-fast-dev/checkpoints/7B" - model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True) - input_ids = torch.randint(0, model.config.vocab_size, (1, 16), dtype=torch.int64, device="cuda") + @property + def total_time(self): + return sum(count["elapsed"] for count in self._counts.values()) + + def to_dict(self): + # Convert flop_counts from OpOverloadPackets to str + counts = deepcopy(self._counts) + for label,label_counts in counts.items(): + counts[label]['flop_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['flop_counts'].items()} + counts[label]['io_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['io_counts'].items()} - with PerformanceCounterMode(display=False, depth=10) as perf_counter: - _ = model(input_ids) + return counts + + def to_json(self): + return json.dumps(self.to_dict(), indent=2) + + def get_summary(self, device_spec: DeviceSpec=None): + token_throughput = self.total_tokens / self.total_time + achieved_bandwidth = self.total_io / self.total_time + achieved_flops_per_s = self.total_flops / self.total_time + stats = { + "total_tokens": self.total_tokens, + "total_time": self.total_time, + "total_flops": self.total_flops, + "total_io": self.total_io, + "token_throughput": token_throughput, + "achieved_bandwidth": achieved_bandwidth, + "achieved_flops_per_s": achieved_flops_per_s, + "arithmetic_intensity": self.total_flops / self.total_io + } + if device_spec is not None: + device_stats = { + "device_name": device_spec.name, + "theoretical_bandwidth": device_spec.bandwidth, + "theoretical_throughput": device_spec.flops, + "model_bandwidth_utilization": achieved_bandwidth / device_spec.bandwidth, + "model_flops_utilization": achieved_flops_per_s / device_spec.flops, + } + stats.update(device_stats) + return stats + def _format_single(self, label, counts, precision, verbose=False): + ms = round(counts['elapsed'] * 1e3, precision) + token_throughput = round(counts['token_throughput'], precision) + gflops = round(counts['total_flops'] / 1e9, precision) + gb = round(counts['total_io'] / 1e9, precision) + flop_throughput = round(gflops / counts['elapsed'], precision) + io_throughput = round(gb / counts['elapsed'], precision) + text = textwrap.dedent(f"""\ + {label.title()}: + Elapsed = {ms:,} ms + Tokens: + Total {counts['num_tokens']} + Throughput {token_throughput} tokens/s + IO: + Total {gb:,} GB + Throughput {io_throughput} GB/s + FLOPs: + Total {gflops:,} GFLOPs, + Throughput {flop_throughput:,} GFLOP/s""") + if verbose: + counts_by_module = counts['pretty_summary'] + text += textwrap.dedent(f"""\nCounts by Module:\n{counts_by_module}""") - print(perf_counter.pretty_summary_counts(type="flops", depth=3)) - print(perf_counter.pretty_summary_counts(type="io", depth=3)) \ No newline at end of file + return text + + def _format_totals(self, precision=2): + ms = round(self.total_time * 1e3, precision) + token_throughput = round(self.total_tokens / self.total_time, precision) + gflops = round(self.total_flops / 1e9, precision) + gb = round(self.total_io / 1e9, precision) + flop_throughput = round(gflops / self.total_time, precision) + io_throughput = round(gb / self.total_time, precision) + text = textwrap.dedent(f"""\ + FlopCounter Summary: + Total time = {ms:,} ms + Tokens: + Total {self.total_tokens} + Throughput {token_throughput:,} tokens/s + IO: + Total {gb:,} GB + Throughput {io_throughput:,} GB/s + FLOPs: + Total {gflops:,} GFLOPs + Throughput {flop_throughput:,} GFLOP/s""") + return text + + def print_summary(self, labels: list[str] = None, precision=2, verbose=None): + verbose = verbose if verbose is not None else self.verbose + _print = partial(print, flush=True, end='\n') + if labels is None: + text = self._format_totals(precision=precision) + _print(text) + else: + for label in labels: + text = self._format_single(label, self._counts[label], precision=precision, verbose=verbose) + _print(text) \ No newline at end of file From bf77aa8cac958beed8811d3114c96a8910245561 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 7 Jul 2024 23:17:55 +0000 Subject: [PATCH 04/33] add performance counter manager test --- test/profiler/test_performance_counter.py | 84 ++++++++++++++++++++++- torchao/profiler/performance_counter.py | 16 +++-- 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index bc607de71..920636c2d 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -4,12 +4,27 @@ transformers = pytest.importorskip("transformers") LlamaConfig = transformers.models.llama.modeling_llama.LlamaConfig LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM +import time +from contextlib import contextmanager +from unittest.mock import patch + import torch -from torchao.profiler.performance_counter import PerformanceCounterMode +from torchao.profiler.device_spec import CUDADeviceSpec +from torchao.profiler.performance_counter import ( + CUDAPerformanceTimer, + PerformanceCounterManager, + PerformanceCounterMode, + PerformanceTimer, +) from torchao.utils import TORCH_VERSION_AFTER_2_5 +@contextmanager +def patch_device(device_name): + with patch("torch.cuda.get_device_name", return_value=device_name): + yield + def get_leaf_nodes(count_keys, module_name): return [k for k in count_keys if k.endswith(module_name)] @@ -115,3 +130,70 @@ def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, # io movement check expected_size = ffn_io_check(model_config, batch_size, seqlen, element_size, k) assert expected_size == summary_io[proj_keys[0]] + +@pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) +@pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("device_name, bandwidth", [(None, 0), ("A100", 2e12)]) +def test_performance_counter_manager(shape, timer_cls, dtype, device_name, bandwidth): + + batch_size, query_len, in_features, out_features = shape + num_tokens = batch_size * query_len + element_size = dtype.itemsize + a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") + + cm = PerformanceCounterManager(timer_cls=timer_cls) + start = time.perf_counter() + with cm.count("a", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + + elapsed = (end - start) + expected_flops = 2 * num_tokens * in_features * out_features + expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size + assert cm.total_flops == expected_flops + counts = cm.get_counts() + assert "a" in counts + assert abs(counts['a']['elapsed'] - elapsed) < 1e-1 # +/- 100ms + assert counts['a']['total_flops'] == expected_flops + assert counts['a']['total_io'] == expected_io + assert counts['a']['token_throughput'] == counts['a']['num_tokens'] / counts['a']['elapsed'] + assert counts['a']['flops_throughput'] == counts['a']['total_flops'] / counts['a']['elapsed'] + assert counts['a']['io_throughput'] == counts['a']['total_io'] / counts['a']['elapsed'] + + start = time.perf_counter() + with cm.count("b", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + elapsed = end - start + # cm.print_summary(labels=["a", "b"], verbose=True) + assert "a" in cm.counts + assert "b" in cm.counts + counts = cm.counts + assert abs(counts['b']['elapsed'] - elapsed) < 1e-1 # +/- 100ms + assert counts['b']['total_flops'] == expected_flops + assert counts['b']['total_io'] == expected_io + assert cm.total_flops == 2 * expected_flops + assert cm.total_io == 2 * expected_io + + if device_name is not None: + with patch_device(device_name): + device_spec = CUDADeviceSpec(dtype=dtype, bandwidth=bandwidth) + else: + device_spec = None + summary = cm.get_summary(device_spec=device_spec) + expected_tokens = 2 * num_tokens + expected_total_flops = 2 * expected_flops + expected_total_io = 2 * expected_io + expected_total_time = cm.total_time + expected_token_throughput = expected_tokens / expected_total_time + expected_io_throughput = expected_total_io / expected_total_time + expected_flops_throughput = expected_total_flops / expected_total_time + assert summary['total_tokens'] == expected_tokens + assert summary['total_io'] == expected_total_io + assert summary['total_flops'] == expected_total_flops + assert summary['total_time'] == expected_total_time + assert abs(summary['token_throughput'] - expected_token_throughput) < 1e-1 + assert abs(summary['io_throughput'] - expected_io_throughput) < 1e-1 + assert abs(summary['flops_throughput'] - expected_flops_throughput) < 1e-1 \ No newline at end of file diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 3d2e1976a..30da45991 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -234,6 +234,8 @@ def to_json(self): def get_summary(self, device_spec: DeviceSpec=None): token_throughput = self.total_tokens / self.total_time + io_throughput = self.total_io / self.total_time + flops_throughput = self.total_flops / self.total_time achieved_bandwidth = self.total_io / self.total_time achieved_flops_per_s = self.total_flops / self.total_time stats = { @@ -242,20 +244,26 @@ def get_summary(self, device_spec: DeviceSpec=None): "total_flops": self.total_flops, "total_io": self.total_io, "token_throughput": token_throughput, + "io_throughput": io_throughput, + "flops_throughput": flops_throughput, "achieved_bandwidth": achieved_bandwidth, "achieved_flops_per_s": achieved_flops_per_s, "arithmetic_intensity": self.total_flops / self.total_io } if device_spec is not None: + theoretical_bandwidth = device_spec.bandwidth + theoretical_flop_per_s = device_spec.flop_per_s + device_stats = { "device_name": device_spec.name, - "theoretical_bandwidth": device_spec.bandwidth, - "theoretical_throughput": device_spec.flops, - "model_bandwidth_utilization": achieved_bandwidth / device_spec.bandwidth, - "model_flops_utilization": achieved_flops_per_s / device_spec.flops, + "theoretical_bandwidth": theoretical_bandwidth, + "theoretical_throughput": theoretical_flop_per_s, + "model_bandwidth_utilization": achieved_bandwidth / theoretical_bandwidth, + "model_flops_utilization": achieved_flops_per_s / theoretical_flop_per_s, } stats.update(device_stats) return stats + def _format_single(self, label, counts, precision, verbose=False): ms = round(counts['elapsed'] * 1e3, precision) token_throughput = round(counts['token_throughput'], precision) From e570b007a6062d7057fb99bd50e486f624a39995 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 7 Jul 2024 23:38:40 +0000 Subject: [PATCH 05/33] add mbu and mfu test --- test/profiler/test_performance_counter.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 920636c2d..a6dba672e 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -100,7 +100,7 @@ def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, expected_flops = 2 * batch_size * seqlen * model_config.hidden_size * model_config.hidden_size assert expected_flops == summary_flops[proj_keys[0]] - # io movement check + # io check expected_size = attn_proj_io_check(model_config, batch_size, seqlen, element_size) assert expected_size == summary_io[proj_keys[0]] @@ -127,7 +127,7 @@ def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, expected_flops = 2 * batch_size * seqlen * model_config.hidden_size * model_config.intermediate_size assert expected_flops == summary_flops[proj_keys[0]] - # io movement check + # io check expected_size = ffn_io_check(model_config, batch_size, seqlen, element_size, k) assert expected_size == summary_io[proj_keys[0]] @@ -167,7 +167,6 @@ def test_performance_counter_manager(shape, timer_cls, dtype, device_name, bandw _ = torch.matmul(a, b) end = time.perf_counter() elapsed = end - start - # cm.print_summary(labels=["a", "b"], verbose=True) assert "a" in cm.counts assert "b" in cm.counts counts = cm.counts @@ -196,4 +195,11 @@ def test_performance_counter_manager(shape, timer_cls, dtype, device_name, bandw assert summary['total_time'] == expected_total_time assert abs(summary['token_throughput'] - expected_token_throughput) < 1e-1 assert abs(summary['io_throughput'] - expected_io_throughput) < 1e-1 - assert abs(summary['flops_throughput'] - expected_flops_throughput) < 1e-1 \ No newline at end of file + assert abs(summary['flops_throughput'] - expected_flops_throughput) < 1e-1 + if device_spec is not None: + mbu = summary["model_bandwidth_utilization"] + mfu = summary["model_flops_utilization"] + expected_mbu = expected_io_throughput / bandwidth + expected_mfu = expected_flops_throughput / device_spec.flop_per_s + assert abs(mbu - expected_mbu) < 1e-1 + assert abs(mfu - expected_mfu) < 1e-1 \ No newline at end of file From 8e766b6b1a596bdfaf6832523d5f1b9c69972ecc Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 7 Jul 2024 23:44:59 +0000 Subject: [PATCH 06/33] refactor performance manager device spec --- test/profiler/test_performance_counter.py | 21 +++++++++++++-------- torchao/profiler/performance_counter.py | 6 ++++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index a6dba672e..01fa8cc22 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -136,14 +136,24 @@ def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("device_name, bandwidth", [(None, 0), ("A100", 2e12)]) def test_performance_counter_manager(shape, timer_cls, dtype, device_name, bandwidth): - + + # Set up inputs batch_size, query_len, in_features, out_features = shape num_tokens = batch_size * query_len element_size = dtype.itemsize a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") - cm = PerformanceCounterManager(timer_cls=timer_cls) + # Setup device spec + if device_name is not None: + with patch_device(device_name): + device_spec = CUDADeviceSpec(dtype=dtype, bandwidth=bandwidth) + else: + device_spec = None + + cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) + + # Start count start = time.perf_counter() with cm.count("a", num_tokens=num_tokens): _ = torch.matmul(a, b) @@ -176,12 +186,7 @@ def test_performance_counter_manager(shape, timer_cls, dtype, device_name, bandw assert cm.total_flops == 2 * expected_flops assert cm.total_io == 2 * expected_io - if device_name is not None: - with patch_device(device_name): - device_spec = CUDADeviceSpec(dtype=dtype, bandwidth=bandwidth) - else: - device_spec = None - summary = cm.get_summary(device_spec=device_spec) + summary = cm.get_summary() expected_tokens = 2 * num_tokens expected_total_flops = 2 * expected_flops expected_total_io = 2 * expected_io diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 30da45991..cf9c14f71 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -170,11 +170,12 @@ def __exit__(self, type, value, traceback): class PerformanceCounterManager: COUNT_KEYS = ["label", "num_tokens", "elapsed", "throughput", "total_flops", "flops_table", "flop_counts"] - def __init__(self, depth=10, timer_cls=PerformanceTimer, verbose=False): + def __init__(self, depth=10, timer_cls: PerformanceTimer=PerformanceTimer, device_spec: DeviceSpec=None, verbose=False): super().__init__() self._counts = {} self._depth = depth self.timer_cls = timer_cls + self.device_spec = device_spec self.verbose = verbose @contextmanager @@ -232,7 +233,7 @@ def to_dict(self): def to_json(self): return json.dumps(self.to_dict(), indent=2) - def get_summary(self, device_spec: DeviceSpec=None): + def get_summary(self): token_throughput = self.total_tokens / self.total_time io_throughput = self.total_io / self.total_time flops_throughput = self.total_flops / self.total_time @@ -250,6 +251,7 @@ def get_summary(self, device_spec: DeviceSpec=None): "achieved_flops_per_s": achieved_flops_per_s, "arithmetic_intensity": self.total_flops / self.total_io } + device_spec = self.device_spec if device_spec is not None: theoretical_bandwidth = device_spec.bandwidth theoretical_flop_per_s = device_spec.flop_per_s From def868577a8c31499ee3f320e8b4532be1a8c322 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 13:56:03 +0000 Subject: [PATCH 07/33] add perf stats --- torchao/profiler/performance_counter.py | 30 +++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index cf9c14f71..c32177a64 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -5,7 +5,9 @@ from collections import defaultdict from contextlib import contextmanager from copy import deepcopy +from dataclasses import dataclass from functools import partial +from typing import Any, Dict import torch from torch.utils._pytree import tree_map @@ -168,6 +170,34 @@ def __exit__(self, type, value, traceback): if self.display: self._print_exit_msg() +@dataclass +class PerformanceStats: + label: str + num_tokens: int + elapsed: float + # token_throughput: float + total_flops: float + # flops_throughput: float + total_io: float + # io_throughput: float + summary_flops: Dict[str, int] + summary_io: Dict[str, int] + flop_counts: Dict[str, Dict[Any, int]] + io_counts: Dict[str, Dict[Any, int]] + pretty_summary: str + + @property + def token_throughput(self): + return self.num_tokens / self.elapsed + + @property + def flops_throughput(self): + return self.total_flops / self.elapsed + + @property + def io_throughput(self): + return self.total_io / self.elapsed + class PerformanceCounterManager: COUNT_KEYS = ["label", "num_tokens", "elapsed", "throughput", "total_flops", "flops_table", "flop_counts"] def __init__(self, depth=10, timer_cls: PerformanceTimer=PerformanceTimer, device_spec: DeviceSpec=None, verbose=False): From cd02f4fc79caf803dd21918220ecadc3e366ce6a Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 14:16:27 +0000 Subject: [PATCH 08/33] start perf counter manager test refactor --- test/profiler/test_performance_counter.py | 156 ++++++++++++---------- torchao/profiler/performance_counter.py | 53 +++++--- 2 files changed, 114 insertions(+), 95 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 01fa8cc22..39f3fe0e3 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -20,11 +20,6 @@ from torchao.utils import TORCH_VERSION_AFTER_2_5 -@contextmanager -def patch_device(device_name): - with patch("torch.cuda.get_device_name", return_value=device_name): - yield - def get_leaf_nodes(count_keys, module_name): return [k for k in count_keys if k.endswith(module_name)] @@ -131,80 +126,95 @@ def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, expected_size = ffn_io_check(model_config, batch_size, seqlen, element_size, k) assert expected_size == summary_io[proj_keys[0]] -@pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) -@pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) -@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) -@pytest.mark.parametrize("device_name, bandwidth", [(None, 0), ("A100", 2e12)]) -def test_performance_counter_manager(shape, timer_cls, dtype, device_name, bandwidth): - - # Set up inputs - batch_size, query_len, in_features, out_features = shape - num_tokens = batch_size * query_len - element_size = dtype.itemsize - a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") - b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") - - # Setup device spec +@contextmanager +def patch_device(device_name): + with patch("torch.cuda.get_device_name", return_value=device_name): + yield + +@pytest.fixture +def device_spec(device_name, bandwidth): if device_name is not None: with patch_device(device_name): - device_spec = CUDADeviceSpec(dtype=dtype, bandwidth=bandwidth) + device_spec = CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) else: device_spec = None + return device_spec + +@pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) +@pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("device_spec", [(None, 0), ("A100", 2e12)]) +def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): + print(f"Device Spec: {device_spec}") + + # # Set up inputs + # batch_size, query_len, in_features, out_features = shape + # num_tokens = batch_size * query_len + # element_size = dtype.itemsize + # a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + # b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") + + # # Setup device spec + # if device_name is not None: + # with patch_device(device_name): + # device_spec = CUDADeviceSpec(dtype=dtype, bandwidth=bandwidth) + # else: + # device_spec = None - cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) + # cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) - # Start count - start = time.perf_counter() - with cm.count("a", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() + # # Start count + # start = time.perf_counter() + # with cm.count("a", num_tokens=num_tokens): + # _ = torch.matmul(a, b) + # end = time.perf_counter() - elapsed = (end - start) - expected_flops = 2 * num_tokens * in_features * out_features - expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size - assert cm.total_flops == expected_flops - counts = cm.get_counts() - assert "a" in counts - assert abs(counts['a']['elapsed'] - elapsed) < 1e-1 # +/- 100ms - assert counts['a']['total_flops'] == expected_flops - assert counts['a']['total_io'] == expected_io - assert counts['a']['token_throughput'] == counts['a']['num_tokens'] / counts['a']['elapsed'] - assert counts['a']['flops_throughput'] == counts['a']['total_flops'] / counts['a']['elapsed'] - assert counts['a']['io_throughput'] == counts['a']['total_io'] / counts['a']['elapsed'] + # elapsed = (end - start) + # expected_flops = 2 * num_tokens * in_features * out_features + # expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size + # assert cm.total_flops == expected_flops + # counts = cm.get_counts() + # assert "a" in counts + # assert abs(counts['a']['elapsed'] - elapsed) < 1e-1 # +/- 100ms + # assert counts['a']['total_flops'] == expected_flops + # assert counts['a']['total_io'] == expected_io + # assert counts['a']['token_throughput'] == counts['a']['num_tokens'] / counts['a']['elapsed'] + # assert counts['a']['flops_throughput'] == counts['a']['total_flops'] / counts['a']['elapsed'] + # assert counts['a']['io_throughput'] == counts['a']['total_io'] / counts['a']['elapsed'] - start = time.perf_counter() - with cm.count("b", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() - elapsed = end - start - assert "a" in cm.counts - assert "b" in cm.counts - counts = cm.counts - assert abs(counts['b']['elapsed'] - elapsed) < 1e-1 # +/- 100ms - assert counts['b']['total_flops'] == expected_flops - assert counts['b']['total_io'] == expected_io - assert cm.total_flops == 2 * expected_flops - assert cm.total_io == 2 * expected_io + # start = time.perf_counter() + # with cm.count("b", num_tokens=num_tokens): + # _ = torch.matmul(a, b) + # end = time.perf_counter() + # elapsed = end - start + # assert "a" in cm.counts + # assert "b" in cm.counts + # counts = cm.counts + # assert abs(counts['b']['elapsed'] - elapsed) < 1e-1 # +/- 100ms + # assert counts['b']['total_flops'] == expected_flops + # assert counts['b']['total_io'] == expected_io + # assert cm.total_flops == 2 * expected_flops + # assert cm.total_io == 2 * expected_io - summary = cm.get_summary() - expected_tokens = 2 * num_tokens - expected_total_flops = 2 * expected_flops - expected_total_io = 2 * expected_io - expected_total_time = cm.total_time - expected_token_throughput = expected_tokens / expected_total_time - expected_io_throughput = expected_total_io / expected_total_time - expected_flops_throughput = expected_total_flops / expected_total_time - assert summary['total_tokens'] == expected_tokens - assert summary['total_io'] == expected_total_io - assert summary['total_flops'] == expected_total_flops - assert summary['total_time'] == expected_total_time - assert abs(summary['token_throughput'] - expected_token_throughput) < 1e-1 - assert abs(summary['io_throughput'] - expected_io_throughput) < 1e-1 - assert abs(summary['flops_throughput'] - expected_flops_throughput) < 1e-1 - if device_spec is not None: - mbu = summary["model_bandwidth_utilization"] - mfu = summary["model_flops_utilization"] - expected_mbu = expected_io_throughput / bandwidth - expected_mfu = expected_flops_throughput / device_spec.flop_per_s - assert abs(mbu - expected_mbu) < 1e-1 - assert abs(mfu - expected_mfu) < 1e-1 \ No newline at end of file + # summary = cm.get_summary() + # expected_tokens = 2 * num_tokens + # expected_total_flops = 2 * expected_flops + # expected_total_io = 2 * expected_io + # expected_total_time = cm.total_time + # expected_token_throughput = expected_tokens / expected_total_time + # expected_io_throughput = expected_total_io / expected_total_time + # expected_flops_throughput = expected_total_flops / expected_total_time + # assert summary['total_tokens'] == expected_tokens + # assert summary['total_io'] == expected_total_io + # assert summary['total_flops'] == expected_total_flops + # assert summary['total_time'] == expected_total_time + # assert abs(summary['token_throughput'] - expected_token_throughput) < 1e-1 + # assert abs(summary['io_throughput'] - expected_io_throughput) < 1e-1 + # assert abs(summary['flops_throughput'] - expected_flops_throughput) < 1e-1 + # if device_spec is not None: + # mbu = summary["model_bandwidth_utilization"] + # mfu = summary["model_flops_utilization"] + # expected_mbu = expected_io_throughput / bandwidth + # expected_mfu = expected_flops_throughput / device_spec.flop_per_s + # assert abs(mbu - expected_mbu) < 1e-1 + # assert abs(mfu - expected_mfu) < 1e-1 \ No newline at end of file diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index c32177a64..4442e6ef2 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -7,7 +7,7 @@ from copy import deepcopy from dataclasses import dataclass from functools import partial -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from torch.utils._pytree import tree_map @@ -175,17 +175,15 @@ class PerformanceStats: label: str num_tokens: int elapsed: float - # token_throughput: float - total_flops: float - # flops_throughput: float - total_io: float - # io_throughput: float + total_flops: int + total_io: int summary_flops: Dict[str, int] summary_io: Dict[str, int] flop_counts: Dict[str, Dict[Any, int]] io_counts: Dict[str, Dict[Any, int]] pretty_summary: str - + device_bandwidth: Optional[float] = None + device_flop_per_s: Optional[float] = None @property def token_throughput(self): return self.num_tokens / self.elapsed @@ -198,11 +196,23 @@ def flops_throughput(self): def io_throughput(self): return self.total_io / self.elapsed + @property + def bandwidth_utilization(self): + if self.device_bandwidth is not None: + return self.io_throughput / self.device_bandwidth + else: + print("Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation") + @property + def flops_utilization(self): + if self.device_throughput is not None: + return self.flops_throughput / self.device_flop_per_s + else: + print("Device flop_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") class PerformanceCounterManager: COUNT_KEYS = ["label", "num_tokens", "elapsed", "throughput", "total_flops", "flops_table", "flop_counts"] def __init__(self, depth=10, timer_cls: PerformanceTimer=PerformanceTimer, device_spec: DeviceSpec=None, verbose=False): super().__init__() - self._counts = {} + self._counts: Dict[str, PerformanceStats] = {} self._depth = depth self.timer_cls = timer_cls self.device_spec = device_spec @@ -216,20 +226,19 @@ def count(self, label: str, num_tokens: int): yield self finally: perf_timer.__exit__(None, None, None) - self._counts[label] = {"label": label, - "num_tokens": num_tokens, - "elapsed": perf_timer.elapsed, - "token_throughput": num_tokens / perf_timer.elapsed, - "total_flops": perf_timer.total_flops, - "flops_throughput": perf_timer.total_flops / perf_timer.elapsed, - "total_io": perf_timer.total_io, - "io_throughput": perf_timer.total_io / perf_timer.elapsed, - "summary_flops": perf_timer.get_summary_flop_counts(), - "summary_io": perf_timer.get_summary_io_counts(), - "flop_counts": perf_timer.flop_counts, - "io_counts": perf_timer.io_counts, - "pretty_summary": perf_timer.get_pretty_summary(depth=self._depth), - } + stats = PerformanceStats(label=label, + num_tokens=num_tokens, + elapsed=perf_timer.elapsed, + total_flops=perf_timer.total_flops, + total_io=perf_timer.total_io, + summary_flops=perf_timer.get_summary_flop_counts(), + summary_io=perf_timer.get_summary_io_counts(), + flop_counts=perf_timer.flop_counts, + io_counts=perf_timer.io_counts, + pretty_summary=perf_timer.get_pretty_summary(depth=self._depth), + device_bandwidth=self.device_spec.bandwidth if self.device_spec.bandwidth is not None else None, + device_flop_per_s=self.device_spec.flop_per_s if self.device_spec.flop_per_s is not None else None) + self._counts[label] = stats @property def counts(self): return self._counts From ce2113140fc6bec744f50e09d2a2ed3697008925 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 16:44:58 +0000 Subject: [PATCH 09/33] add stat print str --- test/profiler/test_performance_counter.py | 33 ++++++++--- torchao/profiler/performance_counter.py | 68 ++++++++++++++++++++--- 2 files changed, 86 insertions(+), 15 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 39f3fe0e3..8e4647902 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -15,6 +15,7 @@ CUDAPerformanceTimer, PerformanceCounterManager, PerformanceCounterMode, + PerformanceStats, PerformanceTimer, ) from torchao.utils import TORCH_VERSION_AFTER_2_5 @@ -140,6 +141,24 @@ def device_spec(device_name, bandwidth): device_spec = None return device_spec +TEST_STATS = [("with_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, "1", 1e12, 23e12), + ("no_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, "1", None, None)] +@pytest.mark.parametrize("label, num_tokens, duration, total_flops, total_io, summary_flops, summary_io, flop_counts, io_counts, pretty_summary, device_bandwidth, device_flop_per_s", TEST_STATS) +def test_performance_stats(label, num_tokens, duration, total_flops, total_io, summary_flops, summary_io, flop_counts, io_counts, pretty_summary, device_bandwidth, device_flop_per_s): + stats = PerformanceStats(label=label, + num_tokens=num_tokens, + duration=duration, + total_flops=total_flops, + total_io=total_io, + summary_flops=summary_flops, + summary_io=summary_io, + flop_counts=flop_counts, + io_counts=io_counts, + pretty_summary=pretty_summary, + device_bandwidth=device_bandwidth, + device_flop_per_s=device_flop_per_s) + print(stats) + @pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) @pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @@ -169,28 +188,28 @@ def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): # _ = torch.matmul(a, b) # end = time.perf_counter() - # elapsed = (end - start) + # duration = (end - start) # expected_flops = 2 * num_tokens * in_features * out_features # expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size # assert cm.total_flops == expected_flops # counts = cm.get_counts() # assert "a" in counts - # assert abs(counts['a']['elapsed'] - elapsed) < 1e-1 # +/- 100ms + # assert abs(counts['a']['duration'] - duration) < 1e-1 # +/- 100ms # assert counts['a']['total_flops'] == expected_flops # assert counts['a']['total_io'] == expected_io - # assert counts['a']['token_throughput'] == counts['a']['num_tokens'] / counts['a']['elapsed'] - # assert counts['a']['flops_throughput'] == counts['a']['total_flops'] / counts['a']['elapsed'] - # assert counts['a']['io_throughput'] == counts['a']['total_io'] / counts['a']['elapsed'] + # assert counts['a']['token_throughput'] == counts['a']['num_tokens'] / counts['a']['duration'] + # assert counts['a']['flops_throughput'] == counts['a']['total_flops'] / counts['a']['duration'] + # assert counts['a']['io_throughput'] == counts['a']['total_io'] / counts['a']['duration'] # start = time.perf_counter() # with cm.count("b", num_tokens=num_tokens): # _ = torch.matmul(a, b) # end = time.perf_counter() - # elapsed = end - start + # duration = end - start # assert "a" in cm.counts # assert "b" in cm.counts # counts = cm.counts - # assert abs(counts['b']['elapsed'] - elapsed) < 1e-1 # +/- 100ms + # assert abs(counts['b']['duration'] - duration) < 1e-1 # +/- 100ms # assert counts['b']['total_flops'] == expected_flops # assert counts['b']['total_io'] == expected_io # assert cm.total_flops == 2 * expected_flops diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 4442e6ef2..c8f5dc705 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -110,14 +110,14 @@ def __enter__(self): def _print_exit_msg(self): gflops = round(self.total_flops / 1e9, self.precision) - ms = round(self.elapsed * 1e3, self.precision) + ms = round(self.duration * 1e3, self.precision) if self.display: print(f"{self.name.upper()}: Elapsed = {ms} ms, FLOPS = {gflops} GFLOPs") def __exit__(self, type, value, traceback): self.end = time.perf_counter() #Convert to ms - self.elapsed = (self.end - self.start) + self.duration = (self.end - self.start) self.perf_counter.__exit__(type, value, traceback) if self.display: self._print_exit_msg() @@ -164,17 +164,46 @@ def __exit__(self, type, value, traceback): self.end.record() torch.cuda.synchronize() # Convert from ms to s - self.elapsed = self.start.elapsed_time(self.end) * 1e-3 + self.duration = self.start.elapsed_time(self.end) * 1e-3 self.perf_counter.__exit__(type, value, traceback) if self.display: self._print_exit_msg() +def to_nearest_power_of_10(x, precision=2): + + # Dictionary mapping powers of 10 to their metric abbreviations + metric_units = { + 0: '', + -6: 'µ', + -3: 'm', + 6: 'M', + 9: 'G', + 12: 'T' + } + + # Determine the closest power of 10 + if x == 0: + return f"{x:.{precision}f}" + + power = int(math.floor(math.log10(abs(x)))) + # Adjust power to fit within the given metric units + powers = sorted(metric_units.keys()) + closest_power = min(powers, key=lambda p: abs(p - power)) + + # Calculate the value formatted to the closest power of 10 + value = x / 10**closest_power + + # Map the power to the metric unit + unit = metric_units.get(closest_power, f"e{closest_power}") + + return f"{value:,.{precision}f} {unit}" + @dataclass class PerformanceStats: label: str num_tokens: int - elapsed: float + duration: float total_flops: int total_io: int summary_flops: Dict[str, int] @@ -186,15 +215,15 @@ class PerformanceStats: device_flop_per_s: Optional[float] = None @property def token_throughput(self): - return self.num_tokens / self.elapsed + return self.num_tokens / self.duration @property def flops_throughput(self): - return self.total_flops / self.elapsed + return self.total_flops / self.duration @property def io_throughput(self): - return self.total_io / self.elapsed + return self.total_io / self.duration @property def bandwidth_utilization(self): @@ -202,12 +231,35 @@ def bandwidth_utilization(self): return self.io_throughput / self.device_bandwidth else: print("Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation") + return None @property def flops_utilization(self): - if self.device_throughput is not None: + if self.device_flop_per_s is not None: return self.flops_throughput / self.device_flop_per_s else: print("Device flop_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") + return None + def _format(self, value, suffix): + return to_nearest_power_of_10(value) + suffix + def __str__(self): + txt = textwrap.dedent(f"""\ + {self.label}: + Duration = {self._format(self.duration, "s")} + Tokens + Total: {self.num_tokens} tokens + Throughput: {self.token_throughput:,.0f} tokens/s + IO + Total: {self._format(self.total_io, "B")} + Throughput: {self._format(self.io_throughput, "B/s")} + FLOPs + Total: {self._format(self.total_flops, "FLOPs")} + Throughput: {self._format(self.flops_throughput, "FLOPs/s")}""") + if self.bandwidth_utilization is not None: + txt += "\n" + textwrap.indent("""Utilization:\n""", " " * 2) + txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.1f}%""", " " * 4) + if self.flops_utilization is not None: + txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.1f}%""", " " * 4) + return txt class PerformanceCounterManager: COUNT_KEYS = ["label", "num_tokens", "elapsed", "throughput", "total_flops", "flops_table", "flop_counts"] def __init__(self, depth=10, timer_cls: PerformanceTimer=PerformanceTimer, device_spec: DeviceSpec=None, verbose=False): From 92c4f587edd6b907c00a57eed21edbbe78accd68 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 18:04:34 +0000 Subject: [PATCH 10/33] refactor performance counter with perf stats --- test/profiler/test_performance_counter.py | 13 ++- torchao/profiler/performance_counter.py | 107 ++++++++++++++-------- 2 files changed, 77 insertions(+), 43 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 8e4647902..4a485c0e7 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -141,20 +141,19 @@ def device_spec(device_name, bandwidth): device_spec = None return device_spec -TEST_STATS = [("with_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, "1", 1e12, 23e12), - ("no_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, "1", None, None)] -@pytest.mark.parametrize("label, num_tokens, duration, total_flops, total_io, summary_flops, summary_io, flop_counts, io_counts, pretty_summary, device_bandwidth, device_flop_per_s", TEST_STATS) -def test_performance_stats(label, num_tokens, duration, total_flops, total_io, summary_flops, summary_io, flop_counts, io_counts, pretty_summary, device_bandwidth, device_flop_per_s): +TEST_STATS = [("with_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, 1e12, 23e12), + ("no_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, None, None)] +@pytest.mark.parametrize("label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flop_per_s", TEST_STATS) +def test_performance_stats(label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flop_per_s): stats = PerformanceStats(label=label, num_tokens=num_tokens, duration=duration, total_flops=total_flops, total_io=total_io, - summary_flops=summary_flops, - summary_io=summary_io, + flops_summary=flops_summary, + io_summary=io_summary, flop_counts=flop_counts, io_counts=io_counts, - pretty_summary=pretty_summary, device_bandwidth=device_bandwidth, device_flop_per_s=device_flop_per_s) print(stats) diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index c8f5dc705..5ab91ef73 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -206,11 +206,10 @@ class PerformanceStats: duration: float total_flops: int total_io: int - summary_flops: Dict[str, int] - summary_io: Dict[str, int] + flops_summary: Dict[str, int] + io_summary: Dict[str, int] flop_counts: Dict[str, Dict[Any, int]] io_counts: Dict[str, Dict[Any, int]] - pretty_summary: str device_bandwidth: Optional[float] = None device_flop_per_s: Optional[float] = None @property @@ -254,11 +253,16 @@ def __str__(self): FLOPs Total: {self._format(self.total_flops, "FLOPs")} Throughput: {self._format(self.flops_throughput, "FLOPs/s")}""") + if self.bandwidth_utilization is not None: - txt += "\n" + textwrap.indent("""Utilization:\n""", " " * 2) - txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.1f}%""", " " * 4) + indent_2 = " " * 2 + indent_4 = " " * 4 + txt += "\n" + textwrap.indent("""Utilization:\n""", indent_2) + txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.1f}%""", indent_4) + if self.flops_utilization is not None: - txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.1f}%""", " " * 4) + txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.1f}%""", indent_4) + return txt class PerformanceCounterManager: COUNT_KEYS = ["label", "num_tokens", "elapsed", "throughput", "total_flops", "flops_table", "flop_counts"] @@ -283,11 +287,10 @@ def count(self, label: str, num_tokens: int): elapsed=perf_timer.elapsed, total_flops=perf_timer.total_flops, total_io=perf_timer.total_io, - summary_flops=perf_timer.get_summary_flop_counts(), - summary_io=perf_timer.get_summary_io_counts(), + flops_summary=perf_timer.get_summary_flop_counts(), + io_summary=perf_timer.get_summary_io_counts(), flop_counts=perf_timer.flop_counts, io_counts=perf_timer.io_counts, - pretty_summary=perf_timer.get_pretty_summary(depth=self._depth), device_bandwidth=self.device_spec.bandwidth if self.device_spec.bandwidth is not None else None, device_flop_per_s=self.device_spec.flop_per_s if self.device_spec.flop_per_s is not None else None) self._counts[label] = stats @@ -323,13 +326,44 @@ def to_dict(self): def to_json(self): return json.dumps(self.to_dict(), indent=2) - - def get_summary(self): + + def _summarize(self, key): + return {label: getattr(self._counts[label], key) for label in self._counts.keys()} + + @property + def flops_summary(self): + return self._summarize(key="summary_flops") + + @property + def io_summary(self): + return self._summarize(key="summary_io") + + @property + def flop_counts_summary(self): + return self._summarize(key="flop_counts") + + @property + def io_counts_summary(self): + return self._summarize(key="io_counts") + @property + def stats_summary(self): token_throughput = self.total_tokens / self.total_time io_throughput = self.total_io / self.total_time flops_throughput = self.total_flops / self.total_time achieved_bandwidth = self.total_io / self.total_time achieved_flops_per_s = self.total_flops / self.total_time + + stats = PerformanceStats(label="Performance Summary", + num_tokens=self.total_tokens, + elapsed=self.total_time, + total_flops=self.total_flops, + total_io=self.total_io, + flops_summary=self.flops_summary, + io_summary=self.io_summary, + flop_counts=self.flop_counts_summary, + io_counts=self.io_counts_summary, + device_bandwidth=self.device_spec.bandwidth if self.device_spec.bandwidth is not None else None, + device_flop_per_s=self.device_spec.flop_per_s if self.device_spec.flop_per_s is not None else None) stats = { "total_tokens": self.total_tokens, "total_time": self.total_time, @@ -357,30 +391,30 @@ def get_summary(self): stats.update(device_stats) return stats - def _format_single(self, label, counts, precision, verbose=False): - ms = round(counts['elapsed'] * 1e3, precision) - token_throughput = round(counts['token_throughput'], precision) - gflops = round(counts['total_flops'] / 1e9, precision) - gb = round(counts['total_io'] / 1e9, precision) - flop_throughput = round(gflops / counts['elapsed'], precision) - io_throughput = round(gb / counts['elapsed'], precision) - text = textwrap.dedent(f"""\ - {label.title()}: - Elapsed = {ms:,} ms - Tokens: - Total {counts['num_tokens']} - Throughput {token_throughput} tokens/s - IO: - Total {gb:,} GB - Throughput {io_throughput} GB/s - FLOPs: - Total {gflops:,} GFLOPs, - Throughput {flop_throughput:,} GFLOP/s""") - if verbose: - counts_by_module = counts['pretty_summary'] - text += textwrap.dedent(f"""\nCounts by Module:\n{counts_by_module}""") + # def _format_single(self, label, counts, precision, verbose=False): + # ms = round(counts['elapsed'] * 1e3, precision) + # token_throughput = round(counts['token_throughput'], precision) + # gflops = round(counts['total_flops'] / 1e9, precision) + # gb = round(counts['total_io'] / 1e9, precision) + # flop_throughput = round(gflops / counts['elapsed'], precision) + # io_throughput = round(gb / counts['elapsed'], precision) + # text = textwrap.dedent(f"""\ + # {label.title()}: + # Elapsed = {ms:,} ms + # Tokens: + # Total {counts['num_tokens']} + # Throughput {token_throughput} tokens/s + # IO: + # Total {gb:,} GB + # Throughput {io_throughput} GB/s + # FLOPs: + # Total {gflops:,} GFLOPs, + # Throughput {flop_throughput:,} GFLOP/s""") + # if verbose: + # counts_by_module = counts['pretty_summary'] + # text += textwrap.dedent(f"""\nCounts by Module:\n{counts_by_module}""") - return text + # return text def _format_totals(self, precision=2): ms = round(self.total_time * 1e3, precision) @@ -411,5 +445,6 @@ def print_summary(self, labels: list[str] = None, precision=2, verbose=None): _print(text) else: for label in labels: - text = self._format_single(label, self._counts[label], precision=precision, verbose=verbose) - _print(text) \ No newline at end of file + _print(self._count[label]) + # text = self._format_single(label, self._counts[label], precision=precision, verbose=verbose) + # _print(text) \ No newline at end of file From 016404b528b1a174c574395722f9c4945f629e91 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 18:14:32 +0000 Subject: [PATCH 11/33] more perf stats tests --- test/profiler/test_performance_counter.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 4a485c0e7..64e5ba341 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -156,8 +156,19 @@ def test_performance_stats(label, num_tokens, duration, total_flops, total_io, f io_counts=io_counts, device_bandwidth=device_bandwidth, device_flop_per_s=device_flop_per_s) - print(stats) + # Test derived metrics + assert stats.token_throughput == num_tokens / duration + assert stats.io_throughput == total_io / duration + assert stats.flops_throughput == total_flops / duration + if device_bandwidth is not None: + assert stats.bandwidth_utilization == stats.io_throughput / device_bandwidth + else: + assert stats.bandwidth_utilization is None + if device_flop_per_s is not None: + assert stats.flops_utilization == stats.flops_throughput / device_flop_per_s + else: + assert stats.flops_utilization is None @pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) @pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) From deae11fcebb1bdd152d47605ae2334fa805ef7d1 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 18:27:07 +0000 Subject: [PATCH 12/33] add perf stat print formatting tests --- test/profiler/test_performance_counter.py | 26 ++++++++++++++++++++++- torchao/profiler/performance_counter.py | 4 ++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 64e5ba341..f0d284d10 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -141,7 +141,7 @@ def device_spec(device_name, bandwidth): device_spec = None return device_spec -TEST_STATS = [("with_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, 1e12, 23e12), +TEST_STATS = [("with_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, 1e9, 23e9), ("no_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, None, None)] @pytest.mark.parametrize("label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flop_per_s", TEST_STATS) def test_performance_stats(label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flop_per_s): @@ -169,6 +169,30 @@ def test_performance_stats(label, num_tokens, duration, total_flops, total_io, f assert stats.flops_utilization == stats.flops_throughput / device_flop_per_s else: assert stats.flops_utilization is None + + # Test str - stats should be formatted to closest power of 10 ** 3 with 2 decimal places of precision + stats_str = str(stats) + print(stats_str) + # Base Stats + expected_io_str = ".12 GB" + expected_flops_str = ".12 TFLOPs" + assert expected_io_str in stats_str + assert expected_flops_str in stats_str + + # Derived Stats + expected_io_throughput_str = "1.23 GB/s" + expected_flops_throughput_str = "1.23 TFLOPs/s" + assert expected_io_throughput_str in stats_str + assert expected_flops_throughput_str in stats_str + + # Utilization Stats + if device_bandwidth is not None: + expected_bandwidth_utilization_str = f"{stats.io_throughput / device_bandwidth:.2f}%" + assert expected_bandwidth_utilization_str in stats_str + if device_flop_per_s is not None: + expected_flops_utilization_str = f"{stats.flops_throughput / device_flop_per_s:.2f}%" + assert expected_flops_utilization_str in stats_str + @pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) @pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 5ab91ef73..28a7bca9d 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -258,10 +258,10 @@ def __str__(self): indent_2 = " " * 2 indent_4 = " " * 4 txt += "\n" + textwrap.indent("""Utilization:\n""", indent_2) - txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.1f}%""", indent_4) + txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.2f}%""", indent_4) if self.flops_utilization is not None: - txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.1f}%""", indent_4) + txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.2f}%""", indent_4) return txt class PerformanceCounterManager: From 9cf1200b5a889dbfc16a85f5e4a41a166b9aa5ab Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 19:26:28 +0000 Subject: [PATCH 13/33] fix device spec formatting --- test/profiler/test_performance_counter.py | 81 +++++++++-------- torchao/profiler/device_spec.py | 30 +++---- torchao/profiler/performance_counter.py | 103 ++++++---------------- 3 files changed, 81 insertions(+), 133 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index f0d284d10..35d69774a 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -132,19 +132,10 @@ def patch_device(device_name): with patch("torch.cuda.get_device_name", return_value=device_name): yield -@pytest.fixture -def device_spec(device_name, bandwidth): - if device_name is not None: - with patch_device(device_name): - device_spec = CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) - else: - device_spec = None - return device_spec - TEST_STATS = [("with_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, 1e9, 23e9), ("no_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, None, None)] -@pytest.mark.parametrize("label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flop_per_s", TEST_STATS) -def test_performance_stats(label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flop_per_s): +@pytest.mark.parametrize("label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flops_per_s", TEST_STATS) +def test_performance_stats(label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flops_per_s): stats = PerformanceStats(label=label, num_tokens=num_tokens, duration=duration, @@ -155,7 +146,7 @@ def test_performance_stats(label, num_tokens, duration, total_flops, total_io, f flop_counts=flop_counts, io_counts=io_counts, device_bandwidth=device_bandwidth, - device_flop_per_s=device_flop_per_s) + device_flops_per_s=device_flops_per_s) # Test derived metrics assert stats.token_throughput == num_tokens / duration @@ -165,8 +156,8 @@ def test_performance_stats(label, num_tokens, duration, total_flops, total_io, f assert stats.bandwidth_utilization == stats.io_throughput / device_bandwidth else: assert stats.bandwidth_utilization is None - if device_flop_per_s is not None: - assert stats.flops_utilization == stats.flops_throughput / device_flop_per_s + if device_flops_per_s is not None: + assert stats.flops_utilization == stats.flops_throughput / device_flops_per_s else: assert stats.flops_utilization is None @@ -189,44 +180,52 @@ def test_performance_stats(label, num_tokens, duration, total_flops, total_io, f if device_bandwidth is not None: expected_bandwidth_utilization_str = f"{stats.io_throughput / device_bandwidth:.2f}%" assert expected_bandwidth_utilization_str in stats_str - if device_flop_per_s is not None: - expected_flops_utilization_str = f"{stats.flops_throughput / device_flop_per_s:.2f}%" + if device_flops_per_s is not None: + expected_flops_utilization_str = f"{stats.flops_throughput / device_flops_per_s:.2f}%" assert expected_flops_utilization_str in stats_str + +@pytest.fixture +def device_spec(request): + device_name, bandwidth = request.param + if device_name is not None: + with patch_device(device_name): + return CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) + else: + return None + +@pytest.mark.parametrize("device_spec",[(None, 0), ("A100", 2e12)], indirect=True, ids=lambda p: p[0]) +def test_device_mock(device_spec): + print(device_spec) + @pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) @pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) -@pytest.mark.parametrize("device_spec", [(None, 0), ("A100", 2e12)]) +@pytest.mark.parametrize("device_spec", [(None, 0), ("A100", 2e12)], indirect=True, ids=lambda p: p[0]) def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): print(f"Device Spec: {device_spec}") - # # Set up inputs - # batch_size, query_len, in_features, out_features = shape - # num_tokens = batch_size * query_len - # element_size = dtype.itemsize - # a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") - # b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") + # Set up inputs + batch_size, query_len, in_features, out_features = shape + num_tokens = batch_size * query_len + element_size = dtype.itemsize + a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") - # # Setup device spec - # if device_name is not None: - # with patch_device(device_name): - # device_spec = CUDADeviceSpec(dtype=dtype, bandwidth=bandwidth) - # else: - # device_spec = None - # cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) + cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) - # # Start count - # start = time.perf_counter() - # with cm.count("a", num_tokens=num_tokens): - # _ = torch.matmul(a, b) - # end = time.perf_counter() + # Start count + start = time.perf_counter() + with cm.count("a", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() - # duration = (end - start) - # expected_flops = 2 * num_tokens * in_features * out_features - # expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size - # assert cm.total_flops == expected_flops - # counts = cm.get_counts() + duration = (end - start) + expected_flops = 2 * num_tokens * in_features * out_features + expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size + assert cm.total_flops == expected_flops + counts = cm.get_counts() # assert "a" in counts # assert abs(counts['a']['duration'] - duration) < 1e-1 # +/- 100ms # assert counts['a']['total_flops'] == expected_flops @@ -268,6 +267,6 @@ def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): # mbu = summary["model_bandwidth_utilization"] # mfu = summary["model_flops_utilization"] # expected_mbu = expected_io_throughput / bandwidth - # expected_mfu = expected_flops_throughput / device_spec.flop_per_s + # expected_mfu = expected_flops_throughput / device_spec.flops_per_s # assert abs(mbu - expected_mbu) < 1e-1 # assert abs(mfu - expected_mfu) < 1e-1 \ No newline at end of file diff --git a/torchao/profiler/device_spec.py b/torchao/profiler/device_spec.py index 6d8851d49..08caef853 100644 --- a/torchao/profiler/device_spec.py +++ b/torchao/profiler/device_spec.py @@ -299,7 +299,7 @@ class DeviceSpec: Fields will be auto-populated in __post_init__ if not already specified and if data is available - bandwidth (bytes /s) - - flop_per_s (FLOP / s) + - flops_per_s (FLOP / s) - vram (bytes) - dtype (torch.dtype) dtype used for theoretical peak performance - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOPs @@ -308,7 +308,7 @@ class DeviceSpec: device_type: str name: Optional[str] = None bandwidth: Optional[int] = None - flop_per_s: Optional[int] = None + flops_per_s: Optional[int] = None vram: Optional[int] = None dtype: Optional[torch.dtype] = None flops_by_dtype: dict = field(default_factory=dict) @@ -316,8 +316,8 @@ class DeviceSpec: def _post_init_check(self): assert self.bandwidth is not None, "GPU bandwidth is None - please specify the bandwidth in GB/s in order to enable speed of light calculations" assert self.dtype is not None, "GPU dtype is None - please specify the dtype in order to enable speed of light calculations" - assert self.flop_per_s is not None, "GPU flop_per_s is None - please specify the flop_per_s in FLOP/s in order to enable speed of light calculations" - self.flops_by_dtype.update({self.dtype: self.flop_per_s}) + assert self.flops_per_s is not None, "GPU flops_per_s is None - please specify the flops_per_s in FLOP/s in order to enable speed of light calculations" + self.flops_by_dtype.update({self.dtype: self.flops_per_s}) # Not needed for downstream calculations atm, no need to assert if self.vram is None: @@ -334,12 +334,12 @@ def __setattr__(self, name, value): def __str__(self): if self.bandwidth is not None: - bw = round(self.bandwidth, 4) - if self.flop_per_s is not None: - tflops = round(self.flop_per_s / 1e12, 4) + formatted_bw = f"{self.bandwidth / 1e9:,.1f}GB/s" + if self.flops_per_s is not None: + formatted_flops = f"{self.flops_per_s / 1e12:,.1f}TFLOPs" if self.vram is not None: - vram_GB = round(self.vram / 1e9, 1) - return f"DeviceSpec(device_type={self.device_type}, name={self.name}, dtype={self.dtype}, bandwidth={bw}GB/s, flops={tflops}TFLOPs, vram={vram_GB}GB)" + formatted_vram = f"{self.vram / 1e9:,.1f}GB" + return f"DeviceSpec(device_type={self.device_type}, name={self.name}, dtype={self.dtype}, bandwidth={formatted_bw}, flops={formatted_flops}, vram={formatted_vram})" @property def roofline_balancepoint(self): @@ -353,10 +353,10 @@ def roofline_balancepoint(self): self.bandwidth is not None ), "Please set bandwidth in order to calculate roofline balancepoint" assert ( - self.flop_per_s is not None - ), "Please set flop_per_s in order to calculate roofline balancepoint" + self.flops_per_s is not None + ), "Please set flops_per_s in order to calculate roofline balancepoint" - return self.flop_per_s / self.bandwidth + return self.flops_per_s / self.bandwidth @dataclass @@ -387,7 +387,7 @@ def __post_init__(self): self.bandwidth = get_bandwidth() # FLOPs - if self.flop_per_s is None: + if self.flops_per_s is None: chip_name = get_chip_name(self.device) if chip_name is None: print(f"No FLOPs data available for device name {self.name}") @@ -398,13 +398,13 @@ def __post_init__(self): # Populate flops if not already populated if flops_by_dtype is not None and self.dtype in flops_by_dtype: - self.flop_per_s = flops_by_dtype[self.dtype] + self.flops_per_s = flops_by_dtype[self.dtype] if self.dtype == torch.float32: use_tf32 = "tfloat32" in flops_by_dtype and self.use_tensorcores if use_tf32: - self.flop_per_s = flops_by_dtype["tfloat32"] + self.flops_per_s = flops_by_dtype["tfloat32"] else: print( f"Could not find FLOPs for dtype {self.dtype} for device {self.name}" diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 28a7bca9d..39fa5c8b6 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -5,7 +5,7 @@ from collections import defaultdict from contextlib import contextmanager from copy import deepcopy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from functools import partial from typing import Any, Dict, Optional @@ -112,7 +112,7 @@ def _print_exit_msg(self): gflops = round(self.total_flops / 1e9, self.precision) ms = round(self.duration * 1e3, self.precision) if self.display: - print(f"{self.name.upper()}: Elapsed = {ms} ms, FLOPS = {gflops} GFLOPs") + print(f"{self.name.upper()}: duration = {ms} ms, FLOPS = {gflops} GFLOPs") def __exit__(self, type, value, traceback): self.end = time.perf_counter() @@ -211,32 +211,32 @@ class PerformanceStats: flop_counts: Dict[str, Dict[Any, int]] io_counts: Dict[str, Dict[Any, int]] device_bandwidth: Optional[float] = None - device_flop_per_s: Optional[float] = None + device_flops_per_s: Optional[float] = None @property def token_throughput(self): return self.num_tokens / self.duration @property - def flops_throughput(self): + def achieved_flops_per_s(self): return self.total_flops / self.duration @property - def io_throughput(self): + def achieved_bandwidth(self): return self.total_io / self.duration @property def bandwidth_utilization(self): if self.device_bandwidth is not None: - return self.io_throughput / self.device_bandwidth + return self.achieved_bandwidth / self.device_bandwidth else: print("Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation") return None @property def flops_utilization(self): - if self.device_flop_per_s is not None: - return self.flops_throughput / self.device_flop_per_s + if self.device_flops_per_s is not None: + return self.achieved_flops_per_s / self.device_flops_per_s else: - print("Device flop_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") + print("Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") return None def _format(self, value, suffix): return to_nearest_power_of_10(value) + suffix @@ -254,9 +254,9 @@ def __str__(self): Total: {self._format(self.total_flops, "FLOPs")} Throughput: {self._format(self.flops_throughput, "FLOPs/s")}""") + indent_2 = " " * 2 + indent_4 = " " * 4 if self.bandwidth_utilization is not None: - indent_2 = " " * 2 - indent_4 = " " * 4 txt += "\n" + textwrap.indent("""Utilization:\n""", indent_2) txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.2f}%""", indent_4) @@ -264,8 +264,11 @@ def __str__(self): txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.2f}%""", indent_4) return txt + + def to_dict(self): + return asdict(self) + class PerformanceCounterManager: - COUNT_KEYS = ["label", "num_tokens", "elapsed", "throughput", "total_flops", "flops_table", "flop_counts"] def __init__(self, depth=10, timer_cls: PerformanceTimer=PerformanceTimer, device_spec: DeviceSpec=None, verbose=False): super().__init__() self._counts: Dict[str, PerformanceStats] = {} @@ -284,7 +287,7 @@ def count(self, label: str, num_tokens: int): perf_timer.__exit__(None, None, None) stats = PerformanceStats(label=label, num_tokens=num_tokens, - elapsed=perf_timer.elapsed, + duration=perf_timer.duration, total_flops=perf_timer.total_flops, total_io=perf_timer.total_io, flops_summary=perf_timer.get_summary_flop_counts(), @@ -292,7 +295,7 @@ def count(self, label: str, num_tokens: int): flop_counts=perf_timer.flop_counts, io_counts=perf_timer.io_counts, device_bandwidth=self.device_spec.bandwidth if self.device_spec.bandwidth is not None else None, - device_flop_per_s=self.device_spec.flop_per_s if self.device_spec.flop_per_s is not None else None) + device_flops_per_s=self.device_spec.flops_per_s if self.device_spec.flops_per_s is not None else None) self._counts[label] = stats @property def counts(self): @@ -313,7 +316,7 @@ def total_tokens(self): @property def total_time(self): - return sum(count["elapsed"] for count in self._counts.values()) + return sum(count["duration"] for count in self._counts.values()) def to_dict(self): # Convert flop_counts from OpOverloadPackets to str @@ -327,35 +330,29 @@ def to_dict(self): def to_json(self): return json.dumps(self.to_dict(), indent=2) - def _summarize(self, key): + def _summarize_stat(self, key): return {label: getattr(self._counts[label], key) for label in self._counts.keys()} @property def flops_summary(self): - return self._summarize(key="summary_flops") + return self._summarize_stat(key="summary_flops") @property def io_summary(self): - return self._summarize(key="summary_io") + return self._summarize_stat(key="summary_io") @property def flop_counts_summary(self): - return self._summarize(key="flop_counts") + return self._summarize_stat(key="flop_counts") @property def io_counts_summary(self): - return self._summarize(key="io_counts") + return self._summarize_stat(key="io_counts") @property def stats_summary(self): - token_throughput = self.total_tokens / self.total_time - io_throughput = self.total_io / self.total_time - flops_throughput = self.total_flops / self.total_time - achieved_bandwidth = self.total_io / self.total_time - achieved_flops_per_s = self.total_flops / self.total_time - stats = PerformanceStats(label="Performance Summary", num_tokens=self.total_tokens, - elapsed=self.total_time, + duration=self.total_time, total_flops=self.total_flops, total_io=self.total_io, flops_summary=self.flops_summary, @@ -363,58 +360,10 @@ def stats_summary(self): flop_counts=self.flop_counts_summary, io_counts=self.io_counts_summary, device_bandwidth=self.device_spec.bandwidth if self.device_spec.bandwidth is not None else None, - device_flop_per_s=self.device_spec.flop_per_s if self.device_spec.flop_per_s is not None else None) - stats = { - "total_tokens": self.total_tokens, - "total_time": self.total_time, - "total_flops": self.total_flops, - "total_io": self.total_io, - "token_throughput": token_throughput, - "io_throughput": io_throughput, - "flops_throughput": flops_throughput, - "achieved_bandwidth": achieved_bandwidth, - "achieved_flops_per_s": achieved_flops_per_s, - "arithmetic_intensity": self.total_flops / self.total_io - } - device_spec = self.device_spec - if device_spec is not None: - theoretical_bandwidth = device_spec.bandwidth - theoretical_flop_per_s = device_spec.flop_per_s - - device_stats = { - "device_name": device_spec.name, - "theoretical_bandwidth": theoretical_bandwidth, - "theoretical_throughput": theoretical_flop_per_s, - "model_bandwidth_utilization": achieved_bandwidth / theoretical_bandwidth, - "model_flops_utilization": achieved_flops_per_s / theoretical_flop_per_s, - } - stats.update(device_stats) + device_flops_per_s=self.device_spec.flops_per_s if self.device_spec.flops_per_s is not None else None) + return stats - # def _format_single(self, label, counts, precision, verbose=False): - # ms = round(counts['elapsed'] * 1e3, precision) - # token_throughput = round(counts['token_throughput'], precision) - # gflops = round(counts['total_flops'] / 1e9, precision) - # gb = round(counts['total_io'] / 1e9, precision) - # flop_throughput = round(gflops / counts['elapsed'], precision) - # io_throughput = round(gb / counts['elapsed'], precision) - # text = textwrap.dedent(f"""\ - # {label.title()}: - # Elapsed = {ms:,} ms - # Tokens: - # Total {counts['num_tokens']} - # Throughput {token_throughput} tokens/s - # IO: - # Total {gb:,} GB - # Throughput {io_throughput} GB/s - # FLOPs: - # Total {gflops:,} GFLOPs, - # Throughput {flop_throughput:,} GFLOP/s""") - # if verbose: - # counts_by_module = counts['pretty_summary'] - # text += textwrap.dedent(f"""\nCounts by Module:\n{counts_by_module}""") - - # return text def _format_totals(self, precision=2): ms = round(self.total_time * 1e3, precision) From 19a6a70bb5a05c57706a92b3acd96297c78370e9 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 8 Jul 2024 23:58:15 +0000 Subject: [PATCH 14/33] finish perf counter manager refactor --- test/profiler/test_performance_counter.py | 124 +++++++++++++--------- torchao/profiler/device_spec.py | 13 --- torchao/profiler/performance_counter.py | 56 +++++----- 3 files changed, 105 insertions(+), 88 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 35d69774a..39aa0fd8c 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -193,26 +193,46 @@ def device_spec(request): return CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) else: return None +@pytest.fixture +def performance_counter_manager(device_spec, request): + shape, timer_cls, dtype = request.param + batch_size, query_len, in_features, out_features = shape + num_tokens = batch_size * query_len + element_size = dtype.itemsize + a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") + + cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) + + # Start count + start = time.perf_counter() + with cm.count("a", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + + duration_a = (end - start) + expected_flops = 2 * num_tokens * in_features * out_features + expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size -@pytest.mark.parametrize("device_spec",[(None, 0), ("A100", 2e12)], indirect=True, ids=lambda p: p[0]) -def test_device_mock(device_spec): - print(device_spec) + start = time.perf_counter() + with cm.count("b", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + duration_b = end - start @pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) @pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("device_spec", [(None, 0), ("A100", 2e12)], indirect=True, ids=lambda p: p[0]) def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): - print(f"Device Spec: {device_spec}") - + FLOAT_TOL = 1e-5 # Set up inputs batch_size, query_len, in_features, out_features = shape num_tokens = batch_size * query_len element_size = dtype.itemsize a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") - - + cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) # Start count @@ -225,48 +245,54 @@ def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): expected_flops = 2 * num_tokens * in_features * out_features expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size assert cm.total_flops == expected_flops + assert cm.total_io == expected_io counts = cm.get_counts() - # assert "a" in counts - # assert abs(counts['a']['duration'] - duration) < 1e-1 # +/- 100ms - # assert counts['a']['total_flops'] == expected_flops - # assert counts['a']['total_io'] == expected_io - # assert counts['a']['token_throughput'] == counts['a']['num_tokens'] / counts['a']['duration'] - # assert counts['a']['flops_throughput'] == counts['a']['total_flops'] / counts['a']['duration'] - # assert counts['a']['io_throughput'] == counts['a']['total_io'] / counts['a']['duration'] + assert "a" in counts + # Check captured performance stats + psa: PerformanceStats = counts["a"] + # Raw metrics + assert abs(psa.duration - duration) < 1e-1 # +/- 100ms + assert psa.total_flops == expected_flops + assert psa.total_io == expected_io + # Derived metrics + assert psa.token_throughput == psa.num_tokens / psa.duration + assert psa.achieved_flops_per_s == psa.total_flops / psa.duration + assert psa.achieved_bandwidth == psa.total_io / psa.duration - # start = time.perf_counter() - # with cm.count("b", num_tokens=num_tokens): - # _ = torch.matmul(a, b) - # end = time.perf_counter() - # duration = end - start - # assert "a" in cm.counts - # assert "b" in cm.counts - # counts = cm.counts - # assert abs(counts['b']['duration'] - duration) < 1e-1 # +/- 100ms - # assert counts['b']['total_flops'] == expected_flops - # assert counts['b']['total_io'] == expected_io - # assert cm.total_flops == 2 * expected_flops - # assert cm.total_io == 2 * expected_io + start = time.perf_counter() + with cm.count("b", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + duration = end - start + assert "b" in cm.counts + psb = cm.counts["b"] + assert abs(psb.duration - duration) < 1e-1 # +/- 100ms + assert psb.total_flops == expected_flops + assert psb.total_io == expected_io + + # Test that total properties account for both a and b + assert cm.total_time == psa.duration + psb.duration + assert cm.total_flops == 2 * expected_flops + assert cm.total_io == 2 * expected_io + + # Test stats_summary property, which returns a new PerformanceStats object with accumulated stats + summary: PerformanceStats = cm.stats_summary + # Raw stats + assert summary.num_tokens == psa.num_tokens + psb.num_tokens + assert summary.total_io == psa.total_io + psb.total_io + assert summary.total_flops == psa.total_flops + psb.total_flops + assert summary.duration == psa.duration + psb.duration + + # Derived stats + expected_token_throughput = (psa.num_tokens + psb.num_tokens) / (psa.duration + psb.duration) + expected_io_throughput = (psa.total_io + psb.total_io) / (psa.duration + psb.duration) + expected_flops_throughput = (psa.total_flops + psb.total_flops) / (psa.duration + psb.duration) + assert abs(summary.token_throughput - expected_token_throughput) < FLOAT_TOL + assert abs(summary.achieved_bandwidth - expected_io_throughput) < FLOAT_TOL + assert abs(summary.achieved_flops_per_s - expected_flops_throughput) < FLOAT_TOL - # summary = cm.get_summary() - # expected_tokens = 2 * num_tokens - # expected_total_flops = 2 * expected_flops - # expected_total_io = 2 * expected_io - # expected_total_time = cm.total_time - # expected_token_throughput = expected_tokens / expected_total_time - # expected_io_throughput = expected_total_io / expected_total_time - # expected_flops_throughput = expected_total_flops / expected_total_time - # assert summary['total_tokens'] == expected_tokens - # assert summary['total_io'] == expected_total_io - # assert summary['total_flops'] == expected_total_flops - # assert summary['total_time'] == expected_total_time - # assert abs(summary['token_throughput'] - expected_token_throughput) < 1e-1 - # assert abs(summary['io_throughput'] - expected_io_throughput) < 1e-1 - # assert abs(summary['flops_throughput'] - expected_flops_throughput) < 1e-1 - # if device_spec is not None: - # mbu = summary["model_bandwidth_utilization"] - # mfu = summary["model_flops_utilization"] - # expected_mbu = expected_io_throughput / bandwidth - # expected_mfu = expected_flops_throughput / device_spec.flops_per_s - # assert abs(mbu - expected_mbu) < 1e-1 - # assert abs(mfu - expected_mfu) < 1e-1 \ No newline at end of file + if device_spec is not None: + expected_bandwidth_utilization = expected_io_throughput / device_spec.bandwidth + expected_flops_utilization = expected_flops_throughput / device_spec.flops_per_s + assert abs(summary.bandwidth_utilization - expected_bandwidth_utilization) < FLOAT_TOL + assert abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL \ No newline at end of file diff --git a/torchao/profiler/device_spec.py b/torchao/profiler/device_spec.py index 08caef853..7c40c73e9 100644 --- a/torchao/profiler/device_spec.py +++ b/torchao/profiler/device_spec.py @@ -277,19 +277,6 @@ def get_bandwidth(device: int = 0) -> int: def get_flops_by_dtype(chip_name: str) -> dict[torch.dtype, float]: return _AVAILABLE_GPU_SPECS.get(chip_name, None) - # # Check for tfloat32 - # if ( - # dtype == torch.float32 - # and "tfloat32" in dtype_to_flops - # and torch.get_float32_matmul_precision() != "highest" - # ): - # logger.warning("Using tfloat32 tensorcores FLOPs") - # dtype = "tfloat32" - # if dtype not in dtype_to_flops: - # logger.warning(f"FLOPs not found for {dtype!r} on {chip!r}") - # return None - # return dtype_to_flops[dtype] - @dataclass class DeviceSpec: diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 39fa5c8b6..7968d0505 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -199,8 +199,14 @@ def to_nearest_power_of_10(x, precision=2): return f"{value:,.{precision}f} {unit}" +class DictGetter: + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(key) @dataclass -class PerformanceStats: +class PerformanceStats(DictGetter): label: str num_tokens: int duration: float @@ -294,8 +300,8 @@ def count(self, label: str, num_tokens: int): io_summary=perf_timer.get_summary_io_counts(), flop_counts=perf_timer.flop_counts, io_counts=perf_timer.io_counts, - device_bandwidth=self.device_spec.bandwidth if self.device_spec.bandwidth is not None else None, - device_flops_per_s=self.device_spec.flops_per_s if self.device_spec.flops_per_s is not None else None) + device_bandwidth=self.device_spec.bandwidth if self.device_spec is not None else None, + device_flops_per_s=self.device_spec.flops_per_s if self.device_spec is not None else None) self._counts[label] = stats @property def counts(self): @@ -305,41 +311,29 @@ def get_counts(self): @property def total_flops(self): - return sum(count["total_flops"] for count in self._counts.values()) + return sum(count.total_flops for count in self._counts.values()) @property def total_io(self): - return sum(count["total_io"] for count in self._counts.values()) + return sum(count.total_io for count in self._counts.values()) @property def total_tokens(self): - return sum(count["num_tokens"] for count in self._counts.values()) + return sum(count.num_tokens for count in self._counts.values()) @property def total_time(self): - return sum(count["duration"] for count in self._counts.values()) - - def to_dict(self): - # Convert flop_counts from OpOverloadPackets to str - counts = deepcopy(self._counts) - for label,label_counts in counts.items(): - counts[label]['flop_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['flop_counts'].items()} - counts[label]['io_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['io_counts'].items()} - - return counts - - def to_json(self): - return json.dumps(self.to_dict(), indent=2) + return sum(count.duration for count in self._counts.values()) def _summarize_stat(self, key): return {label: getattr(self._counts[label], key) for label in self._counts.keys()} @property def flops_summary(self): - return self._summarize_stat(key="summary_flops") + return self._summarize_stat(key="flops_summary") @property def io_summary(self): - return self._summarize_stat(key="summary_io") + return self._summarize_stat(key="io_summary") @property def flop_counts_summary(self): @@ -359,12 +353,11 @@ def stats_summary(self): io_summary=self.io_summary, flop_counts=self.flop_counts_summary, io_counts=self.io_counts_summary, - device_bandwidth=self.device_spec.bandwidth if self.device_spec.bandwidth is not None else None, - device_flops_per_s=self.device_spec.flops_per_s if self.device_spec.flops_per_s is not None else None) + device_bandwidth=self.device_spec.bandwidth if self.device_spec is not None else None, + device_flops_per_s=self.device_spec.flops_per_s if self.device_spec is not None else None) return stats - def _format_totals(self, precision=2): ms = round(self.total_time * 1e3, precision) token_throughput = round(self.total_tokens / self.total_time, precision) @@ -394,6 +387,17 @@ def print_summary(self, labels: list[str] = None, precision=2, verbose=None): _print(text) else: for label in labels: + text = str(self._count[label]) # delegate to __str__ of PerformanceStats object _print(self._count[label]) - # text = self._format_single(label, self._counts[label], precision=precision, verbose=verbose) - # _print(text) \ No newline at end of file + + def to_dict(self): + # Convert flop_counts from OpOverloadPackets to str + counts = deepcopy(self._counts) + for label,label_counts in counts.items(): + counts[label]['flop_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['flop_counts'].items()} + counts[label]['io_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['io_counts'].items()} + + return counts + + def to_json(self): + return json.dumps(self.to_dict(), indent=2) \ No newline at end of file From 22b1cabe7f5dd2e45ef8140a90c3dc1cf549ad0a Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 9 Jul 2024 00:23:22 +0000 Subject: [PATCH 15/33] add serialization test --- test/profiler/test_performance_counter.py | 54 ++++++++++++++++++++++- torchao/profiler/performance_counter.py | 41 +++++++++++++---- 2 files changed, 85 insertions(+), 10 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 39aa0fd8c..136b21324 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -4,6 +4,8 @@ transformers = pytest.importorskip("transformers") LlamaConfig = transformers.models.llama.modeling_llama.LlamaConfig LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM + +import json import time from contextlib import contextmanager from unittest.mock import patch @@ -224,7 +226,7 @@ def performance_counter_manager(device_spec, request): @pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("device_spec", [(None, 0), ("A100", 2e12)], indirect=True, ids=lambda p: p[0]) -def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): +def test_performance_counter_manager(shape, timer_cls, dtype, device_spec, tmpdir): FLOAT_TOL = 1e-5 # Set up inputs batch_size, query_len, in_features, out_features = shape @@ -295,4 +297,52 @@ def test_performance_counter_manager(shape, timer_cls, dtype, device_spec): expected_bandwidth_utilization = expected_io_throughput / device_spec.bandwidth expected_flops_utilization = expected_flops_throughput / device_spec.flops_per_s assert abs(summary.bandwidth_utilization - expected_bandwidth_utilization) < FLOAT_TOL - assert abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL \ No newline at end of file + assert abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL + else: + assert summary.bandwidth_utilization is None + assert summary.flops_utilization is None + + # Test json serialization + temp_path = tmpdir.mkdir('test_dir').join('performance_counter_manager.json') + with open(temp_path, "w") as f: + f.write(cm.to_json()) + with open(temp_path, 'r') as f: + perf_dict = json.load(f) + assert 'a' in perf_dict + assert 'b' in perf_dict + + #Test basic stats are recorded properly + assert perf_dict['a']['num_tokens'] == psa.num_tokens + assert perf_dict['a']['total_io'] == psa.total_io + assert perf_dict['a']['total_flops'] == psa.total_flops + assert perf_dict['a']['duration'] == psa.duration + + assert perf_dict['b']['num_tokens'] == psb.num_tokens + assert perf_dict['b']['total_io'] == psb.total_io + assert perf_dict['b']['total_flops'] == psb.total_flops + assert perf_dict['b']['duration'] == psb.duration + + # Test derived properties are present + perf_dict['a']['achieved_flops_per_s'] == psa.achieved_flops_per_s + perf_dict['a']['achieved_bandwidth'] == psa.achieved_bandwidth + perf_dict['b']['achieved_flops_per_s'] == psb.achieved_flops_per_s + perf_dict['b']['achieved_bandwidth'] == psb.achieved_bandwidth + + if device_spec is not None: + assert perf_dict['a']['device_flops_per_s'] == device_spec.flops_per_s + assert perf_dict['a']['device_bandwidth'] == device_spec.bandwidth + assert perf_dict['a']['bandwidth_utilization'] == psa.bandwidth_utilization + assert perf_dict['a']['flops_utilization'] == psa.flops_utilization + assert perf_dict['b']['device_flops_per_s'] == device_spec.flops_per_s + assert perf_dict['b']['device_bandwidth'] == device_spec.bandwidth + assert perf_dict['b']['bandwidth_utilization'] == psb.bandwidth_utilization + assert perf_dict['b']['flops_utilization'] == psb.flops_utilization + else: + assert perf_dict['a']['device_flops_per_s'] is None + assert perf_dict['a']['device_bandwidth'] is None + assert perf_dict['a']['bandwidth_utilization'] is None + assert perf_dict['a']['flops_utilization'] is None + assert perf_dict['b']['device_flops_per_s'] is None + assert perf_dict['b']['device_bandwidth'] is None + assert perf_dict['b']['bandwidth_utilization'] is None + assert perf_dict['b']['flops_utilization'] is None \ No newline at end of file diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 7968d0505..b9d0d36e0 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -1,3 +1,4 @@ +import inspect import json import math import textwrap @@ -199,14 +200,31 @@ def to_nearest_power_of_10(x, precision=2): return f"{value:,.{precision}f} {unit}" -class DictGetter: +class DictMixin: + """ + Mixin to enable dict-like access to dataclass attributes + """ def __getitem__(self, key): if hasattr(self, key): return getattr(self, key) else: raise KeyError(key) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __contains__(self, key): + return hasattr(self, key) + + def __iter__(self): + for key in self.__dict__: + yield key + +# Function to get all property methods of a class +def get_property_methods(cls): + return [name for name, member in inspect.getmembers(cls, lambda m: isinstance(m, property))] @dataclass -class PerformanceStats(DictGetter): +class PerformanceStats(DictMixin): label: str num_tokens: int duration: float @@ -272,7 +290,12 @@ def __str__(self): return txt def to_dict(self): - return asdict(self) + d = asdict(self) + # Update dict with properties + props = get_property_methods(self.__class__) + d.update({prop: getattr(self, prop) for prop in props}) + + return d class PerformanceCounterManager: def __init__(self, depth=10, timer_cls: PerformanceTimer=PerformanceTimer, device_spec: DeviceSpec=None, verbose=False): @@ -379,24 +402,26 @@ def _format_totals(self, precision=2): Throughput {flop_throughput:,} GFLOP/s""") return text - def print_summary(self, labels: list[str] = None, precision=2, verbose=None): - verbose = verbose if verbose is not None else self.verbose + def print_summary(self, labels: list[str] = None): _print = partial(print, flush=True, end='\n') + # Delegate to __str__ of PerformanceStats for pretty printing if labels is None: - text = self._format_totals(precision=precision) + text = str(self.stats_summary) _print(text) else: for label in labels: - text = str(self._count[label]) # delegate to __str__ of PerformanceStats object + text = str(self._count[label]) _print(self._count[label]) def to_dict(self): # Convert flop_counts from OpOverloadPackets to str + # Then delegate to PerformanceStats `to_dict`, which updates with derived metrics (property methods) counts = deepcopy(self._counts) for label,label_counts in counts.items(): counts[label]['flop_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['flop_counts'].items()} counts[label]['io_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['io_counts'].items()} - + counts[label] = counts[label].to_dict() + return counts def to_json(self): From cc3c73c4e24c51f7da08c74bb92732f6b299e0c8 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 9 Jul 2024 17:37:55 +0000 Subject: [PATCH 16/33] refactor stats tests --- test/profiler/test_performance_counter.py | 413 ++++++++++++---------- test/profiler/utils.py | 49 +++ torchao/profiler/performance_counter.py | 15 +- 3 files changed, 293 insertions(+), 184 deletions(-) create mode 100644 test/profiler/utils.py diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 136b21324..3bbaa2a42 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -6,13 +6,26 @@ LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM import json +import tempfile import time +import unittest from contextlib import contextmanager +from dataclasses import asdict +from pathlib import Path +from typing import Union from unittest.mock import patch import torch +from parameterized import parameterized_class +from utils import ( + PerfCounterResult, + PerfCounterTestConfig, + PerfStatsTestConfig, + get_test_name, + patch_device, +) -from torchao.profiler.device_spec import CUDADeviceSpec +from torchao.profiler.device_spec import CUDADeviceSpec, DeviceSpec from torchao.profiler.performance_counter import ( CUDAPerformanceTimer, PerformanceCounterManager, @@ -129,37 +142,50 @@ def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, expected_size = ffn_io_check(model_config, batch_size, seqlen, element_size, k) assert expected_size == summary_io[proj_keys[0]] -@contextmanager -def patch_device(device_name): - with patch("torch.cuda.get_device_name", return_value=device_name): - yield -TEST_STATS = [("with_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, 1e9, 23e9), - ("no_device", 128, 0.1, 123e9, 123e6, {"a": 234e12, "b": 345e9}, 1, {"a": 1, "b": 2}, {"a": 1, "b": 2}, None, None)] -@pytest.mark.parametrize("label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flops_per_s", TEST_STATS) -def test_performance_stats(label, num_tokens, duration, total_flops, total_io, flops_summary, io_summary, flop_counts, io_counts, device_bandwidth, device_flops_per_s): - stats = PerformanceStats(label=label, - num_tokens=num_tokens, - duration=duration, - total_flops=total_flops, - total_io=total_io, - flops_summary=flops_summary, - io_summary=io_summary, - flop_counts=flop_counts, - io_counts=io_counts, - device_bandwidth=device_bandwidth, - device_flops_per_s=device_flops_per_s) + +PERFSTATS_TEST_CONFIGS = [PerfStatsTestConfig(label="with_device", + num_tokens=128, + duration=0.1, + total_flops=123e9, + total_io=123e6, + flops_summary={"a": 234e12, "b": 345e9}, + io_summary={"a": 1, "b": 2}, + flop_counts={"a": 234e12, "b": 345e9}, + io_counts={"a": 1, "b": 2}, + device_bandwidth=1e9, + device_flops_per_s=23e9), + PerfStatsTestConfig(label="no_device", + num_tokens=128, + duration=0.1, + total_flops=123e9, + total_io=123e6, + flops_summary={"a": 234e12, "b": 345e9}, + io_summary={"a": 1, "b": 2}, + flop_counts={"a": 234e12, "b": 345e9}, + io_counts={"a": 1, "b": 2}, + device_bandwidth=None, + device_flops_per_s=None)] +@pytest.mark.parametrize("cfg", PERFSTATS_TEST_CONFIGS, ids=lambda cfg: cfg.label) +def test_performance_stats(cfg: PerfStatsTestConfig): + stats = PerformanceStats(**asdict(cfg)) + num_tokens = cfg.num_tokens + duration = cfg.duration + total_flops = cfg.total_flops + total_io = cfg.total_io + device_bandwidth = cfg.device_bandwidth + device_flops_per_s = cfg.device_flops_per_s # Test derived metrics assert stats.token_throughput == num_tokens / duration - assert stats.io_throughput == total_io / duration - assert stats.flops_throughput == total_flops / duration + assert stats.achieved_bandwidth == total_io / duration + assert stats.achieved_flops_per_s == total_flops / duration if device_bandwidth is not None: - assert stats.bandwidth_utilization == stats.io_throughput / device_bandwidth + assert stats.bandwidth_utilization == stats.achieved_bandwidth / device_bandwidth else: assert stats.bandwidth_utilization is None if device_flops_per_s is not None: - assert stats.flops_utilization == stats.flops_throughput / device_flops_per_s + assert stats.flops_utilization == stats.achieved_flops_per_s / device_flops_per_s else: assert stats.flops_utilization is None @@ -180,169 +206,198 @@ def test_performance_stats(label, num_tokens, duration, total_flops, total_io, f # Utilization Stats if device_bandwidth is not None: - expected_bandwidth_utilization_str = f"{stats.io_throughput / device_bandwidth:.2f}%" + expected_bandwidth_utilization_str = f"{stats.achieved_bandwidth / device_bandwidth:.2f}%" assert expected_bandwidth_utilization_str in stats_str if device_flops_per_s is not None: - expected_flops_utilization_str = f"{stats.flops_throughput / device_flops_per_s:.2f}%" + expected_flops_utilization_str = f"{stats.achieved_flops_per_s / device_flops_per_s:.2f}%" assert expected_flops_utilization_str in stats_str -@pytest.fixture -def device_spec(request): - device_name, bandwidth = request.param - if device_name is not None: - with patch_device(device_name): - return CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) - else: - return None -@pytest.fixture -def performance_counter_manager(device_spec, request): - shape, timer_cls, dtype = request.param - batch_size, query_len, in_features, out_features = shape - num_tokens = batch_size * query_len - element_size = dtype.itemsize - a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") - b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") - - cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) - - # Start count - start = time.perf_counter() - with cm.count("a", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() +# @pytest.fixture +# def device_spec(request): +# device_name, bandwidth = request.param +# if device_name is not None: +# with patch_device(device_name): +# return CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) +# else: +# return None - duration_a = (end - start) - expected_flops = 2 * num_tokens * in_features * out_features - expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size - start = time.perf_counter() - with cm.count("b", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() - duration_b = end - start +PERFCOUNTERMANAGER_TEST_CONFIGS = [PerfCounterTestConfig("no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0)), + PerfCounterTestConfig("a100", (1, 1024, 4096, 4096), CUDAPerformanceTimer, torch.bfloat16, ("A100", 2e12))] -@pytest.mark.parametrize("shape", [(1, 1024, 4096, 4096), (128, 1, 1024, 4096)], ids=lambda p: ",".join(map(str, p))) -@pytest.mark.parametrize("timer_cls", [PerformanceTimer, CUDAPerformanceTimer], ids=lambda p: p.__name__) -@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) -@pytest.mark.parametrize("device_spec", [(None, 0), ("A100", 2e12)], indirect=True, ids=lambda p: p[0]) -def test_performance_counter_manager(shape, timer_cls, dtype, device_spec, tmpdir): - FLOAT_TOL = 1e-5 - # Set up inputs - batch_size, query_len, in_features, out_features = shape - num_tokens = batch_size * query_len - element_size = dtype.itemsize - a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") - b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") - - cm = PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) - - # Start count - start = time.perf_counter() - with cm.count("a", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() - - duration = (end - start) - expected_flops = 2 * num_tokens * in_features * out_features - expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size - assert cm.total_flops == expected_flops - assert cm.total_io == expected_io - counts = cm.get_counts() - assert "a" in counts - # Check captured performance stats - psa: PerformanceStats = counts["a"] - # Raw metrics - assert abs(psa.duration - duration) < 1e-1 # +/- 100ms - assert psa.total_flops == expected_flops - assert psa.total_io == expected_io - # Derived metrics - assert psa.token_throughput == psa.num_tokens / psa.duration - assert psa.achieved_flops_per_s == psa.total_flops / psa.duration - assert psa.achieved_bandwidth == psa.total_io / psa.duration - - start = time.perf_counter() - with cm.count("b", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() - duration = end - start - assert "b" in cm.counts - psb = cm.counts["b"] - assert abs(psb.duration - duration) < 1e-1 # +/- 100ms - assert psb.total_flops == expected_flops - assert psb.total_io == expected_io - - # Test that total properties account for both a and b - assert cm.total_time == psa.duration + psb.duration - assert cm.total_flops == 2 * expected_flops - assert cm.total_io == 2 * expected_io - - # Test stats_summary property, which returns a new PerformanceStats object with accumulated stats - summary: PerformanceStats = cm.stats_summary - # Raw stats - assert summary.num_tokens == psa.num_tokens + psb.num_tokens - assert summary.total_io == psa.total_io + psb.total_io - assert summary.total_flops == psa.total_flops + psb.total_flops - assert summary.duration == psa.duration + psb.duration - # Derived stats - expected_token_throughput = (psa.num_tokens + psb.num_tokens) / (psa.duration + psb.duration) - expected_io_throughput = (psa.total_io + psb.total_io) / (psa.duration + psb.duration) - expected_flops_throughput = (psa.total_flops + psb.total_flops) / (psa.duration + psb.duration) - assert abs(summary.token_throughput - expected_token_throughput) < FLOAT_TOL - assert abs(summary.achieved_bandwidth - expected_io_throughput) < FLOAT_TOL - assert abs(summary.achieved_flops_per_s - expected_flops_throughput) < FLOAT_TOL - - if device_spec is not None: - expected_bandwidth_utilization = expected_io_throughput / device_spec.bandwidth - expected_flops_utilization = expected_flops_throughput / device_spec.flops_per_s - assert abs(summary.bandwidth_utilization - expected_bandwidth_utilization) < FLOAT_TOL - assert abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL - else: - assert summary.bandwidth_utilization is None - assert summary.flops_utilization is None - - # Test json serialization - temp_path = tmpdir.mkdir('test_dir').join('performance_counter_manager.json') - with open(temp_path, "w") as f: - f.write(cm.to_json()) - with open(temp_path, 'r') as f: - perf_dict = json.load(f) - assert 'a' in perf_dict - assert 'b' in perf_dict - - #Test basic stats are recorded properly - assert perf_dict['a']['num_tokens'] == psa.num_tokens - assert perf_dict['a']['total_io'] == psa.total_io - assert perf_dict['a']['total_flops'] == psa.total_flops - assert perf_dict['a']['duration'] == psa.duration +@parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], class_name_func=get_test_name) +class TestPerformanceCounterManager(unittest.TestCase): + @classmethod + def setUpClass(cls): + shape, timer_cls, dtype = cls.shape, cls.timer_cls, cls.dtype + batch_size, query_len, in_features, out_features = shape + num_tokens = batch_size * query_len + element_size = dtype.itemsize + a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") + # Set up device spec + device_name, bandwidth = cls.device_spec + if device_name is not None: + with patch_device(device_name): + device_spec = CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) + + else: + device_spec = None + + # Stateful class level objects, which will be used in individual tests + cls.cm = cm =PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) + cls.FLOAT_TOL = 1e-5 + cls.expected = expected = {} + + # Start count for a + start = time.perf_counter() + with cm.count("a", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() - assert perf_dict['b']['num_tokens'] == psb.num_tokens - assert perf_dict['b']['total_io'] == psb.total_io - assert perf_dict['b']['total_flops'] == psb.total_flops - assert perf_dict['b']['duration'] == psb.duration - - # Test derived properties are present - perf_dict['a']['achieved_flops_per_s'] == psa.achieved_flops_per_s - perf_dict['a']['achieved_bandwidth'] == psa.achieved_bandwidth - perf_dict['b']['achieved_flops_per_s'] == psb.achieved_flops_per_s - perf_dict['b']['achieved_bandwidth'] == psb.achieved_bandwidth + duration = (end - start) + expected_flops = 2 * num_tokens * in_features * out_features + expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size + + expected['a'] = PerfCounterResult(name="a", + duration=duration, + flops=expected_flops, + io=expected_io, + total_flops=expected_flops, + total_io=expected_io) + + # Start count for b + start = time.perf_counter() + with cm.count("b", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + duration = end - start + + expected['b'] = PerfCounterResult(name="b", + duration=duration, + flops=expected_flops, + io=expected_io, + total_flops=cm.total_flops, + total_io=cm.total_io) + + def test_perf_stats_a(self): + cm: PerformanceCounterManager = self.cm + expected = self.expected['a'] + + counts = cm.get_counts() + assert "a" in counts + + # Check captured performance stats + psa: PerformanceStats = counts["a"] + # Raw metrics + # Duration won't be exact since timing external to the profiler + assert abs(psa.duration - expected.duration) < 1e-1 # +/- 100ms + assert psa.total_flops == expected.flops + assert psa.total_io == expected.io + + # Derived metrics + assert psa.token_throughput == psa.num_tokens / psa.duration + assert psa.achieved_flops_per_s == psa.total_flops / psa.duration + assert psa.achieved_bandwidth == psa.total_io / psa.duration - if device_spec is not None: - assert perf_dict['a']['device_flops_per_s'] == device_spec.flops_per_s - assert perf_dict['a']['device_bandwidth'] == device_spec.bandwidth - assert perf_dict['a']['bandwidth_utilization'] == psa.bandwidth_utilization - assert perf_dict['a']['flops_utilization'] == psa.flops_utilization - assert perf_dict['b']['device_flops_per_s'] == device_spec.flops_per_s - assert perf_dict['b']['device_bandwidth'] == device_spec.bandwidth - assert perf_dict['b']['bandwidth_utilization'] == psb.bandwidth_utilization - assert perf_dict['b']['flops_utilization'] == psb.flops_utilization - else: - assert perf_dict['a']['device_flops_per_s'] is None - assert perf_dict['a']['device_bandwidth'] is None - assert perf_dict['a']['bandwidth_utilization'] is None - assert perf_dict['a']['flops_utilization'] is None - assert perf_dict['b']['device_flops_per_s'] is None - assert perf_dict['b']['device_bandwidth'] is None - assert perf_dict['b']['bandwidth_utilization'] is None - assert perf_dict['b']['flops_utilization'] is None \ No newline at end of file + def test_perf_stats_b(self): + cm: PerformanceCounterManager = self.cm + assert "a" in cm.counts + assert "b" in cm.counts + psa = cm.counts["a"] + psb = cm.counts["b"] + expected = self.expected['b'] + assert abs(psb.duration - expected.duration) < 1e-1 # +/- 100ms + assert psb.total_flops == expected.flops + assert psb.total_io == expected.io + + # check that **total** flops and io after matmul `b` has run accounts for both matmuls + # also check that these global properties are updated correctly in the manager object + assert expected.total_flops == psa.total_flops + psb.total_flops == cm.total_flops + assert expected.total_io == psa.total_io + psb.total_io == cm.total_io + assert cm.total_time == psa.duration + psb.duration + + def test_stats_summary(self): + cm: PerformanceCounterManager = self.cm + FLOAT_TOL = self.FLOAT_TOL + psa = cm.counts["a"] + psb = cm.counts["b"] + summary: PerformanceStats = cm.stats_summary + + # Raw stats + assert summary.num_tokens == psa.num_tokens + psb.num_tokens + assert summary.total_io == psa.total_io + psb.total_io + assert summary.total_flops == psa.total_flops + psb.total_flops + assert summary.duration == psa.duration + psb.duration + + # Derived stats + expected_token_throughput = (psa.num_tokens + psb.num_tokens) / (psa.duration + psb.duration) + expected_io_throughput = (psa.total_io + psb.total_io) / (psa.duration + psb.duration) + expected_flops_throughput = (psa.total_flops + psb.total_flops) / (psa.duration + psb.duration) + assert abs(summary.token_throughput - expected_token_throughput) < FLOAT_TOL + assert abs(summary.achieved_bandwidth - expected_io_throughput) < FLOAT_TOL + assert abs(summary.achieved_flops_per_s - expected_flops_throughput) < FLOAT_TOL + + device_spec = cm.device_spec + if device_spec is not None: + expected_bandwidth_utilization = expected_io_throughput / device_spec.bandwidth + expected_flops_utilization = expected_flops_throughput / device_spec.flops_per_s + assert abs(summary.bandwidth_utilization - expected_bandwidth_utilization) < FLOAT_TOL + assert abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL + else: + assert summary.bandwidth_utilization is None + assert summary.flops_utilization is None + + def test_json(self): + cm: PerformanceCounterManager = self.cm + psa: PerformanceStats = cm.counts["a"] + psb: PerformanceStats = cm.counts["b"] + device_spec: Union[DeviceSpec, None] = cm.device_spec + + with tempfile.TemporaryDirectory() as tmp_dir: + json_path = Path(tmp_dir) / "test.json" + cm.to_json(json_path) + + with open(json_path, 'r') as f: + perf_dict = json.load(f) + + assert 'a' in perf_dict + assert 'b' in perf_dict + + #Test basic stats are recorded properly + assert perf_dict['a']['num_tokens'] == psa.num_tokens + assert perf_dict['a']['total_io'] == psa.total_io + assert perf_dict['a']['total_flops'] == psa.total_flops + assert perf_dict['a']['duration'] == psa.duration + + assert perf_dict['b']['num_tokens'] == psb.num_tokens + assert perf_dict['b']['total_io'] == psb.total_io + assert perf_dict['b']['total_flops'] == psb.total_flops + assert perf_dict['b']['duration'] == psb.duration + + # Test derived properties are present + perf_dict['a']['achieved_flops_per_s'] == psa.achieved_flops_per_s + perf_dict['a']['achieved_bandwidth'] == psa.achieved_bandwidth + perf_dict['b']['achieved_flops_per_s'] == psb.achieved_flops_per_s + perf_dict['b']['achieved_bandwidth'] == psb.achieved_bandwidth + + if device_spec is not None: + assert perf_dict['a']['device_flops_per_s'] == device_spec.flops_per_s + assert perf_dict['a']['device_bandwidth'] == device_spec.bandwidth + assert perf_dict['a']['bandwidth_utilization'] == psa.bandwidth_utilization + assert perf_dict['a']['flops_utilization'] == psa.flops_utilization + assert perf_dict['b']['device_flops_per_s'] == device_spec.flops_per_s + assert perf_dict['b']['device_bandwidth'] == device_spec.bandwidth + assert perf_dict['b']['bandwidth_utilization'] == psb.bandwidth_utilization + assert perf_dict['b']['flops_utilization'] == psb.flops_utilization + else: + assert perf_dict['a']['device_flops_per_s'] is None + assert perf_dict['a']['device_bandwidth'] is None + assert perf_dict['a']['bandwidth_utilization'] is None + assert perf_dict['a']['flops_utilization'] is None + assert perf_dict['b']['device_flops_per_s'] is None + assert perf_dict['b']['device_bandwidth'] is None + assert perf_dict['b']['bandwidth_utilization'] is None + assert perf_dict['b']['flops_utilization'] is None \ No newline at end of file diff --git a/test/profiler/utils.py b/test/profiler/utils.py new file mode 100644 index 000000000..f7b8269ba --- /dev/null +++ b/test/profiler/utils.py @@ -0,0 +1,49 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional +from unittest.mock import patch + +import torch + +from torchao.profiler import PerformanceTimer + + +@contextmanager +def patch_device(device_name): + with patch("torch.cuda.get_device_name", return_value=device_name): + yield + +@dataclass(frozen=True) + +class PerfStatsTestConfig: + label: str + num_tokens: int + duration: float + total_flops: float + total_io: float + flops_summary: dict + io_summary: dict + flop_counts: dict + io_counts: dict + device_bandwidth: Optional[float] = None + device_flops_per_s: Optional[float] = None + +def get_test_name(cls, num, params_dict): + return f"{cls.__name__}_{num}_{params_dict['name']}" + +@dataclass(frozen=True) +class PerfCounterResult: + name: str + duration: float + flops: float + io: float + total_flops: float + total_io: float + +@dataclass +class PerfCounterTestConfig: + name: str + shape: tuple[int] + timer_cls: PerformanceTimer + dtype: torch.dtype + device_spec: tuple[Optional[str], int] diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index b9d0d36e0..5069b45d5 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -8,7 +8,8 @@ from copy import deepcopy from dataclasses import asdict, dataclass from functools import partial -from typing import Any, Dict, Optional +from pathlib import Path +from typing import Any, Dict, Optional, Union import torch from torch.utils._pytree import tree_map @@ -273,10 +274,10 @@ def __str__(self): Throughput: {self.token_throughput:,.0f} tokens/s IO Total: {self._format(self.total_io, "B")} - Throughput: {self._format(self.io_throughput, "B/s")} + Throughput: {self._format(self.achieved_bandwidth, "B/s")} FLOPs Total: {self._format(self.total_flops, "FLOPs")} - Throughput: {self._format(self.flops_throughput, "FLOPs/s")}""") + Throughput: {self._format(self.achieved_flops_per_s, "FLOPs/s")}""") indent_2 = " " * 2 indent_4 = " " * 4 @@ -424,5 +425,9 @@ def to_dict(self): return counts - def to_json(self): - return json.dumps(self.to_dict(), indent=2) \ No newline at end of file + def to_json(self, path: Union[str, Path] = None): + d = self.to_dict() + if path: + with open(path, 'w') as f: + f.write(json.dumps(d, indent=2)) + return d \ No newline at end of file From 0d2885ae06efde4d05052864596b20d0b005d4fa Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 9 Jul 2024 18:41:31 +0000 Subject: [PATCH 17/33] refactor remaining tests --- test/profiler/test_performance_counter.py | 225 +++++++++++----------- test/profiler/utils.py | 44 ++++- 2 files changed, 153 insertions(+), 116 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 3bbaa2a42..f5d488095 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -18,11 +18,16 @@ import torch from parameterized import parameterized_class from utils import ( + PerfCounterManagerTestConfig, PerfCounterResult, PerfCounterTestConfig, PerfStatsTestConfig, + attn_io_check, + ffn_io_check, + get_leaf_nodes, get_test_name, patch_device, + qkv_proj_io_check, ) from torchao.profiler.device_spec import CUDADeviceSpec, DeviceSpec @@ -35,115 +40,116 @@ ) from torchao.utils import TORCH_VERSION_AFTER_2_5 +PERFCOUNTER_TEST_CONFIGS = [PerfCounterTestConfig(name="3.5B", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=32 // 2, + hidden_size=4096 // 2, + intermediate_size=11008 // 2, + num_attention_heads=32 // 2, + vocab_size=32000 // 2), + PerfCounterTestConfig(name="1.25B", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=32 // 4, + hidden_size=4096 // 4, + intermediate_size=11008 // 4, + num_attention_heads=32 // 4, + vocab_size=32000 // 4), + PerfCounterTestConfig(name="tiny", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=1, + hidden_size=4096 // 4, + intermediate_size=11008 // 4, + num_attention_heads=32 // 4, + vocab_size=32000 // 4)] -def get_leaf_nodes(count_keys, module_name): - return [k for k in count_keys if k.endswith(module_name)] - -def attn_proj_io_check(model_config, batch_size, seqlen, element_size): - input_size = batch_size * seqlen * model_config.hidden_size * element_size - weight_size = model_config.hidden_size * model_config.hidden_size * element_size - output_size = batch_size * seqlen * model_config.hidden_size * element_size - return input_size + weight_size + output_size -def attn_io_check(model_config, batch_size, seqlen, element_size): - # queries, keys, values -> factor of 3 - input_size = (batch_size * seqlen * model_config.hidden_size * 3) * element_size - output_size = (batch_size * seqlen * model_config.hidden_size) * element_size - return input_size + output_size - -def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): - assert module_name in ["up_proj", "gate_proj", "down_proj"] - - if module_name == "down_proj": - input_size = batch_size * seqlen * model_config.intermediate_size * element_size - else: - input_size = batch_size * seqlen * model_config.hidden_size * element_size - weight_size = model_config.hidden_size * model_config.intermediate_size * element_size - if module_name == "down_proj": - output_size = batch_size * seqlen * model_config.hidden_size * element_size - else: - output_size = batch_size * seqlen * model_config.intermediate_size * element_size - - return input_size + weight_size + output_size - - -CONFIG_7B = (32, 4096, 11008, 32, 32000) -MEDIUM_CONFIG = [p // 2 for p in CONFIG_7B] -SMALL_CONFIG = [p // 4 for p in CONFIG_7B] +@parameterized_class([asdict(cfg) for cfg in PERFCOUNTER_TEST_CONFIGS], class_name_func=get_test_name) +class PerformanceCounterTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + model_cfg = LlamaConfig(num_hidden_layers=cls.num_hidden_layers, + hidden_size=cls.hidden_size, + intermediate_size=cls.intermediate_size, + num_attention_heads=cls.num_attention_heads, + vocab_size=cls.vocab_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="requires torch >= 2.5") -@pytest.mark.parametrize("num_hidden_layers, hidden_size, intermediate_size, num_attention_heads, vocab_size", [MEDIUM_CONFIG, SMALL_CONFIG]) -@pytest.mark.parametrize("batch_size, seqlen", [(1, 128),]) -@pytest.mark.parametrize("dtype", [torch.float16], ids=lambda p: str(p)) -def test_performance_counter(num_hidden_layers, hidden_size, intermediate_size, num_attention_heads, vocab_size, batch_size, seqlen, dtype): + # Note we set some options manually since the model doesn't seem to be initialized correctly + # when these options are set in LlamaConfig + model_cfg._attn_implementation = "sdpa" + cls.model = model = LlamaForCausalLM(model_cfg).to(cls.dtype).to("cuda") + cls.model_config = model.config + cls.element_size = cls.dtype.itemsize + + input_ids = torch.randint(0, model.config.vocab_size, (cls.batch_size, cls.seqlen), device="cuda") + with torch.no_grad(): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + with PerformanceCounterMode() as perf_counter: + _ = model(input_ids) + cls.perf_counter = perf_counter + cls.summary_flops = perf_counter.get_summary_flop_counts() + cls.summary_io = perf_counter.get_summary_io_counts() + cls.flops_by_op = perf_counter.get_flop_counts() + cls.io_by_op = perf_counter.get_io_counts() - cfg = LlamaConfig(num_hidden_layers=num_hidden_layers, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_attention_heads=num_attention_heads, - vocab_size=vocab_size) + def test_qkv_proj(self): + batch_size, seqlen = self.batch_size, self.seqlen + element_size = self.element_size + + assert len(self.summary_flops) == len(self.summary_io) + assert self.summary_flops.keys() == self.summary_io.keys() - # Note we set some options manually since the model doesn't seem to be initialized correctly - # when these options are set in LlamaConfig - cfg._attn_implementation = "sdpa" - model = LlamaForCausalLM(cfg).to(dtype).to("cuda") - model_config = model.config - element_size = dtype.itemsize - - input_ids = torch.randint(0, model_config.vocab_size, (batch_size, seqlen), device="cuda") - with torch.no_grad(): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): - with PerformanceCounterMode() as perf_counter: - _ = model(input_ids) + # Attn Projections + for k in ["q_proj", "k_proj", "v_proj"]: + # Flops check + proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) + assert len(proj_keys) == self.model.config.num_hidden_layers + expected_flops = 2 * batch_size * seqlen * self.model_config.hidden_size * self.model_config.hidden_size + assert expected_flops == self.summary_flops[proj_keys[0]] + + # io check + expected_size = qkv_proj_io_check(self.model_config, batch_size, seqlen, element_size) + assert expected_size == self.summary_io[proj_keys[0]] - summary_flops = perf_counter.get_summary_flop_counts() - summary_io = perf_counter.get_summary_io_counts() - flops_by_op = perf_counter.get_flop_counts() - io_by_op = perf_counter.get_io_counts() - assert len(summary_flops) == len(summary_io) - assert summary_flops.keys() == summary_io.keys() - - # Attn Projections - for k in ["q_proj", "k_proj", "v_proj"]: - # Flops check - proj_keys = get_leaf_nodes(summary_flops.keys(), k) - assert len(proj_keys) == model.config.num_hidden_layers - expected_flops = 2 * batch_size * seqlen * model_config.hidden_size * model_config.hidden_size - assert expected_flops == summary_flops[proj_keys[0]] + def test_attn(self): + batch_size, seqlen = self.batch_size, self.seqlen + element_size = self.element_size + model_config = self.model.config - # io check - expected_size = attn_proj_io_check(model_config, batch_size, seqlen, element_size) - assert expected_size == summary_io[proj_keys[0]] - - # Attention - attention_keys = get_leaf_nodes(summary_flops.keys(), "self_attn") - for k in attention_keys: - flops = flops_by_op[k] - io_movement = io_by_op[k] - for op, count in flops.items(): - if "attention" in op.__name__: - expected_flops = 2 * 2 * batch_size * seqlen * seqlen * model_config.hidden_size - assert expected_flops == count - for op, count in io_movement.items(): - if "attention" in op.__name__: - # Check approx equal due to other small artifacts returned by sdpa.mem_efficient_attention - # See #https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L867 - # Check within 100 bytes - expected_size = attn_io_check(model_config, batch_size, seqlen, element_size) - assert abs(expected_size - count) < 100 - # FFN - for k in ["up_proj", "gate_proj", "down_proj"]: - proj_keys = get_leaf_nodes(summary_flops.keys(), k) - assert len(proj_keys) == model.config.num_hidden_layers - expected_flops = 2 * batch_size * seqlen * model_config.hidden_size * model_config.intermediate_size - assert expected_flops == summary_flops[proj_keys[0]] + attention_keys = get_leaf_nodes(self.summary_flops.keys(), "self_attn") + for k in attention_keys: + flops = self.flops_by_op[k] + io_movement = self.io_by_op[k] + for op, count in flops.items(): + if "attention" in op.__name__: + expected_flops = 2 * 2 * batch_size * seqlen * seqlen * model_config.hidden_size + assert expected_flops == count + for op, count in io_movement.items(): + if "attention" in op.__name__: + # Check approx equal due to other small artifacts returned by sdpa.mem_efficient_attention + # See #https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L867 + # Check within 100 bytes + expected_size = attn_io_check(model_config, batch_size, seqlen, element_size) + assert abs(expected_size - count) < 100 + + def test_ffn(self): + batch_size, seqlen = self.batch_size, self.seqlen + element_size = self.element_size - # io check - expected_size = ffn_io_check(model_config, batch_size, seqlen, element_size, k) - assert expected_size == summary_io[proj_keys[0]] - - - + for k in ["up_proj", "gate_proj", "down_proj"]: + proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) + assert len(proj_keys) == self.model.config.num_hidden_layers + expected_flops = 2 * batch_size * seqlen * self.model_config.hidden_size * self.model_config.intermediate_size + assert expected_flops == self.summary_flops[proj_keys[0]] + + # io check + expected_size = ffn_io_check(self.model_config, batch_size, seqlen, element_size, k) + assert expected_size == self.summary_io[proj_keys[0]] + PERFSTATS_TEST_CONFIGS = [PerfStatsTestConfig(label="with_device", num_tokens=128, duration=0.1, @@ -213,19 +219,8 @@ def test_performance_stats(cfg: PerfStatsTestConfig): assert expected_flops_utilization_str in stats_str -# @pytest.fixture -# def device_spec(request): -# device_name, bandwidth = request.param -# if device_name is not None: -# with patch_device(device_name): -# return CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) -# else: -# return None - - -PERFCOUNTERMANAGER_TEST_CONFIGS = [PerfCounterTestConfig("no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0)), - PerfCounterTestConfig("a100", (1, 1024, 4096, 4096), CUDAPerformanceTimer, torch.bfloat16, ("A100", 2e12))] - +PERFCOUNTERMANAGER_TEST_CONFIGS = [PerfCounterManagerTestConfig("no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0)), + PerfCounterManagerTestConfig("a100", (1, 1024, 4096, 4096), CUDAPerformanceTimer, torch.bfloat16, ("A100", 2e12))] @parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], class_name_func=get_test_name) class TestPerformanceCounterManager(unittest.TestCase): diff --git a/test/profiler/utils.py b/test/profiler/utils.py index f7b8269ba..bce565141 100644 --- a/test/profiler/utils.py +++ b/test/profiler/utils.py @@ -14,7 +14,49 @@ def patch_device(device_name): yield @dataclass(frozen=True) +class PerfCounterTestConfig: + name: str + batch_size: int + seqlen: int + dtype: torch.dtype + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + vocab_size: int + + +def get_leaf_nodes(count_keys, module_name): + return [k for k in count_keys if k.endswith(module_name)] + +def qkv_proj_io_check(model_config, batch_size, seqlen, element_size): + input_size = batch_size * seqlen * model_config.hidden_size * element_size + weight_size = model_config.hidden_size * model_config.hidden_size * element_size + output_size = batch_size * seqlen * model_config.hidden_size * element_size + return input_size + weight_size + output_size +def attn_io_check(model_config, batch_size, seqlen, element_size): + # queries, keys, values -> factor of 3 + input_size = (batch_size * seqlen * model_config.hidden_size * 3) * element_size + output_size = (batch_size * seqlen * model_config.hidden_size) * element_size + return input_size + output_size + +def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): + assert module_name in ["up_proj", "gate_proj", "down_proj"] + + if module_name == "down_proj": + input_size = batch_size * seqlen * model_config.intermediate_size * element_size + else: + input_size = batch_size * seqlen * model_config.hidden_size * element_size + weight_size = model_config.hidden_size * model_config.intermediate_size * element_size + if module_name == "down_proj": + output_size = batch_size * seqlen * model_config.hidden_size * element_size + else: + output_size = batch_size * seqlen * model_config.intermediate_size * element_size + + return input_size + weight_size + output_size + +@dataclass(frozen=True) class PerfStatsTestConfig: label: str num_tokens: int @@ -41,7 +83,7 @@ class PerfCounterResult: total_io: float @dataclass -class PerfCounterTestConfig: +class PerfCounterManagerTestConfig: name: str shape: tuple[int] timer_cls: PerformanceTimer From 7fc4c1e14fed7f9253ee090935747527ebada25b Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 9 Jul 2024 18:48:32 +0000 Subject: [PATCH 18/33] clean up tests --- test/profiler/test_performance_counter.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index f5d488095..9cb919dbe 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -40,6 +40,8 @@ ) from torchao.utils import TORCH_VERSION_AFTER_2_5 +# ------------------- PerformanceCounter Tests ------------------- # + PERFCOUNTER_TEST_CONFIGS = [PerfCounterTestConfig(name="3.5B", batch_size=1, seqlen=128, @@ -68,6 +70,8 @@ num_attention_heads=32 // 4, vocab_size=32000 // 4)] +@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "PerformanceCounter requires torch >= 2.5+.") +@unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounter requires CUDA") @parameterized_class([asdict(cfg) for cfg in PERFCOUNTER_TEST_CONFIGS], class_name_func=get_test_name) class PerformanceCounterTest(unittest.TestCase): @classmethod @@ -150,6 +154,8 @@ def test_ffn(self): expected_size = ffn_io_check(self.model_config, batch_size, seqlen, element_size, k) assert expected_size == self.summary_io[proj_keys[0]] +# ------------------- PerformanceStats Tests ------------------- # + PERFSTATS_TEST_CONFIGS = [PerfStatsTestConfig(label="with_device", num_tokens=128, duration=0.1, @@ -197,7 +203,7 @@ def test_performance_stats(cfg: PerfStatsTestConfig): # Test str - stats should be formatted to closest power of 10 ** 3 with 2 decimal places of precision stats_str = str(stats) - print(stats_str) + # Base Stats expected_io_str = ".12 GB" expected_flops_str = ".12 TFLOPs" @@ -218,10 +224,13 @@ def test_performance_stats(cfg: PerfStatsTestConfig): expected_flops_utilization_str = f"{stats.achieved_flops_per_s / device_flops_per_s:.2f}%" assert expected_flops_utilization_str in stats_str +# ------------------- PerformanceCounterManager Tests ------------------- # PERFCOUNTERMANAGER_TEST_CONFIGS = [PerfCounterManagerTestConfig("no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0)), PerfCounterManagerTestConfig("a100", (1, 1024, 4096, 4096), CUDAPerformanceTimer, torch.bfloat16, ("A100", 2e12))] +@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "PerformanceCounterManager requires torch >= 2.5+.") +@unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounterManager requires CUDA") @parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], class_name_func=get_test_name) class TestPerformanceCounterManager(unittest.TestCase): @classmethod From 0363e1bc5901276fd2db1c0ff6efb0f6dc8cbbcf Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 9 Jul 2024 18:54:30 +0000 Subject: [PATCH 19/33] clean up device_spec tests --- test/profiler/test_device_spec.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/test/profiler/test_device_spec.py b/test/profiler/test_device_spec.py index d76fdd3a2..fb45dca86 100644 --- a/test/profiler/test_device_spec.py +++ b/test/profiler/test_device_spec.py @@ -7,6 +7,7 @@ from unittest.mock import patch import torch +from utils import patch_device from torchao.profiler.device_spec import ( _AVAILABLE_GPU_SPECS, @@ -20,11 +21,6 @@ USE_TENSORCORES = [True, False] DEVICE_CONFIGS = itertools.product(DEVICE_NAMES, DTYPES, USE_TENSORCORES) - -@contextmanager -def patch_device(device_name): - with patch("torch.cuda.get_device_name", return_value=device_name): - yield @pytest.mark.parametrize("device_name, dtype, use_tensorcores", DEVICE_CONFIGS, ids=lambda x: str(x)) def test_device_spec(device_name, dtype, use_tensorcores): with patch_device(device_name): @@ -33,12 +29,12 @@ def test_device_spec(device_name, dtype, use_tensorcores): dtype = "tfloat32" chip_name = get_chip_name(device_name) expected_flops = _AVAILABLE_GPU_SPECS[chip_name][dtype] - assert device_spec.flop_per_s == expected_flops + assert device_spec.flops_per_s == expected_flops assert device_spec.flops_by_dtype[dtype] == expected_flops assert device_spec.roofline_balancepoint == expected_flops / device_spec.bandwidth with pytest.raises(AssertionError): - device_spec.flop_per_s = None + device_spec.flops_per_s = None print(device_spec.roofline_balancepoint) # Prevent setting attributes not in named fields to guard against user error with pytest.raises(AttributeError): @@ -48,21 +44,21 @@ def test_empty_device_spec(): device_name = "fake device" with patch_device(device_name): with pytest.raises(AssertionError): - device_spec = CUDADeviceSpec() + _ = CUDADeviceSpec() # Ok to instantiate as long as fields are filled - device_spec = CUDADeviceSpec(name=device_name, - flop_per_s=1.0, - bandwidth=1.0, - dtype=torch.float32, - use_tensorcores=True) + _ = CUDADeviceSpec(name=device_name, + flops_per_s=1.0, + bandwidth=1.0, + dtype=torch.float32, + use_tensorcores=True) device_name = DEVICE_NAMES[0] with patch_device(device_name): # All critical fields will be auto-filled except for dtype (and vram, but vram is not used for downstream calcs atm) - device_spec = CUDADeviceSpec(dtype=torch.float32) + _ = CUDADeviceSpec(dtype=torch.float32) # No dtype specified with pytest.raises(AssertionError): - device_spec = CUDADeviceSpec() + _ = CUDADeviceSpec() From 06f0b082d2e8cc10e521be28d8d6af38419dafa7 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 9 Jul 2024 23:37:33 +0000 Subject: [PATCH 20/33] add latency --- test/profiler/test_performance_counter.py | 16 ++++-- torchao/profiler/performance_counter.py | 65 +++++++++++++++++------ 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 9cb919dbe..6cb6efa4f 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -194,13 +194,17 @@ def test_performance_stats(cfg: PerfStatsTestConfig): assert stats.achieved_flops_per_s == total_flops / duration if device_bandwidth is not None: assert stats.bandwidth_utilization == stats.achieved_bandwidth / device_bandwidth + assert stats.theoretical_io_latency == total_io / device_bandwidth else: assert stats.bandwidth_utilization is None + assert stats.theoretical_io_latency is None if device_flops_per_s is not None: assert stats.flops_utilization == stats.achieved_flops_per_s / device_flops_per_s + assert stats.theoretical_compute_latency == total_flops / device_flops_per_s else: assert stats.flops_utilization is None - + assert stats.theoretical_compute_latency is None + # Test str - stats should be formatted to closest power of 10 ** 3 with 2 decimal places of precision stats_str = str(stats) @@ -218,12 +222,16 @@ def test_performance_stats(cfg: PerfStatsTestConfig): # Utilization Stats if device_bandwidth is not None: - expected_bandwidth_utilization_str = f"{stats.achieved_bandwidth / device_bandwidth:.2f}%" + expected_bandwidth_utilization_str = f"{stats.achieved_bandwidth / device_bandwidth:.2f} %" + expected_io_latency_str = f"{stats.theoretical_io_latency:.2f} s" assert expected_bandwidth_utilization_str in stats_str + assert expected_io_latency_str in stats_str + if device_flops_per_s is not None: - expected_flops_utilization_str = f"{stats.achieved_flops_per_s / device_flops_per_s:.2f}%" + expected_flops_utilization_str = f"{stats.achieved_flops_per_s / device_flops_per_s:.2f} %" + expected_compute_latency_str = f"{stats.theoretical_compute_latency:.2f} s" assert expected_flops_utilization_str in stats_str - + assert expected_compute_latency_str in stats_str # ------------------- PerformanceCounterManager Tests ------------------- # PERFCOUNTERMANAGER_TEST_CONFIGS = [PerfCounterManagerTestConfig("no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0)), diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 5069b45d5..5b14ccd96 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -3,6 +3,7 @@ import math import textwrap import time +import warnings from collections import defaultdict from contextlib import contextmanager from copy import deepcopy @@ -17,8 +18,20 @@ from .device_spec import DeviceSpec -aten = torch.ops.aten +# Set to keep track of issued warnings +_issued_warnings = set() + +def warn_once(message): + global _issued_warnings + if message not in _issued_warnings: + warnings.warn(message, CustomWarning) + _issued_warnings.add(message) + +# Define a custom warning category +class CustomWarning(UserWarning): + pass +aten = torch.ops.aten class PerformanceCounterMode(FlopCounterMode): def __init__(self, display=False, depth=10, debug=False): self.debug = debug @@ -249,22 +262,39 @@ def achieved_flops_per_s(self): def achieved_bandwidth(self): return self.total_io / self.duration + @property + def theoretical_io_latency(self): + if self.device_bandwidth is not None: + return self.total_io / self.device_bandwidth + else: + warn_once("Device bandwidth is not specified. Please specify the device bandwidth to enable io latency calculation") + return None + + @property + def theoretical_compute_latency(self): + if self.device_flops_per_s is not None: + return self.total_flops / self.device_flops_per_s + else: + warn_once("Device flops_per_s is not specified. Please specify the device throughput to enable compute latency calculation") + return None + @property def bandwidth_utilization(self): if self.device_bandwidth is not None: return self.achieved_bandwidth / self.device_bandwidth else: - print("Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation") + warn_once("Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation") return None @property def flops_utilization(self): if self.device_flops_per_s is not None: return self.achieved_flops_per_s / self.device_flops_per_s else: - print("Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") + warn_once("Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") return None def _format(self, value, suffix): return to_nearest_power_of_10(value) + suffix + def __str__(self): txt = textwrap.dedent(f"""\ {self.label}: @@ -275,18 +305,23 @@ def __str__(self): IO Total: {self._format(self.total_io, "B")} Throughput: {self._format(self.achieved_bandwidth, "B/s")} + Theoretical Latency: {self._format(self.theoretical_io_latency, "s") if self.theoretical_io_latency is not None else "N/A"} FLOPs Total: {self._format(self.total_flops, "FLOPs")} - Throughput: {self._format(self.achieved_flops_per_s, "FLOPs/s")}""") - - indent_2 = " " * 2 - indent_4 = " " * 4 - if self.bandwidth_utilization is not None: - txt += "\n" + textwrap.indent("""Utilization:\n""", indent_2) - txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.2f}%""", indent_4) - - if self.flops_utilization is not None: - txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.2f}%""", indent_4) + Throughput: {self._format(self.achieved_flops_per_s, "FLOPs/s")} + Theoretical Latency: {self._format(self.theoretical_compute_latency, "s") if self.theoretical_compute_latency is not None else "N/A"} + Utilization + Bandwidth: {self._format(self.bandwidth_utilization, "%") if self.bandwidth_utilization is not None else "N/A"} + FLOPs: {self._format(self.flops_utilization, "%") if self.flops_utilization is not None else "N/A"}""") + + # indent_2 = " " * 2 + # indent_4 = " " * 4 + # if self.bandwidth_utilization is not None: + # txt += "\n" + textwrap.indent("""Utilization:\n""", indent_2) + # txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.2f}%""", indent_4) + + # if self.flops_utilization is not None: + # txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.2f}%""", indent_4) return txt @@ -411,8 +446,8 @@ def print_summary(self, labels: list[str] = None): _print(text) else: for label in labels: - text = str(self._count[label]) - _print(self._count[label]) + text = str(self._counts[label]) + _print(self._counts[label]) def to_dict(self): # Convert flop_counts from OpOverloadPackets to str From 95c1c28a6435d4c24035ab6c42a37b472cc86c5c Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 9 Jul 2024 23:45:56 +0000 Subject: [PATCH 21/33] add latency tests --- test/profiler/test_performance_counter.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index 6cb6efa4f..b36b67299 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -239,7 +239,7 @@ def test_performance_stats(cfg: PerfStatsTestConfig): @unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "PerformanceCounterManager requires torch >= 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounterManager requires CUDA") -@parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], class_name_func=get_test_name) +@parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS[-1:]], class_name_func=get_test_name) class TestPerformanceCounterManager(unittest.TestCase): @classmethod def setUpClass(cls): @@ -398,18 +398,14 @@ def test_json(self): if device_spec is not None: assert perf_dict['a']['device_flops_per_s'] == device_spec.flops_per_s assert perf_dict['a']['device_bandwidth'] == device_spec.bandwidth + assert perf_dict['a']['theoretical_io_latency'] == psa.theoretical_io_latency + assert perf_dict['a']['theoretical_compute_latency'] == psa.theoretical_compute_latency assert perf_dict['a']['bandwidth_utilization'] == psa.bandwidth_utilization assert perf_dict['a']['flops_utilization'] == psa.flops_utilization + assert perf_dict['b']['device_flops_per_s'] == device_spec.flops_per_s assert perf_dict['b']['device_bandwidth'] == device_spec.bandwidth + assert perf_dict['b']['theoretical_io_latency'] == psb.theoretical_io_latency + assert perf_dict['b']['theoretical_compute_latency'] == psb.theoretical_compute_latency assert perf_dict['b']['bandwidth_utilization'] == psb.bandwidth_utilization - assert perf_dict['b']['flops_utilization'] == psb.flops_utilization - else: - assert perf_dict['a']['device_flops_per_s'] is None - assert perf_dict['a']['device_bandwidth'] is None - assert perf_dict['a']['bandwidth_utilization'] is None - assert perf_dict['a']['flops_utilization'] is None - assert perf_dict['b']['device_flops_per_s'] is None - assert perf_dict['b']['device_bandwidth'] is None - assert perf_dict['b']['bandwidth_utilization'] is None - assert perf_dict['b']['flops_utilization'] is None \ No newline at end of file + assert perf_dict['b']['flops_utilization'] == psb.flops_utilization \ No newline at end of file From 09208c118fb16b39005e0be31d6b550f260cfc69 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 00:16:43 +0000 Subject: [PATCH 22/33] fix formatting --- test/profiler/test_performance_counter.py | 6 ++--- torchao/profiler/performance_counter.py | 32 +++++++++-------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index b36b67299..db99bc247 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -222,13 +222,13 @@ def test_performance_stats(cfg: PerfStatsTestConfig): # Utilization Stats if device_bandwidth is not None: - expected_bandwidth_utilization_str = f"{stats.achieved_bandwidth / device_bandwidth:.2f} %" + expected_bandwidth_utilization_str = f"{stats.achieved_bandwidth / device_bandwidth:.4f}" expected_io_latency_str = f"{stats.theoretical_io_latency:.2f} s" assert expected_bandwidth_utilization_str in stats_str assert expected_io_latency_str in stats_str if device_flops_per_s is not None: - expected_flops_utilization_str = f"{stats.achieved_flops_per_s / device_flops_per_s:.2f} %" + expected_flops_utilization_str = f"{stats.achieved_flops_per_s / device_flops_per_s:.4f}" expected_compute_latency_str = f"{stats.theoretical_compute_latency:.2f} s" assert expected_flops_utilization_str in stats_str assert expected_compute_latency_str in stats_str @@ -239,7 +239,7 @@ def test_performance_stats(cfg: PerfStatsTestConfig): @unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "PerformanceCounterManager requires torch >= 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounterManager requires CUDA") -@parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS[-1:]], class_name_func=get_test_name) +@parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], class_name_func=get_test_name) class TestPerformanceCounterManager(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 5b14ccd96..e4b234694 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -21,16 +21,17 @@ # Set to keep track of issued warnings _issued_warnings = set() + +# Define a custom warning category +class DeviceInfoMissing(UserWarning): + pass + def warn_once(message): global _issued_warnings if message not in _issued_warnings: - warnings.warn(message, CustomWarning) + warnings.warn(message, DeviceInfoMissing) _issued_warnings.add(message) -# Define a custom warning category -class CustomWarning(UserWarning): - pass - aten = torch.ops.aten class PerformanceCounterMode(FlopCounterMode): def __init__(self, display=False, depth=10, debug=False): @@ -292,9 +293,11 @@ def flops_utilization(self): else: warn_once("Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") return None - def _format(self, value, suffix): - return to_nearest_power_of_10(value) + suffix - + def _format(self, value, suffix, precision=2, round=True): + if round: + return to_nearest_power_of_10(value, precision=precision) + suffix + return f"{value:.{precision}f} " + suffix + def __str__(self): txt = textwrap.dedent(f"""\ {self.label}: @@ -311,17 +314,8 @@ def __str__(self): Throughput: {self._format(self.achieved_flops_per_s, "FLOPs/s")} Theoretical Latency: {self._format(self.theoretical_compute_latency, "s") if self.theoretical_compute_latency is not None else "N/A"} Utilization - Bandwidth: {self._format(self.bandwidth_utilization, "%") if self.bandwidth_utilization is not None else "N/A"} - FLOPs: {self._format(self.flops_utilization, "%") if self.flops_utilization is not None else "N/A"}""") - - # indent_2 = " " * 2 - # indent_4 = " " * 4 - # if self.bandwidth_utilization is not None: - # txt += "\n" + textwrap.indent("""Utilization:\n""", indent_2) - # txt += textwrap.indent(f"""Bandwidth: {self.bandwidth_utilization:.2f}%""", indent_4) - - # if self.flops_utilization is not None: - # txt += "\n" + textwrap.indent(f"""FLOPs: {self.flops_utilization:.2f}%""", indent_4) + Bandwidth: {self._format(self.bandwidth_utilization, round=False, precision=4, suffix="%") if self.bandwidth_utilization is not None else "N/A"} + FLOPs: {self._format(self.flops_utilization, round=False, precision=4, suffix="%") if self.flops_utilization is not None else "N/A"}""") return txt From 4de130a09711b638ba016cc14c6dbf05ff47d5c2 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 00:17:55 +0000 Subject: [PATCH 23/33] remove unused methods --- torchao/profiler/performance_counter.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index e4b234694..14022c016 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -410,27 +410,6 @@ def stats_summary(self): device_flops_per_s=self.device_spec.flops_per_s if self.device_spec is not None else None) return stats - - def _format_totals(self, precision=2): - ms = round(self.total_time * 1e3, precision) - token_throughput = round(self.total_tokens / self.total_time, precision) - gflops = round(self.total_flops / 1e9, precision) - gb = round(self.total_io / 1e9, precision) - flop_throughput = round(gflops / self.total_time, precision) - io_throughput = round(gb / self.total_time, precision) - text = textwrap.dedent(f"""\ - FlopCounter Summary: - Total time = {ms:,} ms - Tokens: - Total {self.total_tokens} - Throughput {token_throughput:,} tokens/s - IO: - Total {gb:,} GB - Throughput {io_throughput:,} GB/s - FLOPs: - Total {gflops:,} GFLOPs - Throughput {flop_throughput:,} GFLOP/s""") - return text def print_summary(self, labels: list[str] = None): _print = partial(print, flush=True, end='\n') From 2afa14fa36be82c9cd24cc8019488aee5760a7b4 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 12:24:59 +0000 Subject: [PATCH 24/33] add documentation --- torchao/profiler/device_spec.py | 37 +++++++---- torchao/profiler/performance_counter.py | 81 ++++++++++++++++++++++--- 2 files changed, 96 insertions(+), 22 deletions(-) diff --git a/torchao/profiler/device_spec.py b/torchao/profiler/device_spec.py index 7c40c73e9..807a9d1b0 100644 --- a/torchao/profiler/device_spec.py +++ b/torchao/profiler/device_spec.py @@ -1,13 +1,15 @@ -import logging -from collections import defaultdict -from copy import copy from dataclasses import dataclass, field, fields from typing import Dict, Optional, Union import torch -logger = logging.getLogger(__name__) +"""This module contains the device specs for theoretical peak performance calculations. +- Contains a list of available chips and their corresponding theoretical peak FLOPs performance for various torch.dtypes. +- Exposes a DeviceSpec interface and a concrete CUDADeviceSpec implementation for CUDA gpus. Extendable to other device types. +- Where possible, the CUDADeviceSpec auto-populates its fields by utilizing `torch.cuda` API and `triton.runtime.driver`. + +""" # Copied from https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/throughput.py _AVAILABLE_GPU_SPECS: Dict[str, Dict[Union[str, torch.dtype], float]] = { # Hopper @@ -289,7 +291,7 @@ class DeviceSpec: - flops_per_s (FLOP / s) - vram (bytes) - dtype (torch.dtype) dtype used for theoretical peak performance - - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOPs + - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOP / s """ device_type: str @@ -349,13 +351,22 @@ def roofline_balancepoint(self): @dataclass class CUDADeviceSpec(DeviceSpec): """ - CUDA specs for theoretical peak performance - - Fields will be auto-populated in __post_init__ if not already specified - and if data is available - - See DeviceSpec for a list of available fields - See AVAILABLE_GPU_SPECS for a list of available chips + CUDA specs for theoretical peak performance, conformant with DeviceSpec interface. + + Fields will be auto-populated in __post_init__ if not specified + and if data is available. + + See _AVAILABLE_GPU_SPECS for a list of available chip data. + + Fields and expected units: + - device (int): CUDA device index + - name (str): name of the device + - bandwidth (bytes /s): memory bandwidth in bytes / s + - flops_per_s (FLOP / s): FLOPs per second + - vram (bytes): VRAM in bytes + - dtype (torch.dtype): dtype used for theoretical peak performance + - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOP / s + - use_tensorcores (bool): whether to use tensorcores if dtype == torch.float32 """ device_type: str = "cuda" @@ -373,7 +384,7 @@ def __post_init__(self): if self.bandwidth is None: self.bandwidth = get_bandwidth() - # FLOPs + # FLOPs / s if self.flops_per_s is None: chip_name = get_chip_name(self.device) if chip_name is None: diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 14022c016..2366b3899 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -18,9 +18,11 @@ from .device_spec import DeviceSpec -# Set to keep track of issued warnings -_issued_warnings = set() +aten = torch.ops.aten +# TODO: Quick hack to track issued warnings to prevent excessive output each time a field is missing. +# Implement a cleaner solution. +_issued_warnings = set() # Define a custom warning category class DeviceInfoMissing(UserWarning): @@ -32,8 +34,21 @@ def warn_once(message): warnings.warn(message, DeviceInfoMissing) _issued_warnings.add(message) -aten = torch.ops.aten class PerformanceCounterMode(FlopCounterMode): + """ + ``PerformanceCounterMode`` extends FlopCounterMode to track IO in addition to flops. + + It does this using a ``TorchDispatchMode`` per `FlopCounterMode` and tracks the + inputs and outputs of each operator, organized by module. + + In addition to the methods exposed by FlopCounterMode, the following methods are + available: + - ``get_io_counts``: returns a dictionary of module names and their associated IO counts by aten operator + - ``get_total_io``: returns the total number of IO operations across all modules + - ``get_summary_io_counts``: returns a summary of the IO counts for each module (totals by operator) + - ``get_summary_flop_counts``: returns a summary of the flop counts for each module (totals by operator) + """ + def __init__(self, display=False, depth=10, debug=False): self.debug = debug self.io_counts = defaultdict(lambda: defaultdict(int)) @@ -112,6 +127,23 @@ def _count_flops(self, func_packet, out, args, kwargs): return out class PerformanceTimer: + """ + Context manager that records the duration, io, and flops of a torch operator / module. + + Timing is done using `time.perf_counter` and can be overridden to use a different + timer (see `CUDAPerformanceTimer`). + + IO and FLOPs are recorded using `PerformanceCounterMode`. + + Available attributes: + name: str + precision: int + display: bool + depth (int): passed to `PerformanceCounterMode` if displaying and determines depth of module tree to display. + **Note**: these attributes are primarily used for debugging when using the `PerformanceTimer` standalone. + The PerformanceCounterManager class is a higher-level API that should be used instead. + + """ def __init__(self, name, precision=1, display=False, depth=10): self.name = name self.precision = precision @@ -166,8 +198,12 @@ def io_counts(self): def get_pretty_summary(self, depth): return self.perf_counter.pretty_summary_counts(depth=depth if depth is not None else self.depth) + class CUDAPerformanceTimer(PerformanceTimer): - + """ + `PerformanceTimer` that uses `cudaEvents` to record duration. + """ + def __enter__(self): self.start = torch.cuda.Event(enable_timing=True) self.end = torch.cuda.Event(enable_timing=True) @@ -217,8 +253,9 @@ def to_nearest_power_of_10(x, precision=2): class DictMixin: """ - Mixin to enable dict-like access to dataclass attributes + Enables dict-like interface to dataclasses. """ + def __getitem__(self, key): if hasattr(self, key): return getattr(self, key) @@ -235,11 +272,36 @@ def __iter__(self): for key in self.__dict__: yield key -# Function to get all property methods of a class -def get_property_methods(cls): - return [name for name, member in inspect.getmembers(cls, lambda m: isinstance(m, property))] +def _get_property_methods(cls): + return [name for name, _ in inspect.getmembers(cls, lambda m: isinstance(m, property))] + @dataclass class PerformanceStats(DictMixin): + """ + Data struct that stores performance statistics. + + Attrs: + num_tokens (int): number of tokens processed + duration (float): duration in seconds + total_flops (int): total FLOPs + total_io (int): total data movement in bytes + flops_summary (Dict[str, int]): summary of FLOPs by module + io_summary (Dict[str, int]): summary of data movement in bytes by module + flop_counts (Dict[str, Dict[Any, int]]): FLOP counts by module and operation + io_counts (Dict[str, Dict[Any, int]]): data movement by module and operation + device_bandwidth (Optional[float]): device bandwidth in bytes per second + device_flops_per_s (Optional[float]): device FLOPs per second + + Additionally, the following derived properties are available: + token_throughput (float): number of tokens processed per second + achieved_flops_per_s (float): achieved FLOPs per second + achieved_bandwidth (float): achieved data movement in bytes per second + theoretical_io_latency (Optional[float]): theoretical I/O latency in seconds, set to None if + no device bandwidth is available. + theoretical_compute_latency (Optional[float]): theoretical compute latency in seconds, set to None if + no device FLOPs are available. + """ + label: str num_tokens: int duration: float @@ -251,6 +313,7 @@ class PerformanceStats(DictMixin): io_counts: Dict[str, Dict[Any, int]] device_bandwidth: Optional[float] = None device_flops_per_s: Optional[float] = None + @property def token_throughput(self): return self.num_tokens / self.duration @@ -322,7 +385,7 @@ def __str__(self): def to_dict(self): d = asdict(self) # Update dict with properties - props = get_property_methods(self.__class__) + props = _get_property_methods(self.__class__) d.update({prop: getattr(self, prop) for prop in props}) return d From ff2d1931989d2f6fa8fab98466334da66eed3001 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 12:41:50 +0000 Subject: [PATCH 25/33] more docs --- torchao/profiler/device_spec.py | 14 +- torchao/profiler/performance_counter.py | 346 +++++++++++++++--------- 2 files changed, 230 insertions(+), 130 deletions(-) diff --git a/torchao/profiler/device_spec.py b/torchao/profiler/device_spec.py index 807a9d1b0..040367583 100644 --- a/torchao/profiler/device_spec.py +++ b/torchao/profiler/device_spec.py @@ -303,9 +303,15 @@ class DeviceSpec: flops_by_dtype: dict = field(default_factory=dict) def _post_init_check(self): - assert self.bandwidth is not None, "GPU bandwidth is None - please specify the bandwidth in GB/s in order to enable speed of light calculations" - assert self.dtype is not None, "GPU dtype is None - please specify the dtype in order to enable speed of light calculations" - assert self.flops_per_s is not None, "GPU flops_per_s is None - please specify the flops_per_s in FLOP/s in order to enable speed of light calculations" + assert ( + self.bandwidth is not None + ), "GPU bandwidth is None - please specify the bandwidth in GB/s in order to enable speed of light calculations" + assert ( + self.dtype is not None + ), "GPU dtype is None - please specify the dtype in order to enable speed of light calculations" + assert ( + self.flops_per_s is not None + ), "GPU flops_per_s is None - please specify the flops_per_s in FLOP/s in order to enable speed of light calculations" self.flops_by_dtype.update({self.dtype: self.flops_per_s}) # Not needed for downstream calculations atm, no need to assert @@ -357,7 +363,7 @@ class CUDADeviceSpec(DeviceSpec): and if data is available. See _AVAILABLE_GPU_SPECS for a list of available chip data. - + Fields and expected units: - device (int): CUDA device index - name (str): name of the device diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 2366b3899..55f2953c2 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -24,16 +24,19 @@ # Implement a cleaner solution. _issued_warnings = set() + # Define a custom warning category class DeviceInfoMissing(UserWarning): pass + def warn_once(message): global _issued_warnings if message not in _issued_warnings: warnings.warn(message, DeviceInfoMissing) _issued_warnings.add(message) + class PerformanceCounterMode(FlopCounterMode): """ ``PerformanceCounterMode`` extends FlopCounterMode to track IO in addition to flops. @@ -48,48 +51,67 @@ class PerformanceCounterMode(FlopCounterMode): - ``get_summary_io_counts``: returns a summary of the IO counts for each module (totals by operator) - ``get_summary_flop_counts``: returns a summary of the flop counts for each module (totals by operator) """ - + def __init__(self, display=False, depth=10, debug=False): self.debug = debug self.io_counts = defaultdict(lambda: defaultdict(int)) super().__init__(display=display, depth=depth) - + def get_io_counts(self): - return {k: dict(v) for k,v in self.io_counts.items()} - + return {k: dict(v) for k, v in self.io_counts.items()} + def get_total_io(self): - return sum(self.io_counts['Global'].values()) + return sum(self.io_counts["Global"].values()) def _get_io_sizes(self, args): - sizes = tree_map(lambda x: x.numel() * x.element_size() if isinstance(x, torch.Tensor) else 0, args) + sizes = tree_map( + lambda x: x.numel() * x.element_size() + if isinstance(x, torch.Tensor) + else 0, + args, + ) if not hasattr(sizes, "__len__"): sizes = [sizes] return sizes - + def get_summary_flop_counts(self): flop_counts = self.get_flop_counts() - return {k: sum(v.values()) for k,v in flop_counts.items()} - + return {k: sum(v.values()) for k, v in flop_counts.items()} + def get_summary_io_counts(self): io_counts = self.get_io_counts() - return {k: sum(v.values()) for k,v in io_counts.items()} - + return {k: sum(v.values()) for k, v in io_counts.items()} + def _nearest_power_of_10(self, x): if x == 0: return x, 0 - + power = int(math.floor(math.log10(abs(x)) / 3)) scaled_value = x / (10 ** (3 * power)) - + return scaled_value, power - + def pretty_summary_counts(self, type="flops", precision=2, depth=None): assert type in ["flops", "io"] - metric_units = {0: '', 1: 'k', 2: 'M', 3: 'G', 4: 'T', 5: 'P', 6: 'E', 7: 'Z', 8: 'Y'} + metric_units = { + 0: "", + 1: "k", + 2: "M", + 3: "G", + 4: "T", + 5: "P", + 6: "E", + 7: "Z", + 8: "Y", + } if depth is None: depth = self.depth - summary_counts = self.get_summary_flop_counts() if type == "flops" else self.get_summary_io_counts() + summary_counts = ( + self.get_summary_flop_counts() + if type == "flops" + else self.get_summary_io_counts() + ) keys_to_print = [k for k in summary_counts.keys() if len(k.split(".")) <= depth] units = "FLOPs" if type == "flops" else "B" summary_str = [] @@ -98,59 +120,69 @@ def pretty_summary_counts(self, type="flops", precision=2, depth=None): continue spaces = " " * (len(k.split(".")) - 1) scaled_val, power = self._nearest_power_of_10(summary_counts[k]) - formatted_val = f"{scaled_val:.{precision}f}{metric_units[power]}{units}" + formatted_val = f"{scaled_val:.{precision}f}{metric_units[power]}{units}" summary_str.append(f"{spaces}{k}: {formatted_val}") - + return "\n".join(summary_str) - + def _count_io(self, func_packet, out, args, kwargs): arg_sizes = self._get_io_sizes(args) kwargs_sizes = self._get_io_sizes(kwargs.values()) out_sizes = self._get_io_sizes(out) - arg_size, kwargs_size, out_size = sum(arg_sizes), sum(kwargs_sizes), sum(out_sizes) + arg_size, kwargs_size, out_size = ( + sum(arg_sizes), + sum(kwargs_sizes), + sum(out_sizes), + ) return arg_size, kwargs_size, out_size - + def _count_flops(self, func_packet, out, args, kwargs): if func_packet in self.flop_registry: flop_count_func = self.flop_registry[func_packet] flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator] - arg_size, kwarg_size, out_size = self._count_io(func_packet, out, args, kwargs) + arg_size, kwarg_size, out_size = self._count_io( + func_packet, out, args, kwargs + ) total_size = arg_size + kwarg_size + out_size for par in set(self.mod_tracker.parents): if self.debug: print(f"Counting flops for {par}, {func_packet}: {flop_count}") - print(f"Counting io for {par}, {func_packet}: {sum([arg_size, kwarg_size, out_size])} = {arg_size} + {kwarg_size} + {out_size}") + print( + f"Counting io for {par}, {func_packet}: {sum([arg_size, kwarg_size, out_size])} = {arg_size} + {kwarg_size} + {out_size}" + ) self.flop_counts[par][func_packet] += flop_count self.io_counts[par][func_packet] += total_size - + return out + class PerformanceTimer: """ Context manager that records the duration, io, and flops of a torch operator / module. - + Timing is done using `time.perf_counter` and can be overridden to use a different timer (see `CUDAPerformanceTimer`). - + IO and FLOPs are recorded using `PerformanceCounterMode`. - - Available attributes: + + Available attributes: name: str precision: int display: bool depth (int): passed to `PerformanceCounterMode` if displaying and determines depth of module tree to display. **Note**: these attributes are primarily used for debugging when using the `PerformanceTimer` standalone. The PerformanceCounterManager class is a higher-level API that should be used instead. - + """ + def __init__(self, name, precision=1, display=False, depth=10): self.name = name self.precision = precision - self.display = display + self.display = display self.depth = depth self.perf_counter = PerformanceCounterMode(display=display, depth=depth) - + def __enter__(self): self.start = time.perf_counter() self.perf_counter.__enter__() @@ -159,56 +191,61 @@ def __enter__(self): def _print_exit_msg(self): gflops = round(self.total_flops / 1e9, self.precision) ms = round(self.duration * 1e3, self.precision) - if self.display: + if self.display: print(f"{self.name.upper()}: duration = {ms} ms, FLOPS = {gflops} GFLOPs") def __exit__(self, type, value, traceback): self.end = time.perf_counter() - #Convert to ms - self.duration = (self.end - self.start) + # Convert to ms + self.duration = self.end - self.start self.perf_counter.__exit__(type, value, traceback) if self.display: - self._print_exit_msg() + self._print_exit_msg() @property def total_flops(self): return self.perf_counter.get_total_flops() - + @property def total_io(self): return self.perf_counter.get_total_io() - + @property def flops_table(self): return self.perf_counter.get_table() - + def get_summary_flop_counts(self): return self.perf_counter.get_summary_flop_counts() - + def get_summary_io_counts(self): return self.perf_counter.get_summary_io_counts() - + @property def flop_counts(self): return self.perf_counter.get_flop_counts() - + @property def io_counts(self): return self.perf_counter.get_io_counts() - + def get_pretty_summary(self, depth): - return self.perf_counter.pretty_summary_counts(depth=depth if depth is not None else self.depth) + return self.perf_counter.pretty_summary_counts( + depth=depth if depth is not None else self.depth + ) + class CUDAPerformanceTimer(PerformanceTimer): """ `PerformanceTimer` that uses `cudaEvents` to record duration. - """ - + """ + def __enter__(self): self.start = torch.cuda.Event(enable_timing=True) self.end = torch.cuda.Event(enable_timing=True) self.start.record() - self.perf_counter = PerformanceCounterMode(display=self.display, depth=self.depth) + self.perf_counter = PerformanceCounterMode( + display=self.display, depth=self.depth + ) self.perf_counter.__enter__() return self @@ -220,37 +257,31 @@ def __exit__(self, type, value, traceback): self.perf_counter.__exit__(type, value, traceback) if self.display: - self._print_exit_msg() + self._print_exit_msg() + def to_nearest_power_of_10(x, precision=2): - # Dictionary mapping powers of 10 to their metric abbreviations - metric_units = { - 0: '', - -6: 'µ', - -3: 'm', - 6: 'M', - 9: 'G', - 12: 'T' - } - + metric_units = {0: "", -6: "µ", -3: "m", 6: "M", 9: "G", 12: "T"} + # Determine the closest power of 10 if x == 0: return f"{x:.{precision}f}" - + power = int(math.floor(math.log10(abs(x)))) # Adjust power to fit within the given metric units powers = sorted(metric_units.keys()) closest_power = min(powers, key=lambda p: abs(p - power)) - + # Calculate the value formatted to the closest power of 10 value = x / 10**closest_power - + # Map the power to the metric unit unit = metric_units.get(closest_power, f"e{closest_power}") - + return f"{value:,.{precision}f} {unit}" + class DictMixin: """ Enables dict-like interface to dataclasses. @@ -261,25 +292,29 @@ def __getitem__(self, key): return getattr(self, key) else: raise KeyError(key) - + def __setitem__(self, key, value): setattr(self, key, value) - + def __contains__(self, key): return hasattr(self, key) - + def __iter__(self): for key in self.__dict__: yield key + def _get_property_methods(cls): - return [name for name, _ in inspect.getmembers(cls, lambda m: isinstance(m, property))] + return [ + name for name, _ in inspect.getmembers(cls, lambda m: isinstance(m, property)) + ] + @dataclass class PerformanceStats(DictMixin): """ Data struct that stores performance statistics. - + Attrs: num_tokens (int): number of tokens processed duration (float): duration in seconds @@ -291,7 +326,7 @@ class PerformanceStats(DictMixin): io_counts (Dict[str, Dict[Any, int]]): data movement by module and operation device_bandwidth (Optional[float]): device bandwidth in bytes per second device_flops_per_s (Optional[float]): device FLOPs per second - + Additionally, the following derived properties are available: token_throughput (float): number of tokens processed per second achieved_flops_per_s (float): achieved FLOPs per second @@ -301,7 +336,7 @@ class PerformanceStats(DictMixin): theoretical_compute_latency (Optional[float]): theoretical compute latency in seconds, set to None if no device FLOPs are available. """ - + label: str num_tokens: int duration: float @@ -312,26 +347,28 @@ class PerformanceStats(DictMixin): flop_counts: Dict[str, Dict[Any, int]] io_counts: Dict[str, Dict[Any, int]] device_bandwidth: Optional[float] = None - device_flops_per_s: Optional[float] = None - + device_flops_per_s: Optional[float] = None + @property def token_throughput(self): return self.num_tokens / self.duration - + @property def achieved_flops_per_s(self): return self.total_flops / self.duration - + @property def achieved_bandwidth(self): return self.total_io / self.duration - + @property def theoretical_io_latency(self): if self.device_bandwidth is not None: return self.total_io / self.device_bandwidth else: - warn_once("Device bandwidth is not specified. Please specify the device bandwidth to enable io latency calculation") + warn_once( + "Device bandwidth is not specified. Please specify the device bandwidth to enable io latency calculation" + ) return None @property @@ -339,28 +376,36 @@ def theoretical_compute_latency(self): if self.device_flops_per_s is not None: return self.total_flops / self.device_flops_per_s else: - warn_once("Device flops_per_s is not specified. Please specify the device throughput to enable compute latency calculation") + warn_once( + "Device flops_per_s is not specified. Please specify the device throughput to enable compute latency calculation" + ) return None - + @property def bandwidth_utilization(self): if self.device_bandwidth is not None: return self.achieved_bandwidth / self.device_bandwidth else: - warn_once("Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation") + warn_once( + "Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation" + ) return None + @property def flops_utilization(self): if self.device_flops_per_s is not None: return self.achieved_flops_per_s / self.device_flops_per_s else: - warn_once("Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation") + warn_once( + "Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation" + ) return None + def _format(self, value, suffix, precision=2, round=True): if round: return to_nearest_power_of_10(value, precision=precision) + suffix return f"{value:.{precision}f} " + suffix - + def __str__(self): txt = textwrap.dedent(f"""\ {self.label}: @@ -379,7 +424,7 @@ def __str__(self): Utilization Bandwidth: {self._format(self.bandwidth_utilization, round=False, precision=4, suffix="%") if self.bandwidth_utilization is not None else "N/A"} FLOPs: {self._format(self.flops_utilization, round=False, precision=4, suffix="%") if self.flops_utilization is not None else "N/A"}""") - + return txt def to_dict(self): @@ -389,15 +434,40 @@ def to_dict(self): d.update({prop: getattr(self, prop) for prop in props}) return d - + + class PerformanceCounterManager: - def __init__(self, depth=10, timer_cls: PerformanceTimer=PerformanceTimer, device_spec: DeviceSpec=None, verbose=False): + """ + Context manager-like class for tracking performance across multiple calls + to a Transformer model. + + Provides properties for accessing performance stats for data movement and FLOPs for each context as well as + summary stats across all contexts. + Additionally, if a device_spec is provided, theoretical peak bandwidth / FLOPs stats will be available. + + See `PerformanceStats` struct for description of tracked metrics. + + Example: + >>> manager = PerformanceCounterManager(device_spec=device_spec) + >>> with manager.count(label="prefill", num_tokens=x.numel()): + >>> out = model(encoded_prompt) + >>> manager.print_summary(labels=["prefill"]) # prints recorded stats for "prefill" context + >>> with manager.count(label="decode", num_tokens=1): + >>> out = model(out[-1]) + >>> manager.print_summary(labels=["decode"]) # prints recorded stats for "decode" context + >>> print(manager.print_summary) # prints accumulated stats across all contexts + """ + def __init__( + self, + depth=10, + timer_cls: PerformanceTimer = PerformanceTimer, + device_spec: DeviceSpec = None, + ): super().__init__() self._counts: Dict[str, PerformanceStats] = {} self._depth = depth self.timer_cls = timer_cls self.device_spec = device_spec - self.verbose = verbose @contextmanager def count(self, label: str, num_tokens: int): @@ -407,98 +477,122 @@ def count(self, label: str, num_tokens: int): yield self finally: perf_timer.__exit__(None, None, None) - stats = PerformanceStats(label=label, - num_tokens=num_tokens, - duration=perf_timer.duration, - total_flops=perf_timer.total_flops, - total_io=perf_timer.total_io, - flops_summary=perf_timer.get_summary_flop_counts(), - io_summary=perf_timer.get_summary_io_counts(), - flop_counts=perf_timer.flop_counts, - io_counts=perf_timer.io_counts, - device_bandwidth=self.device_spec.bandwidth if self.device_spec is not None else None, - device_flops_per_s=self.device_spec.flops_per_s if self.device_spec is not None else None) + stats = PerformanceStats( + label=label, + num_tokens=num_tokens, + duration=perf_timer.duration, + total_flops=perf_timer.total_flops, + total_io=perf_timer.total_io, + flops_summary=perf_timer.get_summary_flop_counts(), + io_summary=perf_timer.get_summary_io_counts(), + flop_counts=perf_timer.flop_counts, + io_counts=perf_timer.io_counts, + device_bandwidth=self.device_spec.bandwidth + if self.device_spec is not None + else None, + device_flops_per_s=self.device_spec.flops_per_s + if self.device_spec is not None + else None, + ) self._counts[label] = stats + @property def counts(self): return self._counts + def get_counts(self): - return self._counts + return self._counts @property def total_flops(self): return sum(count.total_flops for count in self._counts.values()) - + @property def total_io(self): return sum(count.total_io for count in self._counts.values()) + @property def total_tokens(self): return sum(count.num_tokens for count in self._counts.values()) - + @property def total_time(self): return sum(count.duration for count in self._counts.values()) - + def _summarize_stat(self, key): - return {label: getattr(self._counts[label], key) for label in self._counts.keys()} - + return { + label: getattr(self._counts[label], key) for label in self._counts.keys() + } + @property def flops_summary(self): return self._summarize_stat(key="flops_summary") - + @property def io_summary(self): return self._summarize_stat(key="io_summary") - + @property def flop_counts_summary(self): return self._summarize_stat(key="flop_counts") - + @property def io_counts_summary(self): return self._summarize_stat(key="io_counts") + @property def stats_summary(self): - stats = PerformanceStats(label="Performance Summary", - num_tokens=self.total_tokens, - duration=self.total_time, - total_flops=self.total_flops, - total_io=self.total_io, - flops_summary=self.flops_summary, - io_summary=self.io_summary, - flop_counts=self.flop_counts_summary, - io_counts=self.io_counts_summary, - device_bandwidth=self.device_spec.bandwidth if self.device_spec is not None else None, - device_flops_per_s=self.device_spec.flops_per_s if self.device_spec is not None else None) - + stats = PerformanceStats( + label="Performance Summary", + num_tokens=self.total_tokens, + duration=self.total_time, + total_flops=self.total_flops, + total_io=self.total_io, + flops_summary=self.flops_summary, + io_summary=self.io_summary, + flop_counts=self.flop_counts_summary, + io_counts=self.io_counts_summary, + device_bandwidth=self.device_spec.bandwidth + if self.device_spec is not None + else None, + device_flops_per_s=self.device_spec.flops_per_s + if self.device_spec is not None + else None, + ) + return stats - + def print_summary(self, labels: list[str] = None): - _print = partial(print, flush=True, end='\n') + _print = partial(print, flush=True, end="\n") # Delegate to __str__ of PerformanceStats for pretty printing if labels is None: text = str(self.stats_summary) _print(text) else: for label in labels: - text = str(self._counts[label]) + text = str(self._counts[label]) _print(self._counts[label]) - + def to_dict(self): # Convert flop_counts from OpOverloadPackets to str # Then delegate to PerformanceStats `to_dict`, which updates with derived metrics (property methods) counts = deepcopy(self._counts) - for label,label_counts in counts.items(): - counts[label]['flop_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['flop_counts'].items()} - counts[label]['io_counts'] = {mod: {str(op): count for op, count in op_count.items()} for mod, op_count in label_counts['io_counts'].items()} + for label, label_counts in counts.items(): + counts[label]["flop_counts"] = { + mod: {str(op): count for op, count in op_count.items()} + for mod, op_count in label_counts["flop_counts"].items() + } + counts[label]["io_counts"] = { + mod: {str(op): count for op, count in op_count.items()} + for mod, op_count in label_counts["io_counts"].items() + } counts[label] = counts[label].to_dict() - + return counts - + def to_json(self, path: Union[str, Path] = None): d = self.to_dict() if path: - with open(path, 'w') as f: + with open(path, "w") as f: f.write(json.dumps(d, indent=2)) - return d \ No newline at end of file + return d From ea0f2b61b58eebb773c09f9c9d347efbb8bf56b4 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 12:46:08 +0000 Subject: [PATCH 26/33] formatting --- torchao/profiler/performance_counter.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 55f2953c2..20733111a 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -440,13 +440,13 @@ class PerformanceCounterManager: """ Context manager-like class for tracking performance across multiple calls to a Transformer model. - - Provides properties for accessing performance stats for data movement and FLOPs for each context as well as - summary stats across all contexts. + + Provides properties for accessing performance stats for data movement and FLOPs for each context as well as + summary stats across all contexts. Additionally, if a device_spec is provided, theoretical peak bandwidth / FLOPs stats will be available. - + See `PerformanceStats` struct for description of tracked metrics. - + Example: >>> manager = PerformanceCounterManager(device_spec=device_spec) >>> with manager.count(label="prefill", num_tokens=x.numel()): @@ -455,8 +455,9 @@ class PerformanceCounterManager: >>> with manager.count(label="decode", num_tokens=1): >>> out = model(out[-1]) >>> manager.print_summary(labels=["decode"]) # prints recorded stats for "decode" context - >>> print(manager.print_summary) # prints accumulated stats across all contexts + >>> print(manager.print_summary) # prints accumulated stats across all contexts """ + def __init__( self, depth=10, From 996538f30e5d4c24238d9d75469a84a1c6a2ab72 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 14:13:21 +0000 Subject: [PATCH 27/33] clean up warnings --- test/profiler/test_device_spec.py | 36 +- test/profiler/test_performance_counter.py | 463 ++++++++++++++-------- test/profiler/utils.py | 26 +- torchao/profiler/performance_counter.py | 26 +- 4 files changed, 341 insertions(+), 210 deletions(-) diff --git a/test/profiler/test_device_spec.py b/test/profiler/test_device_spec.py index fb45dca86..b1166f3ad 100644 --- a/test/profiler/test_device_spec.py +++ b/test/profiler/test_device_spec.py @@ -1,7 +1,8 @@ - import pytest -cuda_driver = pytest.importorskip("triton.runtime.driver", reason="requires triton cuda driver module") +cuda_driver = pytest.importorskip( + "triton.runtime.driver", reason="requires triton cuda driver module" +) import itertools from contextlib import contextmanager from unittest.mock import patch @@ -21,7 +22,10 @@ USE_TENSORCORES = [True, False] DEVICE_CONFIGS = itertools.product(DEVICE_NAMES, DTYPES, USE_TENSORCORES) -@pytest.mark.parametrize("device_name, dtype, use_tensorcores", DEVICE_CONFIGS, ids=lambda x: str(x)) + +@pytest.mark.parametrize( + "device_name, dtype, use_tensorcores", DEVICE_CONFIGS, ids=lambda x: str(x) +) def test_device_spec(device_name, dtype, use_tensorcores): with patch_device(device_name): device_spec = CUDADeviceSpec(dtype=dtype, use_tensorcores=use_tensorcores) @@ -31,8 +35,10 @@ def test_device_spec(device_name, dtype, use_tensorcores): expected_flops = _AVAILABLE_GPU_SPECS[chip_name][dtype] assert device_spec.flops_per_s == expected_flops assert device_spec.flops_by_dtype[dtype] == expected_flops - assert device_spec.roofline_balancepoint == expected_flops / device_spec.bandwidth - + assert ( + device_spec.roofline_balancepoint == expected_flops / device_spec.bandwidth + ) + with pytest.raises(AssertionError): device_spec.flops_per_s = None print(device_spec.roofline_balancepoint) @@ -40,25 +46,27 @@ def test_device_spec(device_name, dtype, use_tensorcores): with pytest.raises(AttributeError): device_spec.FLOPs = None + def test_empty_device_spec(): device_name = "fake device" with patch_device(device_name): with pytest.raises(AssertionError): _ = CUDADeviceSpec() - + # Ok to instantiate as long as fields are filled - _ = CUDADeviceSpec(name=device_name, - flops_per_s=1.0, - bandwidth=1.0, - dtype=torch.float32, - use_tensorcores=True) + _ = CUDADeviceSpec( + name=device_name, + flops_per_s=1.0, + bandwidth=1.0, + dtype=torch.float32, + use_tensorcores=True, + ) device_name = DEVICE_NAMES[0] - + with patch_device(device_name): # All critical fields will be auto-filled except for dtype (and vram, but vram is not used for downstream calcs atm) _ = CUDADeviceSpec(dtype=torch.float32) - + # No dtype specified with pytest.raises(AssertionError): _ = CUDADeviceSpec() - diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index db99bc247..c6b35cd27 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -33,54 +33,69 @@ from torchao.profiler.device_spec import CUDADeviceSpec, DeviceSpec from torchao.profiler.performance_counter import ( CUDAPerformanceTimer, - PerformanceCounterManager, PerformanceCounterMode, PerformanceStats, PerformanceTimer, + TransformerPerformanceCounter, ) from torchao.utils import TORCH_VERSION_AFTER_2_5 # ------------------- PerformanceCounter Tests ------------------- # -PERFCOUNTER_TEST_CONFIGS = [PerfCounterTestConfig(name="3.5B", - batch_size=1, - seqlen=128, - dtype=torch.float16, - num_hidden_layers=32 // 2, - hidden_size=4096 // 2, - intermediate_size=11008 // 2, - num_attention_heads=32 // 2, - vocab_size=32000 // 2), - PerfCounterTestConfig(name="1.25B", - batch_size=1, - seqlen=128, - dtype=torch.float16, - num_hidden_layers=32 // 4, - hidden_size=4096 // 4, - intermediate_size=11008 // 4, - num_attention_heads=32 // 4, - vocab_size=32000 // 4), - PerfCounterTestConfig(name="tiny", - batch_size=1, - seqlen=128, - dtype=torch.float16, - num_hidden_layers=1, - hidden_size=4096 // 4, - intermediate_size=11008 // 4, - num_attention_heads=32 // 4, - vocab_size=32000 // 4)] - -@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "PerformanceCounter requires torch >= 2.5+.") +PERFCOUNTER_TEST_CONFIGS = [ + PerfCounterTestConfig( + name="3.5B", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=32 // 2, + hidden_size=4096 // 2, + intermediate_size=11008 // 2, + num_attention_heads=32 // 2, + vocab_size=32000 // 2, + ), + PerfCounterTestConfig( + name="1.25B", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=32 // 4, + hidden_size=4096 // 4, + intermediate_size=11008 // 4, + num_attention_heads=32 // 4, + vocab_size=32000 // 4, + ), + PerfCounterTestConfig( + name="tiny", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=1, + hidden_size=4096 // 4, + intermediate_size=11008 // 4, + num_attention_heads=32 // 4, + vocab_size=32000 // 4, + ), +] + + +@unittest.skipIf( + not TORCH_VERSION_AFTER_2_5, "PerformanceCounter requires torch >= 2.5+." +) @unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounter requires CUDA") -@parameterized_class([asdict(cfg) for cfg in PERFCOUNTER_TEST_CONFIGS], class_name_func=get_test_name) +@parameterized_class( + [asdict(cfg) for cfg in PERFCOUNTER_TEST_CONFIGS], class_name_func=get_test_name +) class PerformanceCounterTest(unittest.TestCase): @classmethod def setUpClass(cls): - model_cfg = LlamaConfig(num_hidden_layers=cls.num_hidden_layers, - hidden_size=cls.hidden_size, - intermediate_size=cls.intermediate_size, - num_attention_heads=cls.num_attention_heads, - vocab_size=cls.vocab_size) + model_cfg = LlamaConfig( + num_hidden_layers=cls.num_hidden_layers, + hidden_size=cls.hidden_size, + intermediate_size=cls.intermediate_size, + num_attention_heads=cls.num_attention_heads, + vocab_size=cls.vocab_size, + ) # Note we set some options manually since the model doesn't seem to be initialized correctly # when these options are set in LlamaConfig @@ -88,10 +103,14 @@ def setUpClass(cls): cls.model = model = LlamaForCausalLM(model_cfg).to(cls.dtype).to("cuda") cls.model_config = model.config cls.element_size = cls.dtype.itemsize - - input_ids = torch.randint(0, model.config.vocab_size, (cls.batch_size, cls.seqlen), device="cuda") + + input_ids = torch.randint( + 0, model.config.vocab_size, (cls.batch_size, cls.seqlen), device="cuda" + ) with torch.no_grad(): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION + ): with PerformanceCounterMode() as perf_counter: _ = model(input_ids) cls.perf_counter = perf_counter @@ -103,7 +122,7 @@ def setUpClass(cls): def test_qkv_proj(self): batch_size, seqlen = self.batch_size, self.seqlen element_size = self.element_size - + assert len(self.summary_flops) == len(self.summary_io) assert self.summary_flops.keys() == self.summary_io.keys() @@ -112,72 +131,101 @@ def test_qkv_proj(self): # Flops check proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) assert len(proj_keys) == self.model.config.num_hidden_layers - expected_flops = 2 * batch_size * seqlen * self.model_config.hidden_size * self.model_config.hidden_size + expected_flops = ( + 2 + * batch_size + * seqlen + * self.model_config.hidden_size + * self.model_config.hidden_size + ) assert expected_flops == self.summary_flops[proj_keys[0]] - + # io check - expected_size = qkv_proj_io_check(self.model_config, batch_size, seqlen, element_size) + expected_size = qkv_proj_io_check( + self.model_config, batch_size, seqlen, element_size + ) assert expected_size == self.summary_io[proj_keys[0]] - + def test_attn(self): batch_size, seqlen = self.batch_size, self.seqlen element_size = self.element_size model_config = self.model.config - + attention_keys = get_leaf_nodes(self.summary_flops.keys(), "self_attn") for k in attention_keys: flops = self.flops_by_op[k] io_movement = self.io_by_op[k] for op, count in flops.items(): - if "attention" in op.__name__: - expected_flops = 2 * 2 * batch_size * seqlen * seqlen * model_config.hidden_size + if "attention" in op.__name__: + expected_flops = ( + 2 * 2 * batch_size * seqlen * seqlen * model_config.hidden_size + ) assert expected_flops == count for op, count in io_movement.items(): if "attention" in op.__name__: # Check approx equal due to other small artifacts returned by sdpa.mem_efficient_attention # See #https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L867 # Check within 100 bytes - expected_size = attn_io_check(model_config, batch_size, seqlen, element_size) + expected_size = attn_io_check( + model_config, batch_size, seqlen, element_size + ) assert abs(expected_size - count) < 100 - + def test_ffn(self): batch_size, seqlen = self.batch_size, self.seqlen element_size = self.element_size - + for k in ["up_proj", "gate_proj", "down_proj"]: proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) assert len(proj_keys) == self.model.config.num_hidden_layers - expected_flops = 2 * batch_size * seqlen * self.model_config.hidden_size * self.model_config.intermediate_size + expected_flops = ( + 2 + * batch_size + * seqlen + * self.model_config.hidden_size + * self.model_config.intermediate_size + ) assert expected_flops == self.summary_flops[proj_keys[0]] - + # io check - expected_size = ffn_io_check(self.model_config, batch_size, seqlen, element_size, k) + expected_size = ffn_io_check( + self.model_config, batch_size, seqlen, element_size, k + ) assert expected_size == self.summary_io[proj_keys[0]] - + + # ------------------- PerformanceStats Tests ------------------- # -PERFSTATS_TEST_CONFIGS = [PerfStatsTestConfig(label="with_device", - num_tokens=128, - duration=0.1, - total_flops=123e9, - total_io=123e6, - flops_summary={"a": 234e12, "b": 345e9}, - io_summary={"a": 1, "b": 2}, - flop_counts={"a": 234e12, "b": 345e9}, - io_counts={"a": 1, "b": 2}, - device_bandwidth=1e9, - device_flops_per_s=23e9), - PerfStatsTestConfig(label="no_device", - num_tokens=128, - duration=0.1, - total_flops=123e9, - total_io=123e6, - flops_summary={"a": 234e12, "b": 345e9}, - io_summary={"a": 1, "b": 2}, - flop_counts={"a": 234e12, "b": 345e9}, - io_counts={"a": 1, "b": 2}, - device_bandwidth=None, - device_flops_per_s=None)] +PERFSTATS_TEST_CONFIGS = [ + PerfStatsTestConfig( + label="with_device", + num_tokens=128, + duration=0.1, + total_flops=123e9, + total_io=123e6, + flops_summary={"a": 234e12, "b": 345e9}, + io_summary={"a": 1, "b": 2}, + flop_counts={"a": 234e12, "b": 345e9}, + io_counts={"a": 1, "b": 2}, + device_bandwidth=1e9, + device_flops_per_s=23e9, + ), + PerfStatsTestConfig( + label="no_device", + num_tokens=128, + duration=0.1, + total_flops=123e9, + total_io=123e6, + flops_summary={"a": 234e12, "b": 345e9}, + io_summary={"a": 1, "b": 2}, + flop_counts={"a": 234e12, "b": 345e9}, + io_counts={"a": 1, "b": 2}, + device_bandwidth=None, + device_flops_per_s=None, + ), +] + + @pytest.mark.parametrize("cfg", PERFSTATS_TEST_CONFIGS, ids=lambda cfg: cfg.label) def test_performance_stats(cfg: PerfStatsTestConfig): stats = PerformanceStats(**asdict(cfg)) @@ -187,24 +235,28 @@ def test_performance_stats(cfg: PerfStatsTestConfig): total_io = cfg.total_io device_bandwidth = cfg.device_bandwidth device_flops_per_s = cfg.device_flops_per_s - + # Test derived metrics assert stats.token_throughput == num_tokens / duration assert stats.achieved_bandwidth == total_io / duration assert stats.achieved_flops_per_s == total_flops / duration if device_bandwidth is not None: - assert stats.bandwidth_utilization == stats.achieved_bandwidth / device_bandwidth + assert ( + stats.bandwidth_utilization == stats.achieved_bandwidth / device_bandwidth + ) assert stats.theoretical_io_latency == total_io / device_bandwidth else: assert stats.bandwidth_utilization is None assert stats.theoretical_io_latency is None if device_flops_per_s is not None: - assert stats.flops_utilization == stats.achieved_flops_per_s / device_flops_per_s + assert ( + stats.flops_utilization == stats.achieved_flops_per_s / device_flops_per_s + ) assert stats.theoretical_compute_latency == total_flops / device_flops_per_s else: assert stats.flops_utilization is None assert stats.theoretical_compute_latency is None - + # Test str - stats should be formatted to closest power of 10 ** 3 with 2 decimal places of precision stats_str = str(stats) @@ -219,28 +271,52 @@ def test_performance_stats(cfg: PerfStatsTestConfig): expected_flops_throughput_str = "1.23 TFLOPs/s" assert expected_io_throughput_str in stats_str assert expected_flops_throughput_str in stats_str - + # Utilization Stats if device_bandwidth is not None: - expected_bandwidth_utilization_str = f"{stats.achieved_bandwidth / device_bandwidth:.4f}" + expected_bandwidth_utilization_str = ( + f"{stats.achieved_bandwidth / device_bandwidth:.4f}" + ) expected_io_latency_str = f"{stats.theoretical_io_latency:.2f} s" assert expected_bandwidth_utilization_str in stats_str assert expected_io_latency_str in stats_str - + if device_flops_per_s is not None: - expected_flops_utilization_str = f"{stats.achieved_flops_per_s / device_flops_per_s:.4f}" + expected_flops_utilization_str = ( + f"{stats.achieved_flops_per_s / device_flops_per_s:.4f}" + ) expected_compute_latency_str = f"{stats.theoretical_compute_latency:.2f} s" assert expected_flops_utilization_str in stats_str assert expected_compute_latency_str in stats_str -# ------------------- PerformanceCounterManager Tests ------------------- # -PERFCOUNTERMANAGER_TEST_CONFIGS = [PerfCounterManagerTestConfig("no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0)), - PerfCounterManagerTestConfig("a100", (1, 1024, 4096, 4096), CUDAPerformanceTimer, torch.bfloat16, ("A100", 2e12))] -@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "PerformanceCounterManager requires torch >= 2.5+.") -@unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounterManager requires CUDA") -@parameterized_class([asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], class_name_func=get_test_name) -class TestPerformanceCounterManager(unittest.TestCase): +# ------------------- TransformerPerformanceCounter Tests ------------------- # + +PERFCOUNTERMANAGER_TEST_CONFIGS = [ + PerfCounterManagerTestConfig( + "no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0) + ), + PerfCounterManagerTestConfig( + "a100", + (1, 1024, 4096, 4096), + CUDAPerformanceTimer, + torch.bfloat16, + ("A100", 2e12), + ), +] + + +@unittest.skipIf( + not TORCH_VERSION_AFTER_2_5, "TransformerPerformanceCounter requires torch >= 2.5+." +) +@unittest.skipIf( + not torch.cuda.is_available(), "TransformerPerformanceCounter requires CUDA" +) +@parameterized_class( + [asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], + class_name_func=get_test_name, +) +class TestTransformerPerformanceCounter(unittest.TestCase): @classmethod def setUpClass(cls): shape, timer_cls, dtype = cls.shape, cls.timer_cls, cls.dtype @@ -254,50 +330,60 @@ def setUpClass(cls): if device_name is not None: with patch_device(device_name): device_spec = CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) - + else: device_spec = None - + # Stateful class level objects, which will be used in individual tests - cls.cm = cm =PerformanceCounterManager(timer_cls=timer_cls, device_spec=device_spec) + cls.cm = cm = TransformerPerformanceCounter( + timer_cls=timer_cls, device_spec=device_spec + ) cls.FLOAT_TOL = 1e-5 cls.expected = expected = {} - + # Start count for a start = time.perf_counter() with cm.count("a", num_tokens=num_tokens): _ = torch.matmul(a, b) end = time.perf_counter() - duration = (end - start) + duration = end - start expected_flops = 2 * num_tokens * in_features * out_features - expected_io = (num_tokens * in_features + in_features * out_features + num_tokens * out_features) * element_size - - expected['a'] = PerfCounterResult(name="a", - duration=duration, - flops=expected_flops, - io=expected_io, - total_flops=expected_flops, - total_io=expected_io) - + expected_io = ( + num_tokens * in_features + + in_features * out_features + + num_tokens * out_features + ) * element_size + + expected["a"] = PerfCounterResult( + name="a", + duration=duration, + flops=expected_flops, + io=expected_io, + total_flops=expected_flops, + total_io=expected_io, + ) + # Start count for b start = time.perf_counter() with cm.count("b", num_tokens=num_tokens): _ = torch.matmul(a, b) end = time.perf_counter() duration = end - start - - expected['b'] = PerfCounterResult(name="b", - duration=duration, - flops=expected_flops, - io=expected_io, - total_flops=cm.total_flops, - total_io=cm.total_io) - + + expected["b"] = PerfCounterResult( + name="b", + duration=duration, + flops=expected_flops, + io=expected_io, + total_flops=cm.total_flops, + total_io=cm.total_io, + ) + def test_perf_stats_a(self): - cm: PerformanceCounterManager = self.cm - expected = self.expected['a'] - + cm: TransformerPerformanceCounter = self.cm + expected = self.expected["a"] + counts = cm.get_counts() assert "a" in counts @@ -305,39 +391,41 @@ def test_perf_stats_a(self): psa: PerformanceStats = counts["a"] # Raw metrics # Duration won't be exact since timing external to the profiler - assert abs(psa.duration - expected.duration) < 1e-1 # +/- 100ms + assert abs(psa.duration - expected.duration) < 1e-1 # +/- 100ms assert psa.total_flops == expected.flops assert psa.total_io == expected.io - + # Derived metrics assert psa.token_throughput == psa.num_tokens / psa.duration assert psa.achieved_flops_per_s == psa.total_flops / psa.duration assert psa.achieved_bandwidth == psa.total_io / psa.duration - + def test_perf_stats_b(self): - cm: PerformanceCounterManager = self.cm + cm: TransformerPerformanceCounter = self.cm assert "a" in cm.counts assert "b" in cm.counts psa = cm.counts["a"] psb = cm.counts["b"] - expected = self.expected['b'] - assert abs(psb.duration - expected.duration) < 1e-1 # +/- 100ms + expected = self.expected["b"] + assert abs(psb.duration - expected.duration) < 1e-1 # +/- 100ms assert psb.total_flops == expected.flops assert psb.total_io == expected.io - + # check that **total** flops and io after matmul `b` has run accounts for both matmuls # also check that these global properties are updated correctly in the manager object - assert expected.total_flops == psa.total_flops + psb.total_flops == cm.total_flops + assert ( + expected.total_flops == psa.total_flops + psb.total_flops == cm.total_flops + ) assert expected.total_io == psa.total_io + psb.total_io == cm.total_io assert cm.total_time == psa.duration + psb.duration - + def test_stats_summary(self): - cm: PerformanceCounterManager = self.cm + cm: TransformerPerformanceCounter = self.cm FLOAT_TOL = self.FLOAT_TOL psa = cm.counts["a"] psb = cm.counts["b"] summary: PerformanceStats = cm.stats_summary - + # Raw stats assert summary.num_tokens == psa.num_tokens + psb.num_tokens assert summary.total_io == psa.total_io + psb.total_io @@ -345,67 +433,98 @@ def test_stats_summary(self): assert summary.duration == psa.duration + psb.duration # Derived stats - expected_token_throughput = (psa.num_tokens + psb.num_tokens) / (psa.duration + psb.duration) - expected_io_throughput = (psa.total_io + psb.total_io) / (psa.duration + psb.duration) - expected_flops_throughput = (psa.total_flops + psb.total_flops) / (psa.duration + psb.duration) + expected_token_throughput = (psa.num_tokens + psb.num_tokens) / ( + psa.duration + psb.duration + ) + expected_io_throughput = (psa.total_io + psb.total_io) / ( + psa.duration + psb.duration + ) + expected_flops_throughput = (psa.total_flops + psb.total_flops) / ( + psa.duration + psb.duration + ) assert abs(summary.token_throughput - expected_token_throughput) < FLOAT_TOL assert abs(summary.achieved_bandwidth - expected_io_throughput) < FLOAT_TOL assert abs(summary.achieved_flops_per_s - expected_flops_throughput) < FLOAT_TOL - + device_spec = cm.device_spec if device_spec is not None: - expected_bandwidth_utilization = expected_io_throughput / device_spec.bandwidth - expected_flops_utilization = expected_flops_throughput / device_spec.flops_per_s - assert abs(summary.bandwidth_utilization - expected_bandwidth_utilization) < FLOAT_TOL - assert abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL + expected_bandwidth_utilization = ( + expected_io_throughput / device_spec.bandwidth + ) + expected_flops_utilization = ( + expected_flops_throughput / device_spec.flops_per_s + ) + assert ( + abs(summary.bandwidth_utilization - expected_bandwidth_utilization) + < FLOAT_TOL + ) + assert ( + abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL + ) else: assert summary.bandwidth_utilization is None assert summary.flops_utilization is None - + def test_json(self): - cm: PerformanceCounterManager = self.cm + cm: TransformerPerformanceCounter = self.cm psa: PerformanceStats = cm.counts["a"] psb: PerformanceStats = cm.counts["b"] device_spec: Union[DeviceSpec, None] = cm.device_spec - + with tempfile.TemporaryDirectory() as tmp_dir: json_path = Path(tmp_dir) / "test.json" - cm.to_json(json_path) + cm.to_json(json_path) - with open(json_path, 'r') as f: + with open(json_path, "r") as f: perf_dict = json.load(f) - assert 'a' in perf_dict - assert 'b' in perf_dict - - #Test basic stats are recorded properly - assert perf_dict['a']['num_tokens'] == psa.num_tokens - assert perf_dict['a']['total_io'] == psa.total_io - assert perf_dict['a']['total_flops'] == psa.total_flops - assert perf_dict['a']['duration'] == psa.duration - - assert perf_dict['b']['num_tokens'] == psb.num_tokens - assert perf_dict['b']['total_io'] == psb.total_io - assert perf_dict['b']['total_flops'] == psb.total_flops - assert perf_dict['b']['duration'] == psb.duration - + assert "a" in perf_dict + assert "b" in perf_dict + + # Test basic stats are recorded properly + assert perf_dict["a"]["num_tokens"] == psa.num_tokens + assert perf_dict["a"]["total_io"] == psa.total_io + assert perf_dict["a"]["total_flops"] == psa.total_flops + assert perf_dict["a"]["duration"] == psa.duration + + assert perf_dict["b"]["num_tokens"] == psb.num_tokens + assert perf_dict["b"]["total_io"] == psb.total_io + assert perf_dict["b"]["total_flops"] == psb.total_flops + assert perf_dict["b"]["duration"] == psb.duration + # Test derived properties are present - perf_dict['a']['achieved_flops_per_s'] == psa.achieved_flops_per_s - perf_dict['a']['achieved_bandwidth'] == psa.achieved_bandwidth - perf_dict['b']['achieved_flops_per_s'] == psb.achieved_flops_per_s - perf_dict['b']['achieved_bandwidth'] == psb.achieved_bandwidth - + perf_dict["a"]["achieved_flops_per_s"] == psa.achieved_flops_per_s + perf_dict["a"]["achieved_bandwidth"] == psa.achieved_bandwidth + perf_dict["b"]["achieved_flops_per_s"] == psb.achieved_flops_per_s + perf_dict["b"]["achieved_bandwidth"] == psb.achieved_bandwidth + if device_spec is not None: - assert perf_dict['a']['device_flops_per_s'] == device_spec.flops_per_s - assert perf_dict['a']['device_bandwidth'] == device_spec.bandwidth - assert perf_dict['a']['theoretical_io_latency'] == psa.theoretical_io_latency - assert perf_dict['a']['theoretical_compute_latency'] == psa.theoretical_compute_latency - assert perf_dict['a']['bandwidth_utilization'] == psa.bandwidth_utilization - assert perf_dict['a']['flops_utilization'] == psa.flops_utilization - - assert perf_dict['b']['device_flops_per_s'] == device_spec.flops_per_s - assert perf_dict['b']['device_bandwidth'] == device_spec.bandwidth - assert perf_dict['b']['theoretical_io_latency'] == psb.theoretical_io_latency - assert perf_dict['b']['theoretical_compute_latency'] == psb.theoretical_compute_latency - assert perf_dict['b']['bandwidth_utilization'] == psb.bandwidth_utilization - assert perf_dict['b']['flops_utilization'] == psb.flops_utilization \ No newline at end of file + assert perf_dict["a"]["device_flops_per_s"] == device_spec.flops_per_s + assert perf_dict["a"]["device_bandwidth"] == device_spec.bandwidth + assert ( + perf_dict["a"]["theoretical_io_latency"] + == psa.theoretical_io_latency + ) + assert ( + perf_dict["a"]["theoretical_compute_latency"] + == psa.theoretical_compute_latency + ) + assert ( + perf_dict["a"]["bandwidth_utilization"] == psa.bandwidth_utilization + ) + assert perf_dict["a"]["flops_utilization"] == psa.flops_utilization + + assert perf_dict["b"]["device_flops_per_s"] == device_spec.flops_per_s + assert perf_dict["b"]["device_bandwidth"] == device_spec.bandwidth + assert ( + perf_dict["b"]["theoretical_io_latency"] + == psb.theoretical_io_latency + ) + assert ( + perf_dict["b"]["theoretical_compute_latency"] + == psb.theoretical_compute_latency + ) + assert ( + perf_dict["b"]["bandwidth_utilization"] == psb.bandwidth_utilization + ) + assert perf_dict["b"]["flops_utilization"] == psb.flops_utilization diff --git a/test/profiler/utils.py b/test/profiler/utils.py index bce565141..ac5808cfb 100644 --- a/test/profiler/utils.py +++ b/test/profiler/utils.py @@ -13,6 +13,7 @@ def patch_device(device_name): with patch("torch.cuda.get_device_name", return_value=device_name): yield + @dataclass(frozen=True) class PerfCounterTestConfig: name: str @@ -29,17 +30,21 @@ class PerfCounterTestConfig: def get_leaf_nodes(count_keys, module_name): return [k for k in count_keys if k.endswith(module_name)] + def qkv_proj_io_check(model_config, batch_size, seqlen, element_size): - input_size = batch_size * seqlen * model_config.hidden_size * element_size - weight_size = model_config.hidden_size * model_config.hidden_size * element_size + input_size = batch_size * seqlen * model_config.hidden_size * element_size + weight_size = model_config.hidden_size * model_config.hidden_size * element_size output_size = batch_size * seqlen * model_config.hidden_size * element_size return input_size + weight_size + output_size + + def attn_io_check(model_config, batch_size, seqlen, element_size): # queries, keys, values -> factor of 3 input_size = (batch_size * seqlen * model_config.hidden_size * 3) * element_size output_size = (batch_size * seqlen * model_config.hidden_size) * element_size - return input_size + output_size - + return input_size + output_size + + def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): assert module_name in ["up_proj", "gate_proj", "down_proj"] @@ -47,11 +52,15 @@ def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): input_size = batch_size * seqlen * model_config.intermediate_size * element_size else: input_size = batch_size * seqlen * model_config.hidden_size * element_size - weight_size = model_config.hidden_size * model_config.intermediate_size * element_size + weight_size = ( + model_config.hidden_size * model_config.intermediate_size * element_size + ) if module_name == "down_proj": output_size = batch_size * seqlen * model_config.hidden_size * element_size else: - output_size = batch_size * seqlen * model_config.intermediate_size * element_size + output_size = ( + batch_size * seqlen * model_config.intermediate_size * element_size + ) return input_size + weight_size + output_size @@ -69,10 +78,12 @@ class PerfStatsTestConfig: io_counts: dict device_bandwidth: Optional[float] = None device_flops_per_s: Optional[float] = None - + + def get_test_name(cls, num, params_dict): return f"{cls.__name__}_{num}_{params_dict['name']}" + @dataclass(frozen=True) class PerfCounterResult: name: str @@ -82,6 +93,7 @@ class PerfCounterResult: total_flops: float total_io: float + @dataclass class PerfCounterManagerTestConfig: name: str diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 20733111a..52bb61599 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -20,21 +20,13 @@ aten = torch.ops.aten -# TODO: Quick hack to track issued warnings to prevent excessive output each time a field is missing. -# Implement a cleaner solution. -_issued_warnings = set() - -# Define a custom warning category class DeviceInfoMissing(UserWarning): pass -def warn_once(message): - global _issued_warnings - if message not in _issued_warnings: - warnings.warn(message, DeviceInfoMissing) - _issued_warnings.add(message) +# Prevent excessive output +warnings.simplefilter("once", DeviceInfoMissing) class PerformanceCounterMode(FlopCounterMode): @@ -172,7 +164,7 @@ class PerformanceTimer: display: bool depth (int): passed to `PerformanceCounterMode` if displaying and determines depth of module tree to display. **Note**: these attributes are primarily used for debugging when using the `PerformanceTimer` standalone. - The PerformanceCounterManager class is a higher-level API that should be used instead. + The TransformerPerformanceCounter class is a higher-level API that should be used instead. """ @@ -366,7 +358,7 @@ def theoretical_io_latency(self): if self.device_bandwidth is not None: return self.total_io / self.device_bandwidth else: - warn_once( + warnings.warn( "Device bandwidth is not specified. Please specify the device bandwidth to enable io latency calculation" ) return None @@ -376,7 +368,7 @@ def theoretical_compute_latency(self): if self.device_flops_per_s is not None: return self.total_flops / self.device_flops_per_s else: - warn_once( + warnings.warn( "Device flops_per_s is not specified. Please specify the device throughput to enable compute latency calculation" ) return None @@ -386,7 +378,7 @@ def bandwidth_utilization(self): if self.device_bandwidth is not None: return self.achieved_bandwidth / self.device_bandwidth else: - warn_once( + warnings.warn( "Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation" ) return None @@ -396,7 +388,7 @@ def flops_utilization(self): if self.device_flops_per_s is not None: return self.achieved_flops_per_s / self.device_flops_per_s else: - warn_once( + warnings.warn( "Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation" ) return None @@ -436,7 +428,7 @@ def to_dict(self): return d -class PerformanceCounterManager: +class TransformerPerformanceCounter: """ Context manager-like class for tracking performance across multiple calls to a Transformer model. @@ -448,7 +440,7 @@ class PerformanceCounterManager: See `PerformanceStats` struct for description of tracked metrics. Example: - >>> manager = PerformanceCounterManager(device_spec=device_spec) + >>> manager = TransformerPerformanceCounter(device_spec=device_spec) >>> with manager.count(label="prefill", num_tokens=x.numel()): >>> out = model(encoded_prompt) >>> manager.print_summary(labels=["prefill"]) # prints recorded stats for "prefill" context From eb797c0877983c1210bbbfb96f3c0e936cded86a Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 16:05:44 +0000 Subject: [PATCH 28/33] rename duration -> latency --- test/profiler/test_performance_counter.py | 46 +++++++++++------------ test/profiler/utils.py | 4 +- torchao/profiler/performance_counter.py | 44 ++++++++++++---------- 3 files changed, 50 insertions(+), 44 deletions(-) diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py index c6b35cd27..2cd1a3358 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/profiler/test_performance_counter.py @@ -200,7 +200,7 @@ def test_ffn(self): PerfStatsTestConfig( label="with_device", num_tokens=128, - duration=0.1, + latency=0.1, total_flops=123e9, total_io=123e6, flops_summary={"a": 234e12, "b": 345e9}, @@ -213,7 +213,7 @@ def test_ffn(self): PerfStatsTestConfig( label="no_device", num_tokens=128, - duration=0.1, + latency=0.1, total_flops=123e9, total_io=123e6, flops_summary={"a": 234e12, "b": 345e9}, @@ -230,16 +230,16 @@ def test_ffn(self): def test_performance_stats(cfg: PerfStatsTestConfig): stats = PerformanceStats(**asdict(cfg)) num_tokens = cfg.num_tokens - duration = cfg.duration + latency = cfg.latency total_flops = cfg.total_flops total_io = cfg.total_io device_bandwidth = cfg.device_bandwidth device_flops_per_s = cfg.device_flops_per_s # Test derived metrics - assert stats.token_throughput == num_tokens / duration - assert stats.achieved_bandwidth == total_io / duration - assert stats.achieved_flops_per_s == total_flops / duration + assert stats.token_throughput == num_tokens / latency + assert stats.achieved_bandwidth == total_io / latency + assert stats.achieved_flops_per_s == total_flops / latency if device_bandwidth is not None: assert ( stats.bandwidth_utilization == stats.achieved_bandwidth / device_bandwidth @@ -347,7 +347,7 @@ def setUpClass(cls): _ = torch.matmul(a, b) end = time.perf_counter() - duration = end - start + latency = end - start expected_flops = 2 * num_tokens * in_features * out_features expected_io = ( num_tokens * in_features @@ -357,7 +357,7 @@ def setUpClass(cls): expected["a"] = PerfCounterResult( name="a", - duration=duration, + latency=latency, flops=expected_flops, io=expected_io, total_flops=expected_flops, @@ -369,11 +369,11 @@ def setUpClass(cls): with cm.count("b", num_tokens=num_tokens): _ = torch.matmul(a, b) end = time.perf_counter() - duration = end - start + latency = end - start expected["b"] = PerfCounterResult( name="b", - duration=duration, + latency=latency, flops=expected_flops, io=expected_io, total_flops=cm.total_flops, @@ -390,15 +390,15 @@ def test_perf_stats_a(self): # Check captured performance stats psa: PerformanceStats = counts["a"] # Raw metrics - # Duration won't be exact since timing external to the profiler - assert abs(psa.duration - expected.duration) < 1e-1 # +/- 100ms + # Latency won't be exact since timing external to the profiler + assert abs(psa.latency - expected.latency) < 1e-1 # +/- 100ms assert psa.total_flops == expected.flops assert psa.total_io == expected.io # Derived metrics - assert psa.token_throughput == psa.num_tokens / psa.duration - assert psa.achieved_flops_per_s == psa.total_flops / psa.duration - assert psa.achieved_bandwidth == psa.total_io / psa.duration + assert psa.token_throughput == psa.num_tokens / psa.latency + assert psa.achieved_flops_per_s == psa.total_flops / psa.latency + assert psa.achieved_bandwidth == psa.total_io / psa.latency def test_perf_stats_b(self): cm: TransformerPerformanceCounter = self.cm @@ -407,7 +407,7 @@ def test_perf_stats_b(self): psa = cm.counts["a"] psb = cm.counts["b"] expected = self.expected["b"] - assert abs(psb.duration - expected.duration) < 1e-1 # +/- 100ms + assert abs(psb.latency - expected.latency) < 1e-1 # +/- 100ms assert psb.total_flops == expected.flops assert psb.total_io == expected.io @@ -417,7 +417,7 @@ def test_perf_stats_b(self): expected.total_flops == psa.total_flops + psb.total_flops == cm.total_flops ) assert expected.total_io == psa.total_io + psb.total_io == cm.total_io - assert cm.total_time == psa.duration + psb.duration + assert cm.total_time == psa.latency + psb.latency def test_stats_summary(self): cm: TransformerPerformanceCounter = self.cm @@ -430,17 +430,17 @@ def test_stats_summary(self): assert summary.num_tokens == psa.num_tokens + psb.num_tokens assert summary.total_io == psa.total_io + psb.total_io assert summary.total_flops == psa.total_flops + psb.total_flops - assert summary.duration == psa.duration + psb.duration + assert summary.latency == psa.latency + psb.latency # Derived stats expected_token_throughput = (psa.num_tokens + psb.num_tokens) / ( - psa.duration + psb.duration + psa.latency + psb.latency ) expected_io_throughput = (psa.total_io + psb.total_io) / ( - psa.duration + psb.duration + psa.latency + psb.latency ) expected_flops_throughput = (psa.total_flops + psb.total_flops) / ( - psa.duration + psb.duration + psa.latency + psb.latency ) assert abs(summary.token_throughput - expected_token_throughput) < FLOAT_TOL assert abs(summary.achieved_bandwidth - expected_io_throughput) < FLOAT_TOL @@ -485,12 +485,12 @@ def test_json(self): assert perf_dict["a"]["num_tokens"] == psa.num_tokens assert perf_dict["a"]["total_io"] == psa.total_io assert perf_dict["a"]["total_flops"] == psa.total_flops - assert perf_dict["a"]["duration"] == psa.duration + assert perf_dict["a"]["latency"] == psa.latency assert perf_dict["b"]["num_tokens"] == psb.num_tokens assert perf_dict["b"]["total_io"] == psb.total_io assert perf_dict["b"]["total_flops"] == psb.total_flops - assert perf_dict["b"]["duration"] == psb.duration + assert perf_dict["b"]["latency"] == psb.latency # Test derived properties are present perf_dict["a"]["achieved_flops_per_s"] == psa.achieved_flops_per_s diff --git a/test/profiler/utils.py b/test/profiler/utils.py index ac5808cfb..7b2b99980 100644 --- a/test/profiler/utils.py +++ b/test/profiler/utils.py @@ -69,7 +69,7 @@ def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): class PerfStatsTestConfig: label: str num_tokens: int - duration: float + latency: float total_flops: float total_io: float flops_summary: dict @@ -87,7 +87,7 @@ def get_test_name(cls, num, params_dict): @dataclass(frozen=True) class PerfCounterResult: name: str - duration: float + latency: float flops: float io: float total_flops: float diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index 52bb61599..c7f78e681 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -151,7 +151,7 @@ def _count_flops(self, func_packet, out, args, kwargs): class PerformanceTimer: """ - Context manager that records the duration, io, and flops of a torch operator / module. + Context manager that records the latency, io, and flops of a torch operator / module. Timing is done using `time.perf_counter` and can be overridden to use a different timer (see `CUDAPerformanceTimer`). @@ -182,14 +182,14 @@ def __enter__(self): def _print_exit_msg(self): gflops = round(self.total_flops / 1e9, self.precision) - ms = round(self.duration * 1e3, self.precision) + ms = round(self.latency * 1e3, self.precision) if self.display: - print(f"{self.name.upper()}: duration = {ms} ms, FLOPS = {gflops} GFLOPs") + print(f"{self.name.upper()}: latency = {ms} ms, FLOPS = {gflops} GFLOPs") def __exit__(self, type, value, traceback): self.end = time.perf_counter() # Convert to ms - self.duration = self.end - self.start + self.latency = self.end - self.start self.perf_counter.__exit__(type, value, traceback) if self.display: self._print_exit_msg() @@ -228,7 +228,7 @@ def get_pretty_summary(self, depth): class CUDAPerformanceTimer(PerformanceTimer): """ - `PerformanceTimer` that uses `cudaEvents` to record duration. + `PerformanceTimer` that uses `cudaEvents` to record latency. """ def __enter__(self): @@ -245,7 +245,7 @@ def __exit__(self, type, value, traceback): self.end.record() torch.cuda.synchronize() # Convert from ms to s - self.duration = self.start.elapsed_time(self.end) * 1e-3 + self.latency = self.start.elapsed_time(self.end) * 1e-3 self.perf_counter.__exit__(type, value, traceback) if self.display: @@ -309,7 +309,7 @@ class PerformanceStats(DictMixin): Attrs: num_tokens (int): number of tokens processed - duration (float): duration in seconds + latency (float): latency in seconds total_flops (int): total FLOPs total_io (int): total data movement in bytes flops_summary (Dict[str, int]): summary of FLOPs by module @@ -331,7 +331,7 @@ class PerformanceStats(DictMixin): label: str num_tokens: int - duration: float + latency: float total_flops: int total_io: int flops_summary: Dict[str, int] @@ -343,15 +343,15 @@ class PerformanceStats(DictMixin): @property def token_throughput(self): - return self.num_tokens / self.duration + return self.num_tokens / self.latency @property def achieved_flops_per_s(self): - return self.total_flops / self.duration + return self.total_flops / self.latency @property def achieved_bandwidth(self): - return self.total_io / self.duration + return self.total_io / self.latency @property def theoretical_io_latency(self): @@ -401,7 +401,7 @@ def _format(self, value, suffix, precision=2, round=True): def __str__(self): txt = textwrap.dedent(f"""\ {self.label}: - Duration = {self._format(self.duration, "s")} + Latency = {self._format(self.latency, "s")} Tokens Total: {self.num_tokens} tokens Throughput: {self.token_throughput:,.0f} tokens/s @@ -473,7 +473,7 @@ def count(self, label: str, num_tokens: int): stats = PerformanceStats( label=label, num_tokens=num_tokens, - duration=perf_timer.duration, + latency=perf_timer.latency, total_flops=perf_timer.total_flops, total_io=perf_timer.total_io, flops_summary=perf_timer.get_summary_flop_counts(), @@ -510,7 +510,7 @@ def total_tokens(self): @property def total_time(self): - return sum(count.duration for count in self._counts.values()) + return sum(count.latency for count in self._counts.values()) def _summarize_stat(self, key): return { @@ -538,7 +538,7 @@ def stats_summary(self): stats = PerformanceStats( label="Performance Summary", num_tokens=self.total_tokens, - duration=self.total_time, + latency=self.total_time, total_flops=self.total_flops, total_io=self.total_io, flops_summary=self.flops_summary, @@ -555,17 +555,23 @@ def stats_summary(self): return stats - def print_summary(self, labels: list[str] = None): + def print_summary(self, labels: list[str] = None, show: bool=False): _print = partial(print, flush=True, end="\n") # Delegate to __str__ of PerformanceStats for pretty printing if labels is None: text = str(self.stats_summary) - _print(text) + if show: + _print(text) + return text else: + txts = [] for label in labels: text = str(self._counts[label]) - _print(self._counts[label]) - + if show: + _print(text) + txts.append(text) + return '\n'.join(txts) + def to_dict(self): # Convert flop_counts from OpOverloadPackets to str # Then delegate to PerformanceStats `to_dict`, which updates with derived metrics (property methods) From 7b578acd2cf165933f9b850eaf3bd22ef9156a61 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 16:21:40 +0000 Subject: [PATCH 29/33] add gpt-fast example --- torchao/profiler/__init__.py | 51 ++++ torchao/profiler/performance_counter.py | 12 +- tutorials/profiler/README.md | 102 +++++++ tutorials/profiler/generate.py | 336 ++++++++++++++++++++++++ tutorials/profiler/model.py | 257 ++++++++++++++++++ tutorials/profiler/tokenizer.py | 112 ++++++++ 6 files changed, 864 insertions(+), 6 deletions(-) create mode 100644 torchao/profiler/__init__.py create mode 100644 tutorials/profiler/README.md create mode 100644 tutorials/profiler/generate.py create mode 100644 tutorials/profiler/model.py create mode 100644 tutorials/profiler/tokenizer.py diff --git a/torchao/profiler/__init__.py b/torchao/profiler/__init__.py new file mode 100644 index 000000000..6de242f42 --- /dev/null +++ b/torchao/profiler/__init__.py @@ -0,0 +1,51 @@ +import inspect + +import torch + +# Re-exports +from .device_spec import CUDADeviceSpec, DeviceSpec +from .performance_counter import ( + CUDAPerformanceTimer, + PerformanceCounterMode, + PerformanceStats, + PerformanceTimer, + TransformerPerformanceCounter, +) + +_HUGGINGFACE_CAUSAL_LM_BASE_CLASSES = [ + "causallm", + "pretrainedmodel", + "generationmixin", +] + + +def get_all_base_classes(object): + return [cls.__name__.lower() for cls in inspect.getmro(object.__class__)] + + +def total_model_params( + model: torch.nn.Module, + exclude_embeddings: bool = True, + embedding_key: str = "tok_embeddings", +) -> int: + num_params = sum(p.numel() for p in model.parameters()) + + # Exclude embeddings when calculating FLOP since they don't contribute to FLOP count + if exclude_embeddings: + # Not the cleanest, but check if any base class of the model is in _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES + if ( + len( + set(get_all_base_classes(model)).intersection( + _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES + ) + ) + > 0 + ): + num_params -= model.model.embed_tokens.weight.numel() + elif hasattr(model, embedding_key): + num_params -= getattr(model, embedding_key).weight.numel() + else: + raise ValueError( + f"Could not find embedding in model {model.__class__.__name__}, please specify embedding attribute key" + ) + return num_params diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py index c7f78e681..d79625d55 100644 --- a/torchao/profiler/performance_counter.py +++ b/torchao/profiler/performance_counter.py @@ -555,23 +555,23 @@ def stats_summary(self): return stats - def print_summary(self, labels: list[str] = None, show: bool=False): + def print_summary(self, labels: list[str] = None, show: bool = False): _print = partial(print, flush=True, end="\n") # Delegate to __str__ of PerformanceStats for pretty printing if labels is None: text = str(self.stats_summary) - if show: - _print(text) + if show: + _print(text) return text else: txts = [] for label in labels: text = str(self._counts[label]) - if show: + if show: _print(text) txts.append(text) - return '\n'.join(txts) - + return "\n".join(txts) + def to_dict(self): # Convert flop_counts from OpOverloadPackets to str # Then delegate to PerformanceStats `to_dict`, which updates with derived metrics (property methods) diff --git a/tutorials/profiler/README.md b/tutorials/profiler/README.md new file mode 100644 index 000000000..6ede6b8e2 --- /dev/null +++ b/tutorials/profiler/README.md @@ -0,0 +1,102 @@ + +## Performance Profiling Example + +An minimal reproduction of `gpt-fast` that demonstrates usage of `torchao.profiler.TransformerPerformanceCounter`. + +## Usage +```python +python generate.py --prompt "Hello my name is" --checkpoint_path path/to/model.pth --num_samples 1 --max_new_tokens 2 --save_path performance_stats.json +``` +where `checkpoint_path` is the checkpoint path of the converted model weights per `gpt-fast` and `save_path` specifies where to save performance stats. + + +Running the above command for `llama2-7b` should print the following, with accumulated stats saved to `performance_stats.json` + +``` +Loading model ... +Time to load model: 20.14 seconds + +============================== + +Using DeviceSpec(device_type=cuda, name=NVIDIA GeForce RTX 3090, dtype=torch.bfloat16, bandwidth=936.1GB/s, flops=35.6TFLOPs, vram=25.4GB) +Model Config: ModelArgs(block_size=2048, vocab_size=32000, n_layer=32, n_head=32, dim=4096, intermediate_size=11008, n_local_heads=32, head_dim=128, rope_base=10000, norm_eps=1e-05) +Active params, Total Params: 6607343616, 6738415616 + +============================== + +TransformerPerfCounter Metrics +PREFILL_SEQLEN-6: + Latency = 1.26 s + Tokens + Total: 6 tokens + Throughput: 5 tokens/s + IO + Total: 13.25 GB + Throughput: 10.54 GB/s + Theoretical Latency: 14.15 ms + FLOPs + Total: 79.31 GFLOPs + Throughput: 63.06 GFLOPs/s + Theoretical Latency: 2.23 ms + Utilization + Bandwidth: 0.0113 % + FLOPs: 0.0018 % + +============================== + +TransformerPerfCounter Metrics +DECODE_CTX-6_NUM_TOKS-1: + Latency = 0.16 s + Tokens + Total: 1 tokens + Throughput: 6 tokens/s + IO + Total: 13.22 GB + Throughput: 83.27 GB/s + Theoretical Latency: 14.13 ms + FLOPs + Total: 13.22 GFLOPs + Throughput: 83.24 GFLOPs/s + Theoretical Latency: 0.37 ms + Utilization + Bandwidth: 0.0890 % + FLOPs: 0.0023 % + +============================== + +Generated text for sample 0: Hello, my name is [Name + +GPTFast Sample Metrics + Time for inference 1: 6 prompt tokens 2 tokens generated, 1.57 sec total, 1.28 tokens/sec + Bandwidth achieved: 17.22 GB/s + +============================== + +GPTFast Aggregate Stats + Average tokens/sec: 1.28 + Memory used: 13.51 GB + +============================== + +TransformerPerfCounter +Performance Summary: + Latency = 1.42 s + Tokens + Total: 7 tokens + Throughput: 5 tokens/s + IO + Total: 26.47 GB + Throughput: 18.69 GB/s + Theoretical Latency: 28.28 ms + FLOPs + Total: 92.53 GFLOPs + Throughput: 65.33 GFLOPs/s + Theoretical Latency: 2.60 ms + Utilization + Bandwidth: 0.0200 % + FLOPs: 0.0018 % + +Saving performance results to performance_stats.json +``` + +**Note**: `generate.py` script is a stripped down version of the original `gpt-fast` script and currently does not support quantization, tensor parallelism, and speculative decoding, as the primary purpose is to demonstrate basic usage of the performance tracker. \ No newline at end of file diff --git a/tutorials/profiler/generate.py b/tutorials/profiler/generate.py new file mode 100644 index 000000000..20e672a3a --- /dev/null +++ b/tutorials/profiler/generate.py @@ -0,0 +1,336 @@ +import pytest + +# Skip if transformers is not installed +transformers = pytest.importorskip("transformers") +LlamaConfig = transformers.models.llama.modeling_llama.LlamaConfig +LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM +# import sys +import textwrap +import time +from pathlib import Path +from typing import Optional, Tuple, Union + +import torch +from model import Transformer +from tokenizer import get_tokenizer +from torch.nn.attention import SDPBackend + +from torchao.profiler import ( + CUDADeviceSpec, + TransformerPerformanceCounter, + total_model_params, +) + +DEVICE_SPEC: CUDADeviceSpec +PERF_COUNTER: TransformerPerformanceCounter +PERF_COUNTER_PREFIX = "TransformerPerfCounter" +GPT_FAST_PREFIX = "GPTFast" +DELIMITER = "\n" + "=" * 30 + "\n" + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet supported") + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: + # input_pos: [B, S] + seqlen = input_pos.shape[-1] + num_tokens = input_pos.numel() + assert num_tokens == seqlen + + step_name = f"prefill_seqlen-{seqlen}".upper() + with PERF_COUNTER.count(step_name, num_tokens=num_tokens): + logits = model(x, input_pos) + next_token = sample(logits, **sampling_kwargs)[0] + print(DELIMITER) + stats_str = PERF_COUNTER.print_summary(labels=[step_name], show=False) + print(f"{PERF_COUNTER_PREFIX} Metrics\n{stats_str}") + + return next_token + + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + context_len = input_pos[-1].item() + num_tokens = input_pos.numel() + assert input_pos.shape[-1] == 1 + assert num_tokens == 1 + + step_name = f"decode_ctx-{context_len}_num_toks-{num_tokens}".upper() + with PERF_COUNTER.count(step_name, num_tokens=num_tokens): + logits = model(x, input_pos) + next_token = sample(logits, **sampling_kwargs) + print(DELIMITER) + stats_str = PERF_COUNTER.print_summary(labels=[step_name], show=False) + print(f"{PERF_COUNTER_PREFIX} Metrics\n{stats_str}") + + return next_token + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.nn.attention.sdpa_kernel( + backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH] + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + callback=lambda x: x, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill( + model, prompt.view(1, -1), input_pos, **sampling_kwargs + ).clone() + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(1, -1), + input_pos, + max_new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq[T + 1 :] = torch.cat(generated_tokens) + + return seq + + +def encode_tokens(tokenizer, string, bos=True, device="cuda"): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def _load_model(checkpoint_path, device, precision): + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + + model = model.to(device=device, dtype=precision) + return model.eval() + + +def main( + prompt: str, + num_samples: int, + max_new_tokens: int, + top_k: int, + temperature: float, + checkpoint_path: Union[Path, str], + save_path: Union[Path, str], + device: str = "cuda", + precision: torch.dtype = torch.bfloat16, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + global DEVICE_SPEC + global PERF_COUNTER + + DEVICE_SPEC = CUDADeviceSpec(dtype=precision) + PERF_COUNTER = TransformerPerformanceCounter(depth=3, device_spec=DEVICE_SPEC) + print(DELIMITER) + print(f"Using {DEVICE_SPEC}") + print(f"Model Config: {model.config}") + + num_active_params = total_model_params(model, exclude_embeddings=True) + num_params = total_model_params(model, exclude_embeddings=False) + model_size = num_params * precision.itemsize + print(f"Active params, Total Params: {num_active_params}, {num_params}") + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + + aggregate_metrics = { + "tokens_per_sec": [], + } + + start = 0 + + for i in range(start, num_samples): + t0 = time.perf_counter() + + y = generate( + model, + encoded, + max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + + t = time.perf_counter() - t0 + txt = tokenizer.decode(y.tolist()) + print(DELIMITER) + print(f"Generated text for sample {i}: {txt}\n") + + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + sample_metrics = textwrap.dedent(f"""\ + {GPT_FAST_PREFIX} Sample Metrics + Time for inference {i+1}: {prompt_length} prompt tokens {tokens_generated} tokens generated, {t:.02f} sec total, {tokens_sec:.02f} tokens/sec + Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s""") + print( + textwrap.indent( + sample_metrics, + prefix=" ", + predicate=lambda line: not line.startswith(GPT_FAST_PREFIX), + ) + ) + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + + # First print aggregate stats from original gpt-fast script + print(DELIMITER) + gpt_stats = textwrap.dedent(f"""\ + {GPT_FAST_PREFIX} Aggregate Stats + Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f} + Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB""") + + print( + textwrap.indent( + gpt_stats, + prefix=" ", + predicate=lambda line: not line.startswith(GPT_FAST_PREFIX), + ) + ) + + # Print performance summary from TransformerPerformanceCounter + print(DELIMITER) + total_stats_str = PERF_COUNTER.print_summary(show=False) + print(f"{PERF_COUNTER_PREFIX}\n{total_stats_str}") + print(f"\nSaving performance results to {save_path}") + PERF_COUNTER.to_json(save_path) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="TransformerPerformanceCounter Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=2, help="Maximum number of new tokens." + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("/home/ubuntu/gpt-fast-dev/checkpoints/7B/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "--save_path", + type=Path, + default=Path("performance_stats.json"), + help="Path to save performance stats.", + ) + args = parser.parse_args() + main(**vars(args)) diff --git a/tutorials/profiler/model.py b/tutorials/profiler/model.py new file mode 100644 index 000000000..b89a19a0f --- /dev/null +++ b/tutorials/profiler/model.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config.lower() in str(name).lower()] + + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, + # take longer name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len(config[1]), name # make sure only one 'best' match + + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), + "7B": dict(n_layer=32, n_head=32, dim=4096), + "13B": dict(n_layer=40, n_head=40, dim=5120), + "30B": dict(n_layer=60, n_head=52, dim=6656), + "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf + "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), + "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "stories15M": dict(n_layer=6, n_head=6, dim=288), + "stories110M": dict(n_layer=12, n_head=12, dim=768), + + "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), + "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.output.weight.dtype + # For quantized layers, dtype is encoded in scales + if hasattr(self.output, "scales"): + dtype = self.output.scales.dtype + elif hasattr(self.output, "scales_and_zeros"): + dtype = self.output.scales_and_zeros.dtype + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/tutorials/profiler/tokenizer.py b/tutorials/profiler/tokenizer.py new file mode 100644 index 000000000..f60b3c13e --- /dev/null +++ b/tutorials/profiler/tokenizer.py @@ -0,0 +1,112 @@ +import os +import sentencepiece as spm +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from pathlib import Path +from typing import Dict + +class TokenizerInterface: + def __init__(self, model_path): + self.model_path = model_path + + def encode(self, text): + raise NotImplementedError("This method should be overridden by subclasses.") + + def decode(self, tokens): + raise NotImplementedError("This method should be overridden by subclasses.") + + def bos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + + def eos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + +class SentencePieceWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + self.processor = spm.SentencePieceProcessor(str(model_path)) + + def encode(self, text): + return self.processor.EncodeAsIds(text) + + def decode(self, tokens): + return self.processor.DecodeIds(tokens) + + def bos_id(self): + return self.processor.bos_id() + + def eos_id(self): + return self.processor.eos_id() + +class TiktokenWrapper(TokenizerInterface): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path): + super().__init__(model_path) + assert os.path.isfile(model_path), str(model_path) + mergeable_ranks = load_tiktoken_bpe(str(model_path)) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + # BOS / EOS token IDs + self._bos_id: int = self.special_tokens["<|begin_of_text|>"] + self._eos_id: int = self.special_tokens["<|end_of_text|>"] + + def encode(self, text): + return self.model.encode(text) + + def decode(self, tokens): + return self.model.decode(tokens) + + def bos_id(self): + return self._bos_id + + def eos_id(self): + return self._eos_id + +def get_tokenizer(tokenizer_model_path, model_name): + """ + Factory function to get the appropriate tokenizer based on the model name. + + Args: + - tokenizer_model_path (str): The file path to the tokenizer model. + - model_name (str): The name of the model, used to determine the tokenizer type. + + Returns: + - TokenizerInterface: An instance of a tokenizer. + """ + + if "llama-3" in str(model_name).lower(): + return TiktokenWrapper(tokenizer_model_path) + else: + return SentencePieceWrapper(tokenizer_model_path) From 842215dcfb09f3205921ca38fccfe6f44730b562 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 18:31:14 +0000 Subject: [PATCH 30/33] linting and formatting --- test/profiler/test_device_spec.py | 2 -- torchao/profiler/__init__.py | 11 +++++++++++ tutorials/profiler/generate.py | 5 ++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/test/profiler/test_device_spec.py b/test/profiler/test_device_spec.py index b1166f3ad..1ede428fe 100644 --- a/test/profiler/test_device_spec.py +++ b/test/profiler/test_device_spec.py @@ -4,8 +4,6 @@ "triton.runtime.driver", reason="requires triton cuda driver module" ) import itertools -from contextlib import contextmanager -from unittest.mock import patch import torch from utils import patch_device diff --git a/torchao/profiler/__init__.py b/torchao/profiler/__init__.py index 6de242f42..45e0322de 100644 --- a/torchao/profiler/__init__.py +++ b/torchao/profiler/__init__.py @@ -12,6 +12,17 @@ TransformerPerformanceCounter, ) +__all__ = [ + "CUDAPerformanceTimer", + "PerformanceCounterMode", + "PerformanceStats", + "PerformanceTimer", + "TransformerPerformanceCounter", + "CUDADeviceSpec", + "DeviceSpec", + "total_model_params", +] + _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES = [ "causallm", "pretrainedmodel", diff --git a/tutorials/profiler/generate.py b/tutorials/profiler/generate.py index 20e672a3a..28cd317fb 100644 --- a/tutorials/profiler/generate.py +++ b/tutorials/profiler/generate.py @@ -213,6 +213,7 @@ def main( tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) + print(f"{GPT_FAST_PREFIX}") print("Loading model ...") t0 = time.time() model = _load_model(checkpoint_path, device, precision) @@ -226,6 +227,7 @@ def main( DEVICE_SPEC = CUDADeviceSpec(dtype=precision) PERF_COUNTER = TransformerPerformanceCounter(depth=3, device_spec=DEVICE_SPEC) print(DELIMITER) + print(f"{PERF_COUNTER_PREFIX}") print(f"Using {DEVICE_SPEC}") print(f"Model Config: {model.config}") @@ -261,6 +263,7 @@ def main( t = time.perf_counter() - t0 txt = tokenizer.decode(y.tolist()) print(DELIMITER) + print(f"{GPT_FAST_PREFIX}") print(f"Generated text for sample {i}: {txt}\n") tokens_generated = y.size(0) - prompt_length @@ -323,7 +326,7 @@ def main( parser.add_argument( "--checkpoint_path", type=Path, - default=Path("/home/ubuntu/gpt-fast-dev/checkpoints/7B/model.pth"), + default=Path("./checkpoints/7B/model.pth"), help="Model checkpoint path.", ) parser.add_argument( From 4968ddf5bedd57532a49677c231ed1c3118ba213 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 10 Jul 2024 19:01:46 +0000 Subject: [PATCH 31/33] update profiler tutorial readme --- tutorials/profiler/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tutorials/profiler/README.md b/tutorials/profiler/README.md index 6ede6b8e2..668325f7c 100644 --- a/tutorials/profiler/README.md +++ b/tutorials/profiler/README.md @@ -99,4 +99,7 @@ Performance Summary: Saving performance results to performance_stats.json ``` -**Note**: `generate.py` script is a stripped down version of the original `gpt-fast` script and currently does not support quantization, tensor parallelism, and speculative decoding, as the primary purpose is to demonstrate basic usage of the performance tracker. \ No newline at end of file +**Notes** +- `generate.py` script is a stripped down version of the original `gpt-fast` script and currently does not support quantization, tensor parallelism, and speculative decoding, as the primary purpose is to demonstrate basic usage of the performance tracker. +- The discrepancy between `gpt-fast` token throughput and that of `TransformerPerformanceCounter` is due to the fact that `gpt-fast` does not account for all prefill tokens + - `gpt-fast` only counts generated tokens -- so even though `prefill` technically generated `len(prompt) + 1` tokens, it counts the number of tokens generated during this phase as `1`, whereas `TransformerPerformanceCounter` includes all `prefill` tokens in the total token count. \ No newline at end of file From 2b0b86fa19fdda54c54d45ad1a710d931742886e Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sat, 27 Jul 2024 12:50:50 +0000 Subject: [PATCH 32/33] move total_model_params to utils --- torchao/profiler/__init__.py | 41 +-------------------------------- torchao/profiler/utils.py | 44 ++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 40 deletions(-) create mode 100644 torchao/profiler/utils.py diff --git a/torchao/profiler/__init__.py b/torchao/profiler/__init__.py index 45e0322de..e748438e8 100644 --- a/torchao/profiler/__init__.py +++ b/torchao/profiler/__init__.py @@ -1,6 +1,3 @@ -import inspect - -import torch # Re-exports from .device_spec import CUDADeviceSpec, DeviceSpec @@ -11,6 +8,7 @@ PerformanceTimer, TransformerPerformanceCounter, ) +from .utils import total_model_params __all__ = [ "CUDAPerformanceTimer", @@ -23,40 +21,3 @@ "total_model_params", ] -_HUGGINGFACE_CAUSAL_LM_BASE_CLASSES = [ - "causallm", - "pretrainedmodel", - "generationmixin", -] - - -def get_all_base_classes(object): - return [cls.__name__.lower() for cls in inspect.getmro(object.__class__)] - - -def total_model_params( - model: torch.nn.Module, - exclude_embeddings: bool = True, - embedding_key: str = "tok_embeddings", -) -> int: - num_params = sum(p.numel() for p in model.parameters()) - - # Exclude embeddings when calculating FLOP since they don't contribute to FLOP count - if exclude_embeddings: - # Not the cleanest, but check if any base class of the model is in _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES - if ( - len( - set(get_all_base_classes(model)).intersection( - _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES - ) - ) - > 0 - ): - num_params -= model.model.embed_tokens.weight.numel() - elif hasattr(model, embedding_key): - num_params -= getattr(model, embedding_key).weight.numel() - else: - raise ValueError( - f"Could not find embedding in model {model.__class__.__name__}, please specify embedding attribute key" - ) - return num_params diff --git a/torchao/profiler/utils.py b/torchao/profiler/utils.py new file mode 100644 index 000000000..9276dd37b --- /dev/null +++ b/torchao/profiler/utils.py @@ -0,0 +1,44 @@ +import inspect + +import torch + +_HUGGINGFACE_CAUSAL_LM_BASE_CLASSES = [ + "causallm", + "pretrainedmodel", + "generationmixin", +] + + +def _get_all_base_classes(object): + return [cls.__name__.lower() for cls in inspect.getmro(object.__class__)] + + +def total_model_params( + model: torch.nn.Module, + exclude_embeddings: bool = True, + embedding_key: str = "tok_embeddings", +) -> int: + """ + Calculate total params of a HuggingFace CausalLM model or gpt-fast model + """ + num_params = sum(p.numel() for p in model.parameters()) + + # Exclude embeddings when calculating FLOP since they don't contribute to FLOP count + if exclude_embeddings: + # Not the cleanest, but check if any base class of the model is in _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES + if ( + len( + set(_get_all_base_classes(model)).intersection( + _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES + ) + ) + > 0 + ): + num_params -= model.model.embed_tokens.weight.numel() + elif hasattr(model, embedding_key): + num_params -= getattr(model, embedding_key).weight.numel() + else: + raise ValueError( + f"Could not find embedding in model {model.__class__.__name__}, please specify embedding attribute key" + ) + return num_params From b39443a3117030178398e0171210b48f803a7c7b Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sat, 27 Jul 2024 14:53:57 +0000 Subject: [PATCH 33/33] remove tutorials/profiler --- .../_models/llama/perf_profile.py | 119 +++++++- tutorials/profiler/README.md | 105 ------- tutorials/profiler/model.py | 257 ------------------ tutorials/profiler/tokenizer.py | 112 -------- 4 files changed, 111 insertions(+), 482 deletions(-) rename tutorials/profiler/generate.py => torchao/_models/llama/perf_profile.py (75%) delete mode 100644 tutorials/profiler/README.md delete mode 100644 tutorials/profiler/model.py delete mode 100644 tutorials/profiler/tokenizer.py diff --git a/tutorials/profiler/generate.py b/torchao/_models/llama/perf_profile.py similarity index 75% rename from tutorials/profiler/generate.py rename to torchao/_models/llama/perf_profile.py index 28cd317fb..1a0d4e36c 100644 --- a/tutorials/profiler/generate.py +++ b/torchao/_models/llama/perf_profile.py @@ -1,20 +1,123 @@ -import pytest +""" + +## Performance Profiling Example + +An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.profiler.TransformerPerformanceCounter`. +- Outputs from gpt-fast are prefixed with GPT-Fast +- Outputs from `torchao.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. + +## Usage +```python +python perf_profile.py --prompt "Hello my name is" --checkpoint_path path/to/model.pth --num_samples 1 --max_new_tokens 2 --save_path performance_stats.json +``` +where `checkpoint_path` is the checkpoint path of the converted model weights per `gpt-fast` and `save_path` specifies where to save performance stats. + + +Running the above command for `llama2-7b` should print the following, with accumulated stats saved to `performance_stats.json` + +``` +Loading model ... +Time to load model: 20.14 seconds + +============================== + +Using DeviceSpec(device_type=cuda, name=NVIDIA GeForce RTX 3090, dtype=torch.bfloat16, bandwidth=936.1GB/s, flops=35.6TFLOPs, vram=25.4GB) +Model Config: ModelArgs(block_size=2048, vocab_size=32000, n_layer=32, n_head=32, dim=4096, intermediate_size=11008, n_local_heads=32, head_dim=128, rope_base=10000, norm_eps=1e-05) +Active params, Total Params: 6607343616, 6738415616 + +============================== + +TransformerPerfCounter Metrics +PREFILL_SEQLEN-6: + Latency = 1.26 s + Tokens + Total: 6 tokens + Throughput: 5 tokens/s + IO + Total: 13.25 GB + Throughput: 10.54 GB/s + Theoretical Latency: 14.15 ms + FLOPs + Total: 79.31 GFLOPs + Throughput: 63.06 GFLOPs/s + Theoretical Latency: 2.23 ms + Utilization + Bandwidth: 0.0113 % + FLOPs: 0.0018 % + +============================== + +TransformerPerfCounter Metrics +DECODE_CTX-6_NUM_TOKS-1: + Latency = 0.16 s + Tokens + Total: 1 tokens + Throughput: 6 tokens/s + IO + Total: 13.22 GB + Throughput: 83.27 GB/s + Theoretical Latency: 14.13 ms + FLOPs + Total: 13.22 GFLOPs + Throughput: 83.24 GFLOPs/s + Theoretical Latency: 0.37 ms + Utilization + Bandwidth: 0.0890 % + FLOPs: 0.0023 % + +============================== + +Generated text for sample 0: Hello, my name is [Name + +GPTFast Sample Metrics + Time for inference 1: 6 prompt tokens 2 tokens generated, 1.57 sec total, 1.28 tokens/sec + Bandwidth achieved: 17.22 GB/s + +============================== + +GPTFast Aggregate Stats + Average tokens/sec: 1.28 + Memory used: 13.51 GB + +============================== + +TransformerPerfCounter +Performance Summary: + Latency = 1.42 s + Tokens + Total: 7 tokens + Throughput: 5 tokens/s + IO + Total: 26.47 GB + Throughput: 18.69 GB/s + Theoretical Latency: 28.28 ms + FLOPs + Total: 92.53 GFLOPs + Throughput: 65.33 GFLOPs/s + Theoretical Latency: 2.60 ms + Utilization + Bandwidth: 0.0200 % + FLOPs: 0.0018 % + +Saving performance results to performance_stats.json +``` + +**Notes** +- The discrepancy between `gpt-fast` token throughput and that of `TransformerPerformanceCounter` is due to the fact that gpt-fast` only counts generated tokens (no prefill) +-- so even though the `prefill` phase technically generates `len(prompt) + 1` tokens, it counts the number of tokens generated during this phase as `1`, +whereas `TransformerPerformanceCounter` includes all `prefill` tokens in the total token count. +""" -# Skip if transformers is not installed -transformers = pytest.importorskip("transformers") -LlamaConfig = transformers.models.llama.modeling_llama.LlamaConfig -LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM -# import sys import textwrap import time from pathlib import Path from typing import Optional, Tuple, Union import torch -from model import Transformer -from tokenizer import get_tokenizer from torch.nn.attention import SDPBackend +from torchao._models.llama.model import Transformer +from torchao._models.llama.tokenizer import get_tokenizer from torchao.profiler import ( CUDADeviceSpec, TransformerPerformanceCounter, diff --git a/tutorials/profiler/README.md b/tutorials/profiler/README.md deleted file mode 100644 index 668325f7c..000000000 --- a/tutorials/profiler/README.md +++ /dev/null @@ -1,105 +0,0 @@ - -## Performance Profiling Example - -An minimal reproduction of `gpt-fast` that demonstrates usage of `torchao.profiler.TransformerPerformanceCounter`. - -## Usage -```python -python generate.py --prompt "Hello my name is" --checkpoint_path path/to/model.pth --num_samples 1 --max_new_tokens 2 --save_path performance_stats.json -``` -where `checkpoint_path` is the checkpoint path of the converted model weights per `gpt-fast` and `save_path` specifies where to save performance stats. - - -Running the above command for `llama2-7b` should print the following, with accumulated stats saved to `performance_stats.json` - -``` -Loading model ... -Time to load model: 20.14 seconds - -============================== - -Using DeviceSpec(device_type=cuda, name=NVIDIA GeForce RTX 3090, dtype=torch.bfloat16, bandwidth=936.1GB/s, flops=35.6TFLOPs, vram=25.4GB) -Model Config: ModelArgs(block_size=2048, vocab_size=32000, n_layer=32, n_head=32, dim=4096, intermediate_size=11008, n_local_heads=32, head_dim=128, rope_base=10000, norm_eps=1e-05) -Active params, Total Params: 6607343616, 6738415616 - -============================== - -TransformerPerfCounter Metrics -PREFILL_SEQLEN-6: - Latency = 1.26 s - Tokens - Total: 6 tokens - Throughput: 5 tokens/s - IO - Total: 13.25 GB - Throughput: 10.54 GB/s - Theoretical Latency: 14.15 ms - FLOPs - Total: 79.31 GFLOPs - Throughput: 63.06 GFLOPs/s - Theoretical Latency: 2.23 ms - Utilization - Bandwidth: 0.0113 % - FLOPs: 0.0018 % - -============================== - -TransformerPerfCounter Metrics -DECODE_CTX-6_NUM_TOKS-1: - Latency = 0.16 s - Tokens - Total: 1 tokens - Throughput: 6 tokens/s - IO - Total: 13.22 GB - Throughput: 83.27 GB/s - Theoretical Latency: 14.13 ms - FLOPs - Total: 13.22 GFLOPs - Throughput: 83.24 GFLOPs/s - Theoretical Latency: 0.37 ms - Utilization - Bandwidth: 0.0890 % - FLOPs: 0.0023 % - -============================== - -Generated text for sample 0: Hello, my name is [Name - -GPTFast Sample Metrics - Time for inference 1: 6 prompt tokens 2 tokens generated, 1.57 sec total, 1.28 tokens/sec - Bandwidth achieved: 17.22 GB/s - -============================== - -GPTFast Aggregate Stats - Average tokens/sec: 1.28 - Memory used: 13.51 GB - -============================== - -TransformerPerfCounter -Performance Summary: - Latency = 1.42 s - Tokens - Total: 7 tokens - Throughput: 5 tokens/s - IO - Total: 26.47 GB - Throughput: 18.69 GB/s - Theoretical Latency: 28.28 ms - FLOPs - Total: 92.53 GFLOPs - Throughput: 65.33 GFLOPs/s - Theoretical Latency: 2.60 ms - Utilization - Bandwidth: 0.0200 % - FLOPs: 0.0018 % - -Saving performance results to performance_stats.json -``` - -**Notes** -- `generate.py` script is a stripped down version of the original `gpt-fast` script and currently does not support quantization, tensor parallelism, and speculative decoding, as the primary purpose is to demonstrate basic usage of the performance tracker. -- The discrepancy between `gpt-fast` token throughput and that of `TransformerPerformanceCounter` is due to the fact that `gpt-fast` does not account for all prefill tokens - - `gpt-fast` only counts generated tokens -- so even though `prefill` technically generated `len(prompt) + 1` tokens, it counts the number of tokens generated during this phase as `1`, whereas `TransformerPerformanceCounter` includes all `prefill` tokens in the total token count. \ No newline at end of file diff --git a/tutorials/profiler/model.py b/tutorials/profiler/model.py deleted file mode 100644 index b89a19a0f..000000000 --- a/tutorials/profiler/model.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn -from torch import Tensor -from torch.nn import functional as F - - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - -@dataclass -class ModelArgs: - block_size: int = 2048 - vocab_size: int = 32000 - n_layer: int = 32 - n_head: int = 32 - dim: int = 4096 - intermediate_size: int = None - n_local_heads: int = -1 - head_dim: int = 64 - rope_base: float = 10000 - norm_eps: float = 1e-5 - - def __post_init__(self): - if self.n_local_heads == -1: - self.n_local_heads = self.n_head - if self.intermediate_size is None: - hidden_dim = 4 * self.dim - n_hidden = int(2 * hidden_dim / 3) - self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head - - @classmethod - def from_name(cls, name: str): - if name in transformer_configs: - return cls(**transformer_configs[name]) - # fuzzy search - config = [config for config in transformer_configs if config.lower() in str(name).lower()] - - # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, - # take longer name (as it have more symbols matched) - if len(config) > 1: - config.sort(key=len, reverse=True) - assert len(config[0]) != len(config[1]), name # make sure only one 'best' match - - return cls(**transformer_configs[config[0]]) - - -transformer_configs = { - "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), - "7B": dict(n_layer=32, n_head=32, dim=4096), - "13B": dict(n_layer=40, n_head=40, dim=5120), - "30B": dict(n_layer=60, n_head=52, dim=6656), - "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf - "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), - "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), - "stories15M": dict(n_layer=6, n_head=6, dim=288), - "stories110M": dict(n_layer=12, n_head=12, dim=768), - - "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), - "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000), -} - -class KVCache(nn.Module): - def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): - super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) - - def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] - assert input_pos.shape[0] == k_val.shape[2] - - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out - -class Transformer(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - - self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) - self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) - self.output = nn.Linear(config.dim, config.vocab_size, bias=False) - - self.freqs_cis: Optional[Tensor] = None - self.mask_cache: Optional[Tensor] = None - self.max_batch_size = -1 - self.max_seq_length = -1 - - def setup_caches(self, max_batch_size, max_seq_length): - if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: - return - head_dim = self.config.dim // self.config.n_head - max_seq_length = find_multiple(max_seq_length, 8) - self.max_seq_length = max_seq_length - self.max_batch_size = max_batch_size - dtype = self.output.weight.dtype - # For quantized layers, dtype is encoded in scales - if hasattr(self.output, "scales"): - dtype = self.output.scales.dtype - elif hasattr(self.output, "scales_and_zeros"): - dtype = self.output.scales_and_zeros.dtype - for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) - - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) - self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) - - def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: - assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] - freqs_cis = self.freqs_cis[input_pos] - x = self.tok_embeddings(idx) - - for i, layer in enumerate(self.layers): - x = layer(x, input_pos, freqs_cis, mask) - x = self.norm(x) - logits = self.output(x) - return logits - - @classmethod - def from_name(cls, name: str): - return cls(ModelArgs.from_name(name)) - - -class TransformerBlock(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.attention = Attention(config) - self.feed_forward = FeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) - - def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.feed_forward(self.ffn_norm(h)) - return out - - -class Attention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - assert config.dim % config.n_head == 0 - - total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim - # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) - self.wo = nn.Linear(config.dim, config.dim, bias=False) - self.kv_cache = None - - self.n_head = config.n_head - self.head_dim = config.head_dim - self.n_local_heads = config.n_local_heads - self.dim = config.dim - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "wq.weight" in state_dict: - wq = state_dict.pop(prefix + "wq.weight") - wk = state_dict.pop(prefix + "wk.weight") - wv = state_dict.pop(prefix + "wv.weight") - state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - - def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: - bsz, seqlen, _ = x.shape - - kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) - - q = q.view(bsz, seqlen, self.n_head, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) - - q = apply_rotary_emb(q, freqs_cis) - k = apply_rotary_emb(k, freqs_cis) - - q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) - - if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) - - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - y = self.wo(y) - return y - - -class FeedForward(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) - - def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor) -> Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis( - seq_len: int, n_elem: int, base: int = 10000, - dtype: torch.dtype = torch.bfloat16 -) -> Tensor: - freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) - t = torch.arange(seq_len, device=freqs.device) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], - xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], - ], - -1, - ) - - x_out2 = x_out2.flatten(3) - return x_out2.type_as(x) diff --git a/tutorials/profiler/tokenizer.py b/tutorials/profiler/tokenizer.py deleted file mode 100644 index f60b3c13e..000000000 --- a/tutorials/profiler/tokenizer.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import sentencepiece as spm -import tiktoken -from tiktoken.load import load_tiktoken_bpe -from pathlib import Path -from typing import Dict - -class TokenizerInterface: - def __init__(self, model_path): - self.model_path = model_path - - def encode(self, text): - raise NotImplementedError("This method should be overridden by subclasses.") - - def decode(self, tokens): - raise NotImplementedError("This method should be overridden by subclasses.") - - def bos_id(self): - raise NotImplementedError("This method should be overridden by subclasses.") - - def eos_id(self): - raise NotImplementedError("This method should be overridden by subclasses.") - -class SentencePieceWrapper(TokenizerInterface): - def __init__(self, model_path): - super().__init__(model_path) - self.processor = spm.SentencePieceProcessor(str(model_path)) - - def encode(self, text): - return self.processor.EncodeAsIds(text) - - def decode(self, tokens): - return self.processor.DecodeIds(tokens) - - def bos_id(self): - return self.processor.bos_id() - - def eos_id(self): - return self.processor.eos_id() - -class TiktokenWrapper(TokenizerInterface): - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path): - super().__init__(model_path) - assert os.path.isfile(model_path), str(model_path) - mergeable_ranks = load_tiktoken_bpe(str(model_path)) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [ - f"<|reserved_special_token_{i}|>" - for i in range(5, self.num_reserved_special_tokens - 5) - ] - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - # BOS / EOS token IDs - self._bos_id: int = self.special_tokens["<|begin_of_text|>"] - self._eos_id: int = self.special_tokens["<|end_of_text|>"] - - def encode(self, text): - return self.model.encode(text) - - def decode(self, tokens): - return self.model.decode(tokens) - - def bos_id(self): - return self._bos_id - - def eos_id(self): - return self._eos_id - -def get_tokenizer(tokenizer_model_path, model_name): - """ - Factory function to get the appropriate tokenizer based on the model name. - - Args: - - tokenizer_model_path (str): The file path to the tokenizer model. - - model_name (str): The name of the model, used to determine the tokenizer type. - - Returns: - - TokenizerInterface: An instance of a tokenizer. - """ - - if "llama-3" in str(model_name).lower(): - return TiktokenWrapper(tokenizer_model_path) - else: - return SentencePieceWrapper(tokenizer_model_path)