Skip to content

Commit a8f6e76

Browse files
committed
[1/N][Refactor] torchair model runner refactor
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 807f089 commit a8f6e76

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
18+
#
19+
20+
21+
import torch
22+
import torch._dynamo.cache_size
23+
from vllm.config import VllmConfig
24+
25+
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
26+
27+
28+
class NPUTorchairModelRunner(NPUModelRunner):
29+
30+
def __init__(self, vllm_config: VllmConfig, device: torch.device):
31+
super().__init__(vllm_config, device)

vllm_ascend/torchair/torchair_worker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.logger import logger
1818

1919
import vllm_ascend.envs as envs_ascend
20+
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
2021
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
2122
check_torchair_cache_exist,
2223
delete_torchair_cache_file,
@@ -52,3 +53,9 @@ def determine_available_memory(self) -> int:
5253
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
5354

5455
return available_kv_cache_memory
56+
57+
def init_device(self):
58+
"""Override init_device to init torchair model runner"""
59+
device = self._init_device()
60+
# Init ModelRunner here, so that we have access to self.device.
61+
self.model_runner = NPUTorchairModelRunner(self.vllm_config, device)

vllm_ascend/worker/worker_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,19 @@ def initialize_cache(self, num_gpu_blocks: int,
130130
self.cache_config.num_gpu_blocks = num_gpu_blocks
131131
self.cache_config.num_cpu_blocks = num_cpu_blocks
132132

133-
def init_device(self):
133+
def _init_device(self):
134134
device = torch.device(f"npu:{self.local_rank}")
135135
NPUPlatform.set_device(device)
136136
NPUPlatform.empty_cache()
137137
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
138-
139138
# Initialize the distributed environment.
140139
self._init_worker_distributed_environment()
141140
# Set random seed.
142141
NPUPlatform.seed_everything(self.model_config.seed)
142+
return device
143143

144+
def init_device(self):
145+
device = self._init_device()
144146
# Init ModelRunner here, so that we have access to self.device.
145147
self.model_runner = NPUModelRunner(self.vllm_config, device)
146148

0 commit comments

Comments
 (0)