Skip to content

Commit 87494ec

Browse files
committed
add clear_npu_memory() method
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 2c7dd85 commit 87494ec

File tree

4 files changed

+11
-12
lines changed

4 files changed

+11
-12
lines changed

vllm_ascend/platform.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# This file is a part of the vllm-ascend project.
1616
#
1717

18-
import gc
1918
import os
2019
from datetime import timedelta
2120
from typing import TYPE_CHECKING, Optional, Tuple
@@ -107,12 +106,6 @@ def synchronize(cls):
107106
def mem_get_info(cls) -> Tuple[int, int]:
108107
return torch.npu.mem_get_info()
109108

110-
@classmethod
111-
def clear_npu_memory(cls):
112-
gc.collect()
113-
torch.npu.empty_cache()
114-
torch.npu.reset_peak_memory_stats()
115-
116109
@classmethod
117110
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
118111
# initialize ascend config from vllm additional_config

vllm_ascend/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#
1919

2020
import atexit
21+
import gc
2122
import math
2223
from contextlib import contextmanager, nullcontext
2324
from enum import Enum
@@ -294,3 +295,9 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool):
294295
return FusedMoEState.All2All
295296
else:
296297
return FusedMoEState.MC2
298+
299+
300+
def clear_npu_memory():
301+
gc.collect()
302+
torch.npu.empty_cache()
303+
torch.npu.reset_peak_memory_stats()

vllm_ascend/worker/worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import msgpack # type: ignore
2525
import torch
26-
import torch.distributed
2726
import zmq
2827
from torch import nn
2928
from vllm import envs
@@ -51,7 +50,7 @@
5150
from vllm_ascend.device_allocator.camem import CaMemAllocator
5251
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
5352
from vllm_ascend.platform import NPUPlatform
54-
from vllm_ascend.utils import try_register_lib
53+
from vllm_ascend.utils import clear_npu_memory, try_register_lib
5554
from vllm_ascend.worker.model_runner import NPUModelRunner
5655
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
5756

@@ -280,7 +279,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
280279
"""
281280
# Profile the memory usage of the model and get the maximum number of
282281
# cache blocks that can be allocated with the remaining free memory.
283-
NPUPlatform.empty_cache()
282+
clear_npu_memory()
284283

285284
# Execute a forward pass with dummy inputs to profile the memory usage
286285
# of the model.

vllm_ascend/worker/worker_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm_ascend.device_allocator.camem import CaMemAllocator
4343
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4444
from vllm_ascend.platform import NPUPlatform
45-
from vllm_ascend.utils import try_register_lib
45+
from vllm_ascend.utils import clear_npu_memory, try_register_lib
4646
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
4747

4848

@@ -136,7 +136,7 @@ def init_device(self):
136136
def determine_available_memory(self) -> int:
137137
# Profile the memory usage of the model and get the maximum number of
138138
# cache blocks that can be allocated with the remaining free memory.
139-
NPUPlatform.clear_npu_memory()
139+
clear_npu_memory()
140140

141141
# Execute a forward pass with dummy inputs to profile the memory usage
142142
# of the model.

0 commit comments

Comments
 (0)