Skip to content

Commit

Permalink
[Core] enable out-of-tree model register (vllm-project#3871)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Apr 7, 2024
1 parent 9e462d8 commit 1a3a0b9
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ steps:
command: pytest -v -s engine tokenization test_sequence.py test_config.py

- label: Entrypoints Test
command: pytest -v -s entrypoints
commands:
# these tests have to be separated, because each one will allocate all posible GPU memory
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
- pytest -v -s entrypoints/test_server_oot_registration.py

- label: Examples Test
working_dir: "/vllm-workspace/examples"
Expand Down
27 changes: 27 additions & 0 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
This gives you the ability to modify the codebase and test your model.

.. tip::
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.

1. Bring your model code
------------------------
Expand Down Expand Up @@ -94,3 +96,28 @@ This method should load the weights from the HuggingFace's checkpoint file and a
----------------------

Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.

6. Out-of-Tree Model Integration
--------------------------------------------

We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.

Just add the following lines in your code:

.. code-block:: python
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code:

.. code-block:: python
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
Save the above code in a file and run it with `python your_file.py args`.
66 changes: 66 additions & 0 deletions tests/entrypoints/test_server_oot_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import multiprocessing
import sys
import time

import torch
from openai import OpenAI, OpenAIError

from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits


def server_function(port):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --dtype"
f" float32 --api-key token-abc123 --port {port}").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')


def test_oot_registration_for_api_server():
port = get_open_port()
server = multiprocessing.Process(target=server_function, args=(port, ))
server.start()
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
else:
raise e
server.kill()
generated_text = completion.choices[0].message.content
# make sure only the first token is generated
rest = generated_text.replace("<s>", "")
assert rest == ""
32 changes: 32 additions & 0 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits


def test_oot_registration():
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="facebook/opt-125m")
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
generated_text = output.outputs[0].text
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

__version__ = "0.4.0.post1"

__all__ = [
"LLM",
"ModelRegistry",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
Expand Down
18 changes: 17 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import importlib
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type

import torch.nn as nn

Expand Down Expand Up @@ -55,6 +55,10 @@
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
}

# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []

Expand All @@ -74,6 +78,8 @@ class ModelRegistry:

@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
return None
if is_hip():
Expand All @@ -95,6 +101,16 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())

@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
f"Model architecture {model_arch} is already registered, "
"and will be overwritten by the new model "
f"class {model_cls.__name__}.")
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls


__all__ = [
"ModelRegistry",
Expand Down

0 comments on commit 1a3a0b9

Please sign in to comment.