Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,23 @@ def print_debug(*args):


@parameterized_class(
("tensor_parallel_size", "engine_num", "use_v1", "repeat_times", "enable_history", "use_async"),
(
"tensor_parallel_size",
"engine_num",
"use_v1",
"repeat_times",
"enable_history",
"use_async",
"max_model_len",
),
[
(1, 2, False, 2, True, False),
(2, 2, False, 1, False, True),
(2, 2, True, 2, True, False),
(1, 2, True, 1, False, True),
(2, 1, True, 3, True, True),
(1, 2, False, 2, True, False, None),
(1, 2, False, 2, True, True, 20),
(1, 2, False, 2, True, False, 20),
(2, 2, False, 1, False, True, None),
(2, 2, True, 2, True, False, None),
(1, 2, True, 1, False, True, None),
(2, 1, True, 3, True, True, None),
],
)
class ModelWrapperTest(RayUnittestBaseAysnc):
Expand All @@ -116,13 +126,17 @@ def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.max_model_len = self.max_model_len
self.config.explorer.rollout_model.engine_num = self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
self.config.explorer.rollout_model.use_v1 = self.use_v1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = self.repeat_times
self.config.explorer.rollout_model.enable_history = self.enable_history
self.config.check_and_update()
from pprint import pprint

pprint(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(
self.engines[0], model_type="vllm_async", enable_history=self.enable_history
Expand Down Expand Up @@ -191,7 +205,12 @@ async def test_generate(
"content": results[0].response_text,
}
)
exp = self.model_wrapper.convert_messages_to_experience(messages)
if self.max_model_len is not None:
with self.assertRaises(ValueError):
exp = self.model_wrapper.convert_messages_to_experience(messages)
return
else:
exp = self.model_wrapper.convert_messages_to_experience(messages)
tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
result_dict = tokenizer.apply_chat_template(
messages,
Expand Down
49 changes: 33 additions & 16 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from trinity.common.constants import (
EXPLORER_NAME,
MAX_MODEL_LEN,
TRAINER_NAME,
OpType,
PromptType,
Expand Down Expand Up @@ -178,7 +179,7 @@ class ModelConfig:
model_path: str = ""
critic_model_path: str = ""
max_model_len: Optional[int] = None
max_prompt_tokens: Optional[int] = None # deprecated
max_prompt_tokens: Optional[int] = None
max_response_tokens: Optional[int] = None
custom_chat_template: Optional[str] = None

Expand All @@ -203,7 +204,7 @@ class InferenceModelConfig:
# if not set, use `model.max_model_len`
max_model_len: Optional[int] = None
# if not set, use `model.max_prompt_tokens`
max_prompt_tokens: Optional[int] = None # deprecated
max_prompt_tokens: Optional[int] = None
# if not set, use `model.max_response_tokens`
max_response_tokens: Optional[int] = None

Expand Down Expand Up @@ -769,24 +770,40 @@ def check_and_update(self) -> None: # noqa: C901
self.model.critic_model_path = self.model.model_path

# check explorer
if self.model.max_model_len is None:
from transformers import AutoConfig, AutoTokenizer
from transformers.tokenization_utils_base import LARGE_INTEGER

tokenizer = AutoTokenizer.from_pretrained(self.model.model_path)
config = AutoConfig.from_pretrained(self.model.model_path)
max_model_len = min(
getattr(tokenizer, "model_max_length", LARGE_INTEGER),
getattr(config, "max_position_embeddings", LARGE_INTEGER),
)
if max_model_len >= LARGE_INTEGER:
max_model_len = MAX_MODEL_LEN
logger.warning(
f"Failed to get `max_model_len` from model {self.model.model_path}, use {MAX_MODEL_LEN} instead."
)
self.model.max_model_len = max_model_len
if (
self.model.max_prompt_tokens is None
or self.model.max_prompt_tokens >= self.model.max_model_len
):
self.model.max_prompt_tokens = self.model.max_model_len - 1
logger.warning(f"`max_prompt_tokens` is set to {self.model.max_prompt_tokens}.")
if (
self.model.max_response_tokens is None
or self.model.max_response_tokens > self.model.max_model_len
):
self.model.max_response_tokens = self.model.max_model_len
logger.warning(f"`max_response_tokens` is set to {self.model.max_response_tokens}.")
if self.explorer.rollout_model.max_model_len is None:
self.explorer.rollout_model.max_model_len = self.model.max_model_len
if self.explorer.rollout_model.max_prompt_tokens is None:
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
if self.explorer.rollout_model.max_response_tokens is None:
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
if self.explorer.rollout_model.max_model_len is None:
self.explorer.rollout_model.max_model_len = self.model.max_model_len
if (
self.explorer.rollout_model.max_model_len is None
and self.explorer.rollout_model.max_prompt_tokens is not None
and self.explorer.rollout_model.max_response_tokens is not None
):
logger.warning(
"`max_prompt_tokens` is deprecated, please set `max_model_len` directly."
)
self.explorer.rollout_model.max_model_len = (
self.explorer.rollout_model.max_prompt_tokens
+ self.explorer.rollout_model.max_response_tokens
)

# check synchronizer
self.synchronizer.ray_namespace = self.ray_namespace
Expand Down
5 changes: 5 additions & 0 deletions trinity/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
PLUGIN_DIRS_ENV_VAR = "TRINITY_PLUGIN_DIRS"


# constants

MAX_MODEL_LEN = 4096


# enumerate types


Expand Down
15 changes: 12 additions & 3 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
temperature=0.0,
max_tokens=config.max_response_tokens,
min_tokens=1,
truncate_prompt_tokens=config.max_model_len - 1, # type: ignore [operator]
truncate_prompt_tokens=config.max_prompt_tokens,
skip_special_tokens=True,
include_stop_str_in_output=False,
output_kind=RequestOutputKind.FINAL_ONLY,
Expand Down Expand Up @@ -100,6 +100,10 @@ def __init__(
self.api_server_host = None
self.api_server_port = None

async def _initialize_tokenizer(self):
self.tokenizer = await self.async_llm.get_tokenizer()
self.tokenizer.truncation_side = "left"

async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]:
"""Chat with the model with a list of messages in async.

Expand All @@ -111,7 +115,7 @@ async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]:
A list of experiences.
"""
if self.tokenizer is None:
self.tokenizer = await self.async_llm.get_tokenizer()
await self._initialize_tokenizer()
if self.chat_template is None:
self.chat_template = self.tokenizer.get_chat_template()
if messages[-1]["role"] == "assistant":
Expand Down Expand Up @@ -141,7 +145,12 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
Returns:
A list of experiences.
"""
output = await self._generate_internal(prompt=prompt, **kwargs)
if self.tokenizer is None:
await self._initialize_tokenizer()
token_ids = self.tokenizer( # type: ignore
prompt, truncation=True, max_length=self.config.max_prompt_tokens, return_tensors="pt"
)["input_ids"][0].tolist()
output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs)
experiences = [
Experience(
tokens=torch.cat(
Expand Down