Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -238,7 +239,7 @@ def fused_marlin_moe(
max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
device=current_platform.device_type,
requires_grad=False)

if has_no_zp:
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers import PretrainedConfig

from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -650,9 +651,13 @@ def __init__(
is_neox_style, dtype)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
pos_freqs = self.base**(
torch.arange(0,
self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

Expand All @@ -670,7 +675,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch.jit
import torch.nn as nn

from vllm.platforms import current_platform


class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification
Expand Down Expand Up @@ -35,7 +37,7 @@ def __init__(self, strict_mode: bool = False):
def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"cuda:{device}"
device = f"{current_platform.device_type}:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,8 @@ def _parse_quant_state(param_name: str,
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]

return QuantState.from_dict(quant_state, device="cuda")
return QuantState.from_dict(quant_state,
device=current_platform.device_type)

# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig

Expand Down Expand Up @@ -138,13 +139,13 @@ def __init__(self,
torch.empty(self.num_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand Down Expand Up @@ -98,13 +99,13 @@ def __init__(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))

set_weight_attrs(self.ws, {
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .idefics2_vision_model import Idefics2VisionTransformer
Expand Down Expand Up @@ -1184,7 +1185,8 @@ def init_resampler(self,
quant_config=quant_config,
prefix=prefix)

return resampler.to(device="cuda", dtype=torch.get_default_dtype())
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())

def get_vision_embedding(
self,
Expand Down Expand Up @@ -1266,7 +1268,8 @@ def init_resampler(self,
quant_config=quant_config,
prefix=prefix)

return resampler.to(device="cuda", dtype=torch.get_default_dtype())
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())

def get_vision_embedding(
self,
Expand Down Expand Up @@ -1360,7 +1363,8 @@ def init_resampler(self,
quant_config=quant_config,
prefix=prefix)

return resampler.to(device="cuda", dtype=torch.get_default_dtype())
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())

def get_vision_embedding(
self,
Expand Down