diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 294c45f87d..4b324ba9d4 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -152,14 +152,22 @@ Defines the model paths and token limits. model: model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance critic_model_path: ${model.model_path} # use the value of model.model_path - max_response_tokens: 16384 max_model_len: 20480 + max_prompt_tokens: 4096 + max_response_tokens: 16384 + min_response_tokens: 1 ``` - `model_path`: Path to the model being trained. - `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. -- `max_response_tokens`: Maximum number of tokens allowed in generated responses. -- `max_model_len`: Maximum number of tokens in a sequence. +- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will be inferred from the model configuration. +- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. +- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. +- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`. + +```{tip} +If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API. +``` --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 1710a45151..8219ba88c4 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -144,7 +144,7 @@ monitor: --- -## 模型配置 +## Model 配置 定义模型路径和 token 限制。 @@ -152,20 +152,28 @@ monitor: model: model_path: ${oc.env:MODEL_PATH} # MODEL_PATH 是预先设置的环境变量 critic_model_path: ${model.model_path} # 使用 model.model_path 的值 - max_response_tokens: 16384 max_model_len: 20480 + max_prompt_tokens: 4096 + max_response_tokens: 16384 + min_response_tokens: 1 ``` - `model_path`: 被训练模型的路径。 - `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。 -- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。 -- `max_model_len`: 序列中最大 token 数。 +- `max_model_len`: 该模型所支持的单个序列最大 token 数。 +- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 +- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 +- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 + +```{tip} +如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。 +``` --- -## 集群配置 +## Cluster 配置 -定义使用的节点数和每节点的 GPU 数。 +定义使用的集群包含的节点数和每节点的 GPU 数。 ```yaml cluster: @@ -178,7 +186,7 @@ cluster: --- -## 缓冲区配置 +## Buffer 配置 配置 explorer 和 trainer 使用的数据缓冲区(Buffer)。 diff --git a/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml b/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml index e5ddbdb394..fbb3d8deb8 100644 --- a/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml +++ b/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml @@ -8,8 +8,8 @@ algorithm: repeat_times: 8 model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} - max_prompt_tokens: 12288 - max_response_tokens: 12288 + max_prompt_tokens: 8000 + max_response_tokens: 8000 max_model_len: 16000 # slightly smaller than ppo_max_token_len_per_gpu (16384) cluster: node_num: 1 diff --git a/pyproject.toml b/pyproject.toml index 248184ce37..e7767ec22b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ classifiers = [ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", - "ray[default]>=2.45.0", - "vllm>=0.9.1,<=0.10.0", + "ray[default]>=2.48.0", + "vllm>=0.9.1,<=0.10.2", "tensordict", "wandb", "omegaconf", @@ -42,7 +42,7 @@ dependencies = [ "jsonlines", "sortedcontainers", "word2number", - "transformers<4.54.0", # TODO: remove when https://github.com/vllm-project/vllm-ascend/issues/2046 is fixed + "transformers", ] [project.scripts] diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index 23f7e594c4..153e7f1f9b 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -5,7 +5,7 @@ # docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest -FROM nvcr.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index d9c54deda5..40aa80d606 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1,6 +1,7 @@ import unittest import torch +from openai import BadRequestError from parameterized import parameterized_class from transformers import AutoTokenizer @@ -91,20 +92,15 @@ def print_debug(*args): ( "tensor_parallel_size", "engine_num", - "use_v1", "repeat_times", "enable_history", "use_async", "max_model_len", ), [ - (1, 2, False, 2, True, False, None), - (1, 2, False, 2, True, True, 20), - (1, 2, False, 2, True, False, 20), - (2, 2, False, 1, False, True, None), - (2, 2, True, 2, True, False, None), - (1, 2, True, 1, False, True, None), - (2, 1, True, 3, True, True, None), + (2, 2, 2, True, False, None), + (1, 2, 1, False, True, None), + (2, 1, 3, True, True, None), ], ) class ModelWrapperTest(RayUnittestBaseAysnc): @@ -116,7 +112,6 @@ def setUp(self): self.config.model.max_model_len = self.max_model_len self.config.explorer.rollout_model.engine_num = self.engine_num self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size - self.config.explorer.rollout_model.use_v1 = self.use_v1 self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE self.config.algorithm.repeat_times = self.repeat_times self.config.explorer.rollout_model.enable_history = self.enable_history @@ -222,6 +217,61 @@ async def test_generate( self.assertTrue(len(history_experiences) == 0) +@parameterized_class( + ( + "max_model_len", + "max_prompt_tokens", + "max_response_tokens", + ), + [ + (20, 19, None), + (20, None, 1), + ], +) +class TestModelLen(RayUnittestBase): + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.model.max_model_len = self.max_model_len + self.config.model.max_prompt_tokens = self.max_prompt_tokens + self.config.model.max_response_tokens = self.max_response_tokens + self.config.explorer.rollout_model.enable_openai_api = True + self.config.check_and_update() + + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm", enable_history=True) + + def test_model_len(self): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like today?"}, + ] + response = self.model_wrapper.chat(messages) + self.assertEqual(len(response), 1) + self.assertEqual(len(response[0].tokens), self.max_model_len) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertEqual(len(exps[0].tokens), self.max_model_len) + + # max_prompt_tokens and max_response_tokens do not work with openai api + openai_client = self.model_wrapper.get_openai_client() + model_id = openai_client.models.list().data[0].id + with self.assertRaises(BadRequestError): + # the prompt is longer than max_model_len + openai_client.chat.completions.create(model=model_id, messages=messages, n=1) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 0) + + response = openai_client.chat.completions.create(model=model_id, messages=messages[1:], n=1) + self.assertEqual(len(response.choices), 1) + print(response.choices[0].message.content) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + # only generate max_model_len - prompt_len tokens + self.assertEqual(len(exps[0].tokens), self.max_model_len) + + class TestAPIServer(RayUnittestBase): def setUp(self): self.config = get_template_config() diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 1d0dd77fc3..f0612c0bda 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -46,7 +46,6 @@ def test_explorer(self): ] ) self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.explorer.rollout_model.use_v1 = True self.config.check_and_update() explore(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) @@ -64,7 +63,6 @@ class TestExplorerCountdownNoEval(BaseExplorerCase): def test_explorer(self): self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.explorer.rollout_model.use_v1 = False self.config.check_and_update() explore(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) diff --git a/trinity/common/config.py b/trinity/common/config.py index 9f345568fa..537d371bf6 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -223,11 +223,21 @@ class ModelConfig: # source model path model_path: str = "" critic_model_path: str = "" + + custom_chat_template: Optional[str] = None + + # the total number of tokens the model can handle max_model_len: Optional[int] = None + + # Note: the following fields are only for the `chat`/`generate` methods in `InferenceModel` + # if you are using openai API, please set them when calling the API. + + # the maximum number of tokens for the prompt max_prompt_tokens: Optional[int] = None + # the maximum number of tokens for the response max_response_tokens: Optional[int] = None + # the minimum number of tokens for the response min_response_tokens: int = 1 - custom_chat_template: Optional[str] = None @dataclass @@ -798,6 +808,89 @@ def check_and_set(name, registry, args_attr): check_and_set("kl_penalty_fn", KL_FN, "kl_penalty_fn_args") check_and_set("entropy_loss_fn", ENTROPY_LOSS_FN, "entropy_loss_fn_args") + def _check_model(self) -> None: + if not self.model.critic_model_path: + self.model.critic_model_path = self.model.model_path + + # check max_model_len, max_prompt_tokens, max_response_tokens + + # if all three are set, check if they are valid + if ( + self.model.max_model_len is not None + and self.model.max_prompt_tokens is not None + and self.model.max_response_tokens is not None + ): + if ( + self.model.max_prompt_tokens + self.model.max_response_tokens + > self.model.max_model_len + ): + raise ValueError( + f"`max_prompt_tokens` + `max_response_tokens` ({self.model.max_prompt_tokens} + {self.model.max_response_tokens}) " + f"exceeds `max_model_len` ({self.model.max_model_len}). Please adjust them accordingly." + ) + + # check max_model_len first + if self.model.max_model_len is None: + if ( + self.model.max_prompt_tokens is not None + and self.model.max_response_tokens is not None + ): + self.model.max_model_len = ( + self.model.max_prompt_tokens + self.model.max_response_tokens + ) + logger.warning( + f"`max_model_len` is set to {self.model.max_model_len} from `max_prompt_tokens` and `max_response_tokens`." + ) + else: + from transformers import AutoConfig, AutoTokenizer + from transformers.tokenization_utils_base import LARGE_INTEGER + + tokenizer = AutoTokenizer.from_pretrained(self.model.model_path) + config = AutoConfig.from_pretrained(self.model.model_path) + max_model_len = min( + getattr(tokenizer, "model_max_length", LARGE_INTEGER), + getattr(config, "max_position_embeddings", LARGE_INTEGER), + ) + if max_model_len >= LARGE_INTEGER: + max_model_len = MAX_MODEL_LEN + logger.warning( + f"Failed to get `max_model_len` from model {self.model.model_path}, use {MAX_MODEL_LEN} instead." + ) + self.model.max_model_len = max_model_len + + # both max_prompt_tokens and max_response_tokens are None + if self.model.max_prompt_tokens is None and self.model.max_response_tokens is None: + # default to max_model_len / 2 + self.model.max_prompt_tokens = self.model.max_model_len // 2 + self.model.max_response_tokens = self.model.max_model_len - self.model.max_prompt_tokens + logger.warning( + f"`max_prompt_tokens` and `max_response_tokens` are not set, set to {self.model.max_prompt_tokens} and {self.model.max_response_tokens} respectively." + ) + + # only max_prompt_tokens is None + if self.model.max_prompt_tokens is None and self.model.max_response_tokens is not None: + self.model.max_response_tokens = min( + self.model.max_response_tokens, self.model.max_model_len - 1 + ) + self.model.max_prompt_tokens = self.model.max_model_len - self.model.max_response_tokens + logger.warning( + f"`max_prompt_tokens` is set to {self.model.max_prompt_tokens}, `max_response_tokens` is set to {self.model.max_response_tokens}." + ) + + # only max_response_tokens is None + if self.model.max_response_tokens is None and self.model.max_prompt_tokens is not None: + self.model.max_prompt_tokens = min( + self.model.max_prompt_tokens, self.model.max_model_len - 1 + ) + self.model.max_response_tokens = self.model.max_model_len - self.model.max_prompt_tokens + logger.warning( + f"`max_response_tokens` is set to {self.model.max_response_tokens}, `max_prompt_tokens` is set to {self.model.max_prompt_tokens}." + ) + + if self.model.min_response_tokens >= self.model.max_response_tokens: # type: ignore [operator] + self.model.min_response_tokens = max(self.model.max_response_tokens - 1, 0) # type: ignore [operator] + logger.warning(f"`min_response_tokens` is set to {self.model.min_response_tokens}.") + def __iter__(self): """Iterate over configs with each stage applied in order. @@ -848,47 +941,27 @@ def check_and_update(self) -> Config: # noqa: C901 logger.warning(f"Experiment [{ori_name}] already exists, renamed as {self.name}.") os.makedirs(self.checkpoint_job_dir, exist_ok=True) - # check and update model path - if self.explorer is not None: - self.explorer.rollout_model.model_path = self.model.model_path - if not self.model.critic_model_path: - self.model.critic_model_path = self.model.model_path + # check model + self._check_model() # check explorer - if self.model.max_model_len is None: - from transformers import AutoConfig, AutoTokenizer - from transformers.tokenization_utils_base import LARGE_INTEGER - - tokenizer = AutoTokenizer.from_pretrained(self.model.model_path) - config = AutoConfig.from_pretrained(self.model.model_path) - max_model_len = min( - getattr(tokenizer, "model_max_length", LARGE_INTEGER), - getattr(config, "max_position_embeddings", LARGE_INTEGER), - ) - if max_model_len >= LARGE_INTEGER: - max_model_len = MAX_MODEL_LEN - logger.warning( - f"Failed to get `max_model_len` from model {self.model.model_path}, use {MAX_MODEL_LEN} instead." - ) - self.model.max_model_len = max_model_len - if ( - self.model.max_prompt_tokens is None - or self.model.max_prompt_tokens >= self.model.max_model_len - ): - self.model.max_prompt_tokens = self.model.max_model_len - 1 - logger.warning(f"`max_prompt_tokens` is set to {self.model.max_prompt_tokens}.") - if ( - self.model.max_response_tokens is None - or self.model.max_response_tokens > self.model.max_model_len - ): - self.model.max_response_tokens = self.model.max_model_len - logger.warning(f"`max_response_tokens` is set to {self.model.max_response_tokens}.") - if self.explorer.rollout_model.max_model_len is None: + if self.explorer is not None: + self.explorer.rollout_model.model_path = self.model.model_path self.explorer.rollout_model.max_model_len = self.model.max_model_len - if self.explorer.rollout_model.max_prompt_tokens is None: self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens - if self.explorer.rollout_model.max_response_tokens is None: self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens + self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens + for aux_model in self.explorer.auxiliary_models: + if not aux_model.model_path: + raise ValueError("auxiliary model's model_path is required.") + if aux_model.max_model_len is None: + aux_model.max_model_len = self.model.max_model_len + if aux_model.max_prompt_tokens is None: + aux_model.max_prompt_tokens = self.model.max_prompt_tokens + if aux_model.max_response_tokens is None: + aux_model.max_response_tokens = self.model.max_response_tokens + if aux_model.min_response_tokens is None: + aux_model.min_response_tokens = self.model.min_response_tokens # check synchronizer self.synchronizer.ray_namespace = self.ray_namespace diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py index d3376bcade..c9ae2f40a1 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/api/vllm_patch.py @@ -347,10 +347,10 @@ async def run_api_server_in_ray_actor( reasoning_parser: Optional[str] = None, ): vllm_version = get_vllm_version() - if vllm_version < parse_version("0.8.5") or vllm_version > parse_version("0.10.0"): + if vllm_version < parse_version("0.8.5") or vllm_version > parse_version("0.10.2"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.8.5, <= 0.10.0." + "This patch requires vllm version >= 0.8.5, <= 0.10.2." ) parser = FlexibleArgumentParser(description="Run the OpenAI API server.") diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index d08384484a..3984453e78 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -7,11 +7,13 @@ import ray import torch import vllm +from packaging.version import parse as parse_version from transformers import AutoProcessor from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience +from trinity.common.models.api.vllm_patch import get_vllm_version from trinity.common.models.mm_utils import ( attach_images_to_messages, build_multi_modal_inputs, @@ -21,7 +23,7 @@ from trinity.utils.log import get_logger -# TODO: remove V0 when V1 is stable +# V0 engine is deprecated since vLLM v0.10.2, related code will be removed in the future. class vLLMRolloutModel(InferenceModel): """Wrapper around the vLLM engine to handle async requests. @@ -50,7 +52,7 @@ def __init__( n=1, temperature=0.0, max_tokens=config.max_response_tokens, - min_tokens=1, + min_tokens=config.min_response_tokens, truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, @@ -73,11 +75,14 @@ def __init__( dtype=config.dtype, trust_remote_code=True, task="generate", - disable_log_requests=True, gpu_memory_utilization=config.gpu_memory_utilization, enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage ) + if get_vllm_version() > parse_version("0.10.0"): + engine_args.enable_log_requests = False + else: + engine_args.disable_log_requests = True self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.processor = None self.tokenizer = None @@ -286,7 +291,8 @@ async def logprobs(self, token_ids: List[int]) -> torch.Tensor: to align with the actual response length. Args: - token_ids (List[int]): The input token ids (seq_length). + token_ids (List[int]): The input token ids (seq_length). Please make sure the length of + it does not exceed `max_model_len - 1`. Returns: A tensor of logprobs (seq_length - 1). diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index d280a90bd5..26021b2420 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -409,6 +409,28 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.critic.rollout_n = self.actor_rollout_ref.rollout.n self.critic.optim.total_training_steps = self.trainer.total_training_steps + if ( + self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu # type: ignore [operator] + * self.actor_rollout_ref.actor.ulysses_sequence_parallel_size + < config.model.max_model_len + ): + self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu = math.ceil( + config.model.max_model_len # type: ignore [operator] + / self.actor_rollout_ref.actor.ulysses_sequence_parallel_size + ) + logger.warning( + f"Warning: actor.ppo_max_token_len_per_gpu is automatically set to {self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu} to match model.max_model_len ({config.model.max_model_len})" + ) + if ( + self.critic.ppo_max_token_len_per_gpu * self.critic.ulysses_sequence_parallel_size # type: ignore [operator] + < config.model.max_model_len + ): + self.critic.ppo_max_token_len_per_gpu = math.ceil( + config.model.max_model_len / self.critic.ulysses_sequence_parallel_size # type: ignore [operator] + ) + logger.warning( + f"Warning: critic.ppo_max_token_len_per_gpu is automatically set to {self.critic.ppo_max_token_len_per_gpu} to match model.max_model_len ({config.model.max_model_len})" + ) if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip