Skip to content

Commit 949f6b8

Browse files
committed
[model][refactor] remove cuda hard code in models and layers
Signed-off-by: Mengqing Cao <cmq0113@163.com>
1 parent 7f6bae5 commit 949f6b8

File tree

7 files changed

+18
-13
lines changed

7 files changed

+18
-13
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.model_executor.layers.fused_moe.fused_moe import (
99
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
10+
from vllm.platforms import current_platform
1011
from vllm.scalar_type import scalar_types
1112
from vllm.utils import direct_register_custom_op
1213

@@ -238,7 +239,7 @@ def fused_marlin_moe(
238239
max_workspace_size = (max(2 * N, K) // 64) * 16
239240
workspace = torch.zeros(max_workspace_size,
240241
dtype=torch.int,
241-
device="cuda",
242+
device=current_platform.device_type,
242243
requires_grad=False)
243244

244245
if has_no_zp:

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from transformers import PretrainedConfig
3131

3232
from vllm.model_executor.custom_op import CustomOp
33+
from vllm.platforms import current_platform
3334

3435

3536
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -651,7 +652,7 @@ def __init__(
651652

652653
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
653654
pos_freqs = self.base**(torch.arange(
654-
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
655+
0, self.rotary_dim, 2, dtype=torch.float, device=current_platform.device_type) /
655656
self.rotary_dim)
656657
inv_freq_extrapolation = 1.0 / pos_freqs
657658
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
@@ -670,7 +671,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
670671
def _compute_cos_sin_cache(self) -> torch.Tensor:
671672
inv_freq = self._compute_inv_freq(self.scaling_factor)
672673
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
673-
device="cuda",
674+
device=current_platform.device_type,
674675
dtype=torch.float32)
675676
freqs = torch.einsum("i,j -> ij", t, inv_freq)
676677
cos = (freqs.cos() * self.mscale)

vllm/model_executor/layers/spec_decode_base_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.jit
88
import torch.nn as nn
9-
9+
from vllm.platforms import current_platform
1010

1111
class SpecDecodeBaseSampler(nn.Module):
1212
"""Base class for samplers used for Speculative Decoding verification
@@ -35,7 +35,7 @@ def __init__(self, strict_mode: bool = False):
3535
def init_gpu_tensors(self, device: Union[int, str]) -> None:
3636
assert self.num_accepted_tokens is None
3737
if isinstance(device, int):
38-
device = f"cuda:{device}"
38+
device = f"{current_platform.device_type}:{device}"
3939
elif not isinstance(device, str):
4040
raise ValueError(f"Device must be int or str, get {type(device)}")
4141
self.num_accepted_tokens = torch.tensor(0,

vllm/model_executor/model_loader/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def _parse_quant_state(param_name: str,
914914
if param_name + "." in k:
915915
quant_state[k] = temp_state_dict[k]
916916

917-
return QuantState.from_dict(quant_state, device="cuda")
917+
return QuantState.from_dict(quant_state, device=current_platform.device_type)
918918

919919
# Second iterate over all prequant and normal weights
920920
# pre quantized weights would have a quant_state

vllm/model_executor/models/arctic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3131
from vllm.model_executor.sampling_metadata import SamplingMetadata
3232
from vllm.model_executor.utils import set_weight_attrs
33+
from vllm.platforms import current_platform
3334
from vllm.sequence import IntermediateTensors
3435
from vllm.transformers_utils.configs.arctic import ArcticConfig
3536

@@ -138,13 +139,13 @@ def __init__(self,
138139
torch.empty(self.num_experts,
139140
2 * self.intermediate_size,
140141
self.hidden_size,
141-
device="cuda",
142+
device=current_platform.device_type,
142143
dtype=self.params_dtype))
143144
self.w2s = nn.Parameter(
144145
torch.empty(self.num_experts,
145146
self.hidden_size,
146147
self.intermediate_size,
147-
device="cuda",
148+
device=current_platform.device_type,
148149
dtype=self.params_dtype))
149150
set_weight_attrs(self.ws, {
150151
"weight_loader": self.weight_loader,

vllm/model_executor/models/minicpm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
5252
from vllm.model_executor.sampling_metadata import SamplingMetadata
5353
from vllm.model_executor.utils import set_weight_attrs
54+
from vllm.platforms import current_platform
5455
from vllm.sequence import IntermediateTensors
5556

5657
from .interfaces import SupportsLoRA, SupportsPP
@@ -98,13 +99,13 @@ def __init__(
9899
torch.empty(self.num_total_experts,
99100
2 * self.intermediate_size,
100101
self.hidden_size,
101-
device="cuda",
102+
device=current_platform.device_type,
102103
dtype=self.params_dtype))
103104
self.w2s = nn.Parameter(
104105
torch.empty(self.num_total_experts,
105106
self.hidden_size,
106107
self.intermediate_size,
107-
device="cuda",
108+
device=current_platform.device,
108109
dtype=self.params_dtype))
109110

110111
set_weight_attrs(self.ws, {

vllm/model_executor/models/minicpmv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from vllm.multimodal.processing import (BaseMultiModalProcessor,
6060
BaseProcessingInfo, PromptReplacement)
6161
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
62+
from vllm.platforms import current_platform
6263
from vllm.sequence import IntermediateTensors
6364

6465
from .idefics2_vision_model import Idefics2VisionTransformer
@@ -1184,7 +1185,7 @@ def init_resampler(self,
11841185
quant_config=quant_config,
11851186
prefix=prefix)
11861187

1187-
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1188+
return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype())
11881189

11891190
def get_vision_embedding(
11901191
self,
@@ -1266,7 +1267,7 @@ def init_resampler(self,
12661267
quant_config=quant_config,
12671268
prefix=prefix)
12681269

1269-
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1270+
return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype())
12701271

12711272
def get_vision_embedding(
12721273
self,
@@ -1360,7 +1361,7 @@ def init_resampler(self,
13601361
quant_config=quant_config,
13611362
prefix=prefix)
13621363

1363-
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1364+
return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype())
13641365

13651366
def get_vision_embedding(
13661367
self,

0 commit comments

Comments
 (0)