diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 881b73246f..6aa2b4cedc 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -15,7 +15,6 @@ # This file is a part of the vllm-ascend project. # -import gc import os from datetime import timedelta from typing import TYPE_CHECKING, Optional, Tuple @@ -108,12 +107,6 @@ def synchronize(cls): def mem_get_info(cls) -> Tuple[int, int]: return torch.npu.mem_get_info() - @classmethod - def clear_npu_memory(cls): - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() - @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # initialize ascend config from vllm additional_config diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1a59036fa4..9e672a3905 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -18,6 +18,7 @@ # import atexit +import gc import math from contextlib import contextmanager, nullcontext from enum import Enum @@ -405,3 +406,9 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): return FusedMoEState.All2All else: return FusedMoEState.MC2 + + +def clear_npu_memory(): + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index bffc6a8de8..3fa63f05e9 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -23,7 +23,6 @@ import msgpack # type: ignore import torch -import torch.distributed import zmq from torch import nn from vllm import envs @@ -52,7 +51,7 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_310p, try_register_lib) + clear_npu_memory, is_310p, try_register_lib) from vllm_ascend.worker.model_runner import NPUModelRunner from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner @@ -281,7 +280,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. - NPUPlatform.empty_cache() + clear_npu_memory() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 6fe84a4580..089add7c76 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -42,7 +42,7 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import try_register_lib +from vllm_ascend.utils import clear_npu_memory, try_register_lib from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -136,7 +136,7 @@ def init_device(self): def determine_available_memory(self) -> int: # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. - NPUPlatform.clear_npu_memory() + clear_npu_memory() # Execute a forward pass with dummy inputs to profile the memory usage # of the model.