|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import torch |
| 4 | +from vllm_test_utils.monitor import monitor |
| 5 | + |
| 6 | +from vllm.utils.mem_utils import MemorySnapshot, memory_profiling |
| 7 | + |
| 8 | +from ..utils import create_new_process_for_each_test |
| 9 | + |
| 10 | + |
| 11 | +@create_new_process_for_each_test() |
| 12 | +def test_memory_profiling(): |
| 13 | + # Fake out some model loading + inference memory usage to test profiling |
| 14 | + # Memory used by other processes will show up as cuda usage outside of torch |
| 15 | + from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary |
| 16 | + |
| 17 | + lib = CudaRTLibrary() |
| 18 | + # 512 MiB allocation outside of this instance |
| 19 | + handle1 = lib.cudaMalloc(512 * 1024 * 1024) |
| 20 | + |
| 21 | + baseline_snapshot = MemorySnapshot() |
| 22 | + |
| 23 | + # load weights |
| 24 | + |
| 25 | + weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32) |
| 26 | + |
| 27 | + weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB |
| 28 | + |
| 29 | + def measure_current_non_torch(): |
| 30 | + free, total = torch.cuda.mem_get_info() |
| 31 | + current_used = total - free |
| 32 | + current_torch = torch.cuda.memory_reserved() |
| 33 | + current_non_torch = current_used - current_torch |
| 34 | + return current_non_torch |
| 35 | + |
| 36 | + with ( |
| 37 | + memory_profiling( |
| 38 | + baseline_snapshot=baseline_snapshot, weights_memory=weights_memory |
| 39 | + ) as result, |
| 40 | + monitor(measure_current_non_torch) as monitored_values, |
| 41 | + ): |
| 42 | + # make a memory spike, 1 GiB |
| 43 | + spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32) |
| 44 | + del spike |
| 45 | + |
| 46 | + # Add some extra non-torch memory 256 MiB (simulate NCCL) |
| 47 | + handle2 = lib.cudaMalloc(256 * 1024 * 1024) |
| 48 | + |
| 49 | + # this is an analytic value, it is exact, |
| 50 | + # we only have 256 MiB non-torch memory increase |
| 51 | + measured_diff = monitored_values.values[-1] - monitored_values.values[0] |
| 52 | + assert measured_diff == 256 * 1024 * 1024 |
| 53 | + |
| 54 | + # Check that the memory usage is within 5% of the expected values |
| 55 | + # 5% tolerance is caused by cuda runtime. |
| 56 | + # we cannot control cuda runtime in the granularity of bytes, |
| 57 | + # which causes a small error (<10 MiB in practice) |
| 58 | + non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa |
| 59 | + assert abs(non_torch_ratio - 1) <= 0.05 |
| 60 | + assert result.torch_peak_increase == 1024 * 1024 * 1024 |
| 61 | + del weights |
| 62 | + lib.cudaFree(handle1) |
| 63 | + lib.cudaFree(handle2) |
0 commit comments