Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
entry_points={
'vllm.platform_plugins': [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
]
],
"vllm.general_plugins":
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
})
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@

def dummy_platform_plugin() -> Optional[str]:
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"


def register_ops():
import vllm_add_dummy_platform.dummy_custom_ops # noqa
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.attention.backends.flash_attn import FlashAttentionBackend
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)


class DummyAttentionBackend(FlashAttentionBackend):
class DummyAttentionBackend(PlaceholderAttentionBackend):

@staticmethod
def get_name() -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding


# Register CustomRotaryEmbedding to CustomOP.
@RotaryEmbedding.register_oot
class DummyRotaryEmbedding(RotaryEmbedding):
"""Original rotary positional embedding."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.addition_config = True

def forward_oot(self, *args,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
return super().forward_oot(*args, **kwargs)
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING

from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.interface import Platform, PlatformEnum

if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
from vllm import envs

class DummyPlatform(CudaPlatform):

class DummyPlatform(Platform):
_enum = PlatformEnum.OOT
device_name = "DummyDevice"
device_type: str = "privateuseone"
dispatch_key: str = "PrivateUse1"

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if envs.VLLM_USE_V1:
compilation_config = vllm_config.compilation_config
# Activate custom ops for v1.
compilation_config.custom_ops = ["all"]

def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
14 changes: 14 additions & 0 deletions tests/plugins_tests/test_platform_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm.attention.selector import get_attn_backend
from vllm.plugins import load_general_plugins
from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL


Expand Down Expand Up @@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
backend = get_attn_backend(16, torch.float16, "auto", 16, False)
assert backend.get_name() == "Dummy_Backend"


def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
# simulate workload by running an example
load_general_plugins()
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16)
assert layer.__class__.__name__ == "DummyRotaryEmbedding", (
f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, "
"possibly because the custom op is not registered correctly.")
assert hasattr(layer, "addition_config"), (
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
"which is set by the custom op.")
56 changes: 56 additions & 0 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import torch.nn as nn

from vllm.config import get_current_vllm_config
Expand All @@ -16,6 +18,24 @@ class CustomOp(nn.Module):
Dispatches the forward method to the appropriate backend.
"""

def __new__(cls, *args, **kwargs):
try:
op_name = cls.__name__
except AttributeError:
raise TypeError(
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
f"was not set, possibly because it was not decorated with "
f"@CustomOp.register, or it's the CustomOp base class itself."
) from None

if op_name not in cls.op_registry_oot:
op_cls_to_instantiate = cls
else:
op_cls_to_instantiate = cls.op_registry_oot[op_name]
logger.debug("Instantiating custom op: %s using %s", op_name,
str(op_cls_to_instantiate))
return super().__new__(op_cls_to_instantiate)

def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
Expand Down Expand Up @@ -138,6 +158,7 @@ def default_on() -> bool:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type['CustomOp']] = {}
op_registry_oot: dict[str, type['CustomOp']] = {}

# Decorator to register custom ops.
@classmethod
Expand All @@ -150,3 +171,38 @@ def decorator(op_cls):
return op_cls

return decorator

# Decorator to register out-of-tree(oot) custom ops.
# For OOT custom ops:
# if in-tree layer class is registered with an oot_custom_op layer,
# the oot_custom_op layer will be used instead.
# Example:
# - @UnquantizedFusedMoEMethod.register_oot
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
# or
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
@classmethod
def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None):

def decorator(op_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in cls.op_registry_oot, \
f"Duplicate op name: {reg_name}"
op_cls.name = reg_name
cls.op_registry_oot[reg_name] = op_cls
return op_cls

if _decorated_op_cls is None:
# Called with parentheses: @CustomOP.register_oot()
# or @CustomOP.register_oot(name="...")
# So, _decorated_op_cls is None.
# We return the actual decorator function.
return decorator
elif isinstance(_decorated_op_cls, type): # Check if it's a class
# Called without parentheses: @CustomOP.register_oot
# The first argument is the class itself.
# We call the 'decorator' function immediately with the class.
return decorator(_decorated_op_cls)
else:
# Handle other unexpected cases if necessary
raise TypeError("Decorator can only be applied to classes.")