diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index b35c0ecb31..a24c21e0b4 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -101,13 +101,23 @@ def print_debug(*args): @parameterized_class( - ("tensor_parallel_size", "engine_num", "use_v1", "repeat_times", "enable_history", "use_async"), + ( + "tensor_parallel_size", + "engine_num", + "use_v1", + "repeat_times", + "enable_history", + "use_async", + "max_model_len", + ), [ - (1, 2, False, 2, True, False), - (2, 2, False, 1, False, True), - (2, 2, True, 2, True, False), - (1, 2, True, 1, False, True), - (2, 1, True, 3, True, True), + (1, 2, False, 2, True, False, None), + (1, 2, False, 2, True, True, 20), + (1, 2, False, 2, True, False, 20), + (2, 2, False, 1, False, True, None), + (2, 2, True, 2, True, False, None), + (1, 2, True, 1, False, True, None), + (2, 1, True, 3, True, True, None), ], ) class ModelWrapperTest(RayUnittestBaseAysnc): @@ -116,6 +126,7 @@ def setUp(self): self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() + self.config.model.max_model_len = self.max_model_len self.config.explorer.rollout_model.engine_num = self.engine_num self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size self.config.explorer.rollout_model.use_v1 = self.use_v1 @@ -123,6 +134,9 @@ def setUp(self): self.config.algorithm.repeat_times = self.repeat_times self.config.explorer.rollout_model.enable_history = self.enable_history self.config.check_and_update() + from pprint import pprint + + pprint(self.config) self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper( self.engines[0], model_type="vllm_async", enable_history=self.enable_history @@ -191,7 +205,12 @@ async def test_generate( "content": results[0].response_text, } ) - exp = self.model_wrapper.convert_messages_to_experience(messages) + if self.max_model_len is not None: + with self.assertRaises(ValueError): + exp = self.model_wrapper.convert_messages_to_experience(messages) + return + else: + exp = self.model_wrapper.convert_messages_to_experience(messages) tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) result_dict = tokenizer.apply_chat_template( messages, diff --git a/trinity/common/config.py b/trinity/common/config.py index 982d6049b3..0f624d72f9 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -10,6 +10,7 @@ from trinity.common.constants import ( EXPLORER_NAME, + MAX_MODEL_LEN, TRAINER_NAME, OpType, PromptType, @@ -178,7 +179,7 @@ class ModelConfig: model_path: str = "" critic_model_path: str = "" max_model_len: Optional[int] = None - max_prompt_tokens: Optional[int] = None # deprecated + max_prompt_tokens: Optional[int] = None max_response_tokens: Optional[int] = None custom_chat_template: Optional[str] = None @@ -203,7 +204,7 @@ class InferenceModelConfig: # if not set, use `model.max_model_len` max_model_len: Optional[int] = None # if not set, use `model.max_prompt_tokens` - max_prompt_tokens: Optional[int] = None # deprecated + max_prompt_tokens: Optional[int] = None # if not set, use `model.max_response_tokens` max_response_tokens: Optional[int] = None @@ -769,24 +770,40 @@ def check_and_update(self) -> None: # noqa: C901 self.model.critic_model_path = self.model.model_path # check explorer + if self.model.max_model_len is None: + from transformers import AutoConfig, AutoTokenizer + from transformers.tokenization_utils_base import LARGE_INTEGER + + tokenizer = AutoTokenizer.from_pretrained(self.model.model_path) + config = AutoConfig.from_pretrained(self.model.model_path) + max_model_len = min( + getattr(tokenizer, "model_max_length", LARGE_INTEGER), + getattr(config, "max_position_embeddings", LARGE_INTEGER), + ) + if max_model_len >= LARGE_INTEGER: + max_model_len = MAX_MODEL_LEN + logger.warning( + f"Failed to get `max_model_len` from model {self.model.model_path}, use {MAX_MODEL_LEN} instead." + ) + self.model.max_model_len = max_model_len + if ( + self.model.max_prompt_tokens is None + or self.model.max_prompt_tokens >= self.model.max_model_len + ): + self.model.max_prompt_tokens = self.model.max_model_len - 1 + logger.warning(f"`max_prompt_tokens` is set to {self.model.max_prompt_tokens}.") + if ( + self.model.max_response_tokens is None + or self.model.max_response_tokens > self.model.max_model_len + ): + self.model.max_response_tokens = self.model.max_model_len + logger.warning(f"`max_response_tokens` is set to {self.model.max_response_tokens}.") + if self.explorer.rollout_model.max_model_len is None: + self.explorer.rollout_model.max_model_len = self.model.max_model_len if self.explorer.rollout_model.max_prompt_tokens is None: self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens if self.explorer.rollout_model.max_response_tokens is None: self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens - if self.explorer.rollout_model.max_model_len is None: - self.explorer.rollout_model.max_model_len = self.model.max_model_len - if ( - self.explorer.rollout_model.max_model_len is None - and self.explorer.rollout_model.max_prompt_tokens is not None - and self.explorer.rollout_model.max_response_tokens is not None - ): - logger.warning( - "`max_prompt_tokens` is deprecated, please set `max_model_len` directly." - ) - self.explorer.rollout_model.max_model_len = ( - self.explorer.rollout_model.max_prompt_tokens - + self.explorer.rollout_model.max_response_tokens - ) # check synchronizer self.synchronizer.ray_namespace = self.ray_namespace diff --git a/trinity/common/constants.py b/trinity/common/constants.py index e4428ed16b..7ed45e7d53 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -16,6 +16,11 @@ PLUGIN_DIRS_ENV_VAR = "TRINITY_PLUGIN_DIRS" +# constants + +MAX_MODEL_LEN = 4096 + + # enumerate types diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index c06be47fc7..f169573a47 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -53,7 +53,7 @@ def __init__( temperature=0.0, max_tokens=config.max_response_tokens, min_tokens=1, - truncate_prompt_tokens=config.max_model_len - 1, # type: ignore [operator] + truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, @@ -100,6 +100,10 @@ def __init__( self.api_server_host = None self.api_server_port = None + async def _initialize_tokenizer(self): + self.tokenizer = await self.async_llm.get_tokenizer() + self.tokenizer.truncation_side = "left" + async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: """Chat with the model with a list of messages in async. @@ -111,7 +115,7 @@ async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: A list of experiences. """ if self.tokenizer is None: - self.tokenizer = await self.async_llm.get_tokenizer() + await self._initialize_tokenizer() if self.chat_template is None: self.chat_template = self.tokenizer.get_chat_template() if messages[-1]["role"] == "assistant": @@ -141,7 +145,12 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: Returns: A list of experiences. """ - output = await self._generate_internal(prompt=prompt, **kwargs) + if self.tokenizer is None: + await self._initialize_tokenizer() + token_ids = self.tokenizer( # type: ignore + prompt, truncation=True, max_length=self.config.max_prompt_tokens, return_tensors="pt" + )["input_ids"][0].tolist() + output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs) experiences = [ Experience( tokens=torch.cat(