Skip to content

Commit 23eca9c

Browse files
authored
[model][refactor] remove cuda hard code in models and layers (#13658)
1 parent 437b76f commit 23eca9c

File tree

7 files changed

+29
-14
lines changed

7 files changed

+29
-14
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.model_executor.layers.fused_moe.fused_moe import (
99
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
10+
from vllm.platforms import current_platform
1011
from vllm.scalar_type import scalar_types
1112
from vllm.utils import direct_register_custom_op
1213

@@ -238,7 +239,7 @@ def fused_marlin_moe(
238239
max_workspace_size = (max(2 * N, K) // 64) * 16
239240
workspace = torch.zeros(max_workspace_size,
240241
dtype=torch.int,
241-
device="cuda",
242+
device=current_platform.device_type,
242243
requires_grad=False)
243244

244245
if has_no_zp:

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from transformers import PretrainedConfig
3131

3232
from vllm.model_executor.custom_op import CustomOp
33+
from vllm.platforms import current_platform
3334

3435

3536
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -650,9 +651,13 @@ def __init__(
650651
is_neox_style, dtype)
651652

652653
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
653-
pos_freqs = self.base**(torch.arange(
654-
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
655-
self.rotary_dim)
654+
pos_freqs = self.base**(
655+
torch.arange(0,
656+
self.rotary_dim,
657+
2,
658+
dtype=torch.float,
659+
device=current_platform.device_type) /
660+
self.rotary_dim)
656661
inv_freq_extrapolation = 1.0 / pos_freqs
657662
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
658663

@@ -670,7 +675,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
670675
def _compute_cos_sin_cache(self) -> torch.Tensor:
671676
inv_freq = self._compute_inv_freq(self.scaling_factor)
672677
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
673-
device="cuda",
678+
device=current_platform.device_type,
674679
dtype=torch.float32)
675680
freqs = torch.einsum("i,j -> ij", t, inv_freq)
676681
cos = (freqs.cos() * self.mscale)

vllm/model_executor/layers/spec_decode_base_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch.jit
88
import torch.nn as nn
99

10+
from vllm.platforms import current_platform
11+
1012

1113
class SpecDecodeBaseSampler(nn.Module):
1214
"""Base class for samplers used for Speculative Decoding verification
@@ -35,7 +37,7 @@ def __init__(self, strict_mode: bool = False):
3537
def init_gpu_tensors(self, device: Union[int, str]) -> None:
3638
assert self.num_accepted_tokens is None
3739
if isinstance(device, int):
38-
device = f"cuda:{device}"
40+
device = f"{current_platform.device_type}:{device}"
3941
elif not isinstance(device, str):
4042
raise ValueError(f"Device must be int or str, get {type(device)}")
4143
self.num_accepted_tokens = torch.tensor(0,

vllm/model_executor/model_loader/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,8 @@ def _parse_quant_state(param_name: str,
914914
if param_name + "." in k:
915915
quant_state[k] = temp_state_dict[k]
916916

917-
return QuantState.from_dict(quant_state, device="cuda")
917+
return QuantState.from_dict(quant_state,
918+
device=current_platform.device_type)
918919

919920
# Second iterate over all prequant and normal weights
920921
# pre quantized weights would have a quant_state

vllm/model_executor/models/arctic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3131
from vllm.model_executor.sampling_metadata import SamplingMetadata
3232
from vllm.model_executor.utils import set_weight_attrs
33+
from vllm.platforms import current_platform
3334
from vllm.sequence import IntermediateTensors
3435
from vllm.transformers_utils.configs.arctic import ArcticConfig
3536

@@ -138,13 +139,13 @@ def __init__(self,
138139
torch.empty(self.num_experts,
139140
2 * self.intermediate_size,
140141
self.hidden_size,
141-
device="cuda",
142+
device=current_platform.device_type,
142143
dtype=self.params_dtype))
143144
self.w2s = nn.Parameter(
144145
torch.empty(self.num_experts,
145146
self.hidden_size,
146147
self.intermediate_size,
147-
device="cuda",
148+
device=current_platform.device_type,
148149
dtype=self.params_dtype))
149150
set_weight_attrs(self.ws, {
150151
"weight_loader": self.weight_loader,

vllm/model_executor/models/minicpm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
5252
from vllm.model_executor.sampling_metadata import SamplingMetadata
5353
from vllm.model_executor.utils import set_weight_attrs
54+
from vllm.platforms import current_platform
5455
from vllm.sequence import IntermediateTensors
5556

5657
from .interfaces import SupportsLoRA, SupportsPP
@@ -98,13 +99,13 @@ def __init__(
9899
torch.empty(self.num_total_experts,
99100
2 * self.intermediate_size,
100101
self.hidden_size,
101-
device="cuda",
102+
device=current_platform.device_type,
102103
dtype=self.params_dtype))
103104
self.w2s = nn.Parameter(
104105
torch.empty(self.num_total_experts,
105106
self.hidden_size,
106107
self.intermediate_size,
107-
device="cuda",
108+
device=current_platform.device_type,
108109
dtype=self.params_dtype))
109110

110111
set_weight_attrs(self.ws, {

vllm/model_executor/models/minicpmv.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from vllm.multimodal.processing import (BaseMultiModalProcessor,
6060
BaseProcessingInfo, PromptReplacement)
6161
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
62+
from vllm.platforms import current_platform
6263
from vllm.sequence import IntermediateTensors
6364

6465
from .idefics2_vision_model import Idefics2VisionTransformer
@@ -1184,7 +1185,8 @@ def init_resampler(self,
11841185
quant_config=quant_config,
11851186
prefix=prefix)
11861187

1187-
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1188+
return resampler.to(device=current_platform.device_type,
1189+
dtype=torch.get_default_dtype())
11881190

11891191
def get_vision_embedding(
11901192
self,
@@ -1266,7 +1268,8 @@ def init_resampler(self,
12661268
quant_config=quant_config,
12671269
prefix=prefix)
12681270

1269-
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1271+
return resampler.to(device=current_platform.device_type,
1272+
dtype=torch.get_default_dtype())
12701273

12711274
def get_vision_embedding(
12721275
self,
@@ -1360,7 +1363,8 @@ def init_resampler(self,
13601363
quant_config=quant_config,
13611364
prefix=prefix)
13621365

1363-
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1366+
return resampler.to(device=current_platform.device_type,
1367+
dtype=torch.get_default_dtype())
13641368

13651369
def get_vision_embedding(
13661370
self,

0 commit comments

Comments
 (0)