Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,22 @@ Defines the model paths and token limits.
model:
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance
critic_model_path: ${model.model_path} # use the value of model.model_path
max_response_tokens: 16384
max_model_len: 20480
max_prompt_tokens: 4096
max_response_tokens: 16384
min_response_tokens: 1
```

- `model_path`: Path to the model being trained.
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
- `max_response_tokens`: Maximum number of tokens allowed in generated responses.
- `max_model_len`: Maximum number of tokens in a sequence.
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will be inferred from the model configuration.
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.

```{tip}
If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API.
```

---

Expand Down
22 changes: 15 additions & 7 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,28 +144,36 @@ monitor:

---

## 模型配置
## Model 配置

定义模型路径和 token 限制。

```yaml
model:
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH 是预先设置的环境变量
critic_model_path: ${model.model_path} # 使用 model.model_path 的值
max_response_tokens: 16384
max_model_len: 20480
max_prompt_tokens: 4096
max_response_tokens: 16384
min_response_tokens: 1
```

- `model_path`: 被训练模型的路径。
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。
- `max_model_len`: 序列中最大 token 数。
- `max_model_len`: 该模型所支持的单个序列最大 token 数。
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。

```{tip}
如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。
```

---

## 集群配置
## Cluster 配置

定义使用的节点数和每节点的 GPU 数。
定义使用的集群包含的节点数和每节点的 GPU 数。

```yaml
cluster:
Expand All @@ -178,7 +186,7 @@ cluster:

---

## 缓冲区配置
## Buffer 配置

配置 explorer 和 trainer 使用的数据缓冲区(Buffer)。

Expand Down
4 changes: 2 additions & 2 deletions examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ algorithm:
repeat_times: 8
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_prompt_tokens: 12288
max_response_tokens: 12288
max_prompt_tokens: 8000
max_response_tokens: 8000
max_model_len: 16000 # slightly smaller than ppo_max_token_len_per_gpu (16384)
cluster:
node_num: 1
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ classifiers = [
requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"ray[default]>=2.45.0",
"vllm>=0.9.1,<=0.10.0",
"ray[default]>=2.48.0",
"vllm>=0.9.1,<=0.10.2",
"tensordict",
"wandb",
"omegaconf",
Expand All @@ -42,7 +42,7 @@ dependencies = [
"jsonlines",
"sortedcontainers",
"word2number",
"transformers<4.54.0", # TODO: remove when https://github.com/vllm-project/vllm-ascend/issues/2046 is fixed
"transformers",
]

[project.scripts]
Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest


FROM nvcr.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
FROM nvcr.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04

WORKDIR /workspace

Expand Down
68 changes: 59 additions & 9 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import torch
from openai import BadRequestError
from parameterized import parameterized_class
from transformers import AutoTokenizer

Expand Down Expand Up @@ -91,20 +92,15 @@ def print_debug(*args):
(
"tensor_parallel_size",
"engine_num",
"use_v1",
"repeat_times",
"enable_history",
"use_async",
"max_model_len",
),
[
(1, 2, False, 2, True, False, None),
(1, 2, False, 2, True, True, 20),
(1, 2, False, 2, True, False, 20),
(2, 2, False, 1, False, True, None),
(2, 2, True, 2, True, False, None),
(1, 2, True, 1, False, True, None),
(2, 1, True, 3, True, True, None),
(2, 2, 2, True, False, None),
(1, 2, 1, False, True, None),
(2, 1, 3, True, True, None),
],
)
class ModelWrapperTest(RayUnittestBaseAysnc):
Expand All @@ -116,7 +112,6 @@ def setUp(self):
self.config.model.max_model_len = self.max_model_len
self.config.explorer.rollout_model.engine_num = self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
self.config.explorer.rollout_model.use_v1 = self.use_v1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = self.repeat_times
self.config.explorer.rollout_model.enable_history = self.enable_history
Expand Down Expand Up @@ -222,6 +217,61 @@ async def test_generate(
self.assertTrue(len(history_experiences) == 0)


@parameterized_class(
(
"max_model_len",
"max_prompt_tokens",
"max_response_tokens",
),
[
(20, 19, None),
(20, None, 1),
],
)
class TestModelLen(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.max_model_len = self.max_model_len
self.config.model.max_prompt_tokens = self.max_prompt_tokens
self.config.model.max_response_tokens = self.max_response_tokens
self.config.explorer.rollout_model.enable_openai_api = True
self.config.check_and_update()

self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm", enable_history=True)

def test_model_len(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
]
response = self.model_wrapper.chat(messages)
self.assertEqual(len(response), 1)
self.assertEqual(len(response[0].tokens), self.max_model_len)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertEqual(len(exps[0].tokens), self.max_model_len)

# max_prompt_tokens and max_response_tokens do not work with openai api
openai_client = self.model_wrapper.get_openai_client()
model_id = openai_client.models.list().data[0].id
with self.assertRaises(BadRequestError):
# the prompt is longer than max_model_len
openai_client.chat.completions.create(model=model_id, messages=messages, n=1)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 0)

response = openai_client.chat.completions.create(model=model_id, messages=messages[1:], n=1)
self.assertEqual(len(response.choices), 1)
print(response.choices[0].message.content)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
# only generate max_model_len - prompt_len tokens
self.assertEqual(len(exps[0].tokens), self.max_model_len)


class TestAPIServer(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
Expand Down
2 changes: 0 additions & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_explorer(self):
]
)
self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.explorer.rollout_model.use_v1 = True
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
Expand All @@ -64,7 +63,6 @@ class TestExplorerCountdownNoEval(BaseExplorerCase):
def test_explorer(self):
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.explorer.rollout_model.use_v1 = False
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
Expand Down
Loading