Skip to content

Commit de94289

Browse files
authored
[Core] Support weight_loader_v2 for UnquantizedLinearMethod (#23036)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 1983609 commit de94289

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

vllm/compilation/decorators.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
import torch.nn as nn
11+
from packaging import version
1112
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
1213

1314
from vllm.compilation.counter import compilation_counter
@@ -300,13 +301,13 @@ def patched_inline_call(parent, func, args, kwargs):
300301
logger.debug(
301302
"enable_cpp_symbolic_shape_guards config not available")
302303

303-
with patch.object(InliningInstructionTranslator, 'inline_call',
304-
patched_inline_call), torch._dynamo.config.patch(
305-
**dynamo_config_patches
306-
), maybe_use_cudagraph_partition_wrapper(
307-
self.vllm_config):
304+
with patch.object(
305+
InliningInstructionTranslator, "inline_call",
306+
patched_inline_call), torch._dynamo.config.patch(
307+
**dynamo_config_patches
308+
), maybe_use_cudagraph_partition_wrapper(
309+
self.vllm_config), _torch27_patch_tensor_subclasses():
308310
output = self.compiled_callable(*args, **kwargs)
309-
310311
return output
311312

312313
# usually, capturing the model once is enough, and then we can
@@ -367,3 +368,33 @@ def customized_cudagraph_wrapper(f,
367368
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
368369
and compilation_config.use_inductor_graph_partition):
369370
torch._inductor.utils.set_customized_partition_wrappers(None)
371+
372+
373+
@contextlib.contextmanager
374+
def _torch27_patch_tensor_subclasses():
375+
"""
376+
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
377+
using torch 2.7.0. This enables using weight_loader_v2 and the use of
378+
`BasevLLMParameters` without having to replace them with regular tensors
379+
before `torch.compile`-time.
380+
"""
381+
from vllm.model_executor.parameter import (BasevLLMParameter,
382+
ModelWeightParameter,
383+
RowvLLMParameter,
384+
_ColumnvLLMParameter)
385+
386+
def return_false(*args, **kwargs):
387+
return False
388+
389+
if version.parse("2.7") <= version.parse(
390+
torch.__version__) < version.parse("2.8"):
391+
yield
392+
return
393+
394+
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
395+
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
396+
RowvLLMParameter
397+
]),
398+
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
399+
return_false)):
400+
yield

vllm/model_executor/layers/linear.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# yapf: disable
2323
from vllm.model_executor.parameter import (BasevLLMParameter,
2424
BlockQuantScaleParameter,
25+
ModelWeightParameter,
2526
PackedColumnParameter,
2627
PackedvLLMParameter,
2728
PerTensorScaleParameter,
@@ -34,6 +35,7 @@
3435
logger = init_logger(__name__)
3536

3637
WEIGHT_LOADER_V2_SUPPORTED = [
38+
"UnquantizedLinearMethod",
3739
"CompressedTensorsLinearMethod",
3840
"CompressedTensorsLinearTransformMethod",
3941
"BitBLASLinearMethod",
@@ -196,10 +198,14 @@ def create_weights(self, layer: torch.nn.Module,
196198
# The amount of memory allocated for the weights is
197199
# sum(output_partition_sizes) * input_size_per_partition.
198200
try:
199-
weight = Parameter(torch.empty(sum(output_partition_sizes),
200-
input_size_per_partition,
201-
dtype=params_dtype),
202-
requires_grad=False)
201+
weight_loader = extra_weight_attrs.pop("weight_loader")
202+
weight = ModelWeightParameter(data=torch.empty(
203+
sum(output_partition_sizes),
204+
input_size_per_partition,
205+
dtype=params_dtype),
206+
input_dim=1,
207+
output_dim=0,
208+
weight_loader=weight_loader)
203209
except torch.cuda.OutOfMemoryError as e:
204210
logger.error("Failed to create unquantized linear weights: %s", e)
205211
if torch.cuda.is_available():
@@ -212,7 +218,7 @@ def create_weights(self, layer: torch.nn.Module,
212218
"Failed to create unquantized linear weights. "
213219
"This may be caused by insufficient memory to allocate "
214220
"the weight.") from e
215-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
221+
216222
layer.register_parameter("weight", weight)
217223
set_weight_attrs(weight, extra_weight_attrs)
218224

vllm/model_executor/parameter.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,24 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable):
6161
self.tp_size = get_tensor_model_parallel_world_size()
6262

6363
@property
64-
def weight_loader(self):
64+
def weight_loader(self) -> Callable:
65+
# NOTE(@ksayers) some models such as mamba_mixer2 override the
66+
# weight loader to support custom loading. In the future, model-specific
67+
# weight loading should be implemented via Model.load_weights. In the
68+
# meantime, support deleting and overriding `weight_loader`` attribute
69+
if self._weight_loader is None:
70+
raise AttributeError(f"{self.__class__.__name__} weight_loader "
71+
"attribute has been deleted")
6572
return self._weight_loader
6673

74+
@weight_loader.setter
75+
def weight_loader(self, value: Callable):
76+
self._weight_loader = value
77+
78+
@weight_loader.deleter
79+
def weight_loader(self):
80+
self._weight_loader = None # type: ignore[assignment]
81+
6782
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
6883
cond1 = self.data.ndim == 1 and self.data.numel() == 1
6984
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
@@ -97,6 +112,12 @@ def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
97112
assert shard_id in qkv_idxs
98113
return qkv_idxs[shard_id]
99114

115+
@classmethod
116+
def __torch_function__(cls, func, types, args=(), kwargs=None):
117+
if kwargs is None:
118+
kwargs = {}
119+
return super().__torch_function__(func, types, args, kwargs)
120+
100121

101122
class _ColumnvLLMParameter(BasevLLMParameter):
102123
"""

0 commit comments

Comments
 (0)