diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md
index cb4b978c7d..08a150d477 100644
--- a/docs/sphinx_doc/source/tutorial/example_dpo.md
+++ b/docs/sphinx_doc/source/tutorial/example_dpo.md
@@ -30,11 +30,21 @@ modelscope download --dataset HumanLLMs/Human-Like-DPO-Dataset --local_dir $DATA
huggingface-cli download HumanLLMs/Human-Like-DPO-Dataset --repo-type dataset --local-dir $DATASET_PATH/human_like_dpo_dataset
```
+Below are some data samples in JSONL format:
+```json
+{"prompt":"Oh, I just saw the best meme - have you seen it?","chosen":"\ud83d\ude02 Ah, no I haven't! I'm dying to know, what's the meme about? Is it a funny cat or a ridiculous situation? Spill the beans! \ud83e\udd23","rejected":"I'm an artificial intelligence language model, I don't have personal experiences or opinions. However, I can provide you with information on highly-rated and critically acclaimed films, as well as recommendations based on specific genres or themes. Would you like me to suggest some notable movies or discuss a particular genre of interest?"}
+{"prompt":"Have you tried any new hobbies or activities recently?","chosen":"You know, I've been meaning to try my hand at gardening, but I haven't gotten around to it yet. I've heard it's super relaxing and a great way to get some fresh air. Maybe I'll finally get around to buying some seeds and pots this weekend. What about you? Have you taken up anything new and exciting lately? \ud83c\udf31\ud83d\udc40","rejected":"I'm an artificial intelligence language model, and as such, I don't have personal experiences or engage in physical activities such as dining or cooking. My purpose is to provide information, answer questions, and assist with tasks to the best of my abilities, while maintaining a professional and impartial demeanor. If you have any specific questions or topics related to restaurants or recipes, I'd be happy to provide information or guidance."}
+```
+
More details of dataset downloading are referred to [ModelScope](https://modelscope.cn/docs/datasets/download) or [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space).
-Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pass the proper keys to the config.
+Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If you use different datasets, pass the proper keys to the config.
-For SFT, we download the dataset to the local directory `/PATH/TO/SFT_DATASET/`, which usually contains message-based data.
+For SFT, we download the `open-r1/Mixture-of-Thoughts` dataset to the local directory `$DATASET_PATH/Mixture-of-Thoughts`, which contains message-based data, we list a simplified sample here.
+
+```json
+{"messages": [{"content": "You will be given a competitive programming problem...","role": "user"},{"content": "\n...\n...This approach efficiently combines hashing and dynamic programming to solve the problem within the given constraints.","role": "assistant"}], "num_tokens": 22185, "source": "open-r1/codeforces-cots"}
+```
## Step 2: Setup Configuration
@@ -77,9 +87,22 @@ trainer:
save_interval: 30
```
+`buffer.trainer_input.experience_buffer` specifies the dataset to be used for training, including its name, storage type, path, and format.
+
+- The name `human_like_dpo` is a unique identifier for this dataset configuration, you can use other names as long as they are unique within the project.
+- The storage type `file` means the dataset is stored in a file on the local filesystem and the `path` is pointed to the local directory where the dataset is stored. Note that the `file` storage type also supports using huggingface datasets path like `HumanLLMs/Human-Like-DPO-Dataset`.
+- The format specifies how the data is structured within the dataset. In this case, it is defined as follows:
+
+ - `prompt_type: plaintext` indicates that the prompts are in plain text format.
+ - `prompt_key: prompt` specifies the key in the dataset that contains the user prompts.
+ - `chosen_key: chosen` specifies the key in the dataset that contains the chosen responses.
+ - `rejected_key: rejected` specifies the key in the dataset that contains the rejected responses.
+
+For more configuration options, please refer to the {ref}`Configuration Guide `.
+
### Configuration for SFT
-We set the `algorithm_type` as `sft` to run SFT process. Then we modify the config file [`examples/sft_mot/sft.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/sft_mot/sft.yaml) with the following changes:
+We set the `algorithm_type` as `sft` to run SFT process and then modify the config file [`examples/sft_mot/sft.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/sft_mot/sft.yaml) with the following changes:
```yaml
project:
@@ -100,7 +123,7 @@ buffer:
experience_buffer:
name:
storage_type: file
- path: /PATH/TO/SFT_DATASET/
+ path: $DATASET_PATH/Mixture-of-Thoughts
split: train
format:
prompt_type: messages
@@ -110,6 +133,8 @@ trainer:
save_interval: 50
```
+Here we set `buffer.trainer_input.experience_buffer.format.prompt_type` to `messages` because the source data is in message format. We also set `buffer.trainer_input.experience_buffer.format.messages_key` to `messages` to specify the key in the dataset that contains the messages.
+
## Step 3: Run the Experiment
Run DPO process with the following command:
@@ -117,6 +142,7 @@ Run DPO process with the following command:
```shell
trinity run --config examples/dpo_humanlike/dpo.yaml
```
+
or, for SFT:
```shell
diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
index 5df890cd96..05462e23b9 100644
--- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md
+++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
@@ -32,7 +32,7 @@ We prompt a powerful LLM to generate responses with the CoT process for some pre
```json
{
"messages": [
- { "role": "system", "content": },
+ { "role": "system", "content": "" },
{ "role": "user", "content": "What is the sum of 4 and 12?" },
{ "role": "assistant", "content": "thinking process...\n16" } ]
},
diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 6b109f261a..5fffc98f0a 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -1,3 +1,4 @@
+(Configuration Guide)=
# Configuration Guide
This section provides a detailed description of the configuration files used in **Trinity-RFT**.
@@ -231,6 +232,7 @@ buffer:
split: test
repeat_times: 1
format:
+ prompt_type: `plaintext`
prompt_key: 'question'
response_key: 'answer'
rollout_args:
@@ -256,9 +258,6 @@ The configuration for each task dataset is defined as follows:
- `subset_name`: The subset name of the task dataset. Default is `None`.
- `split`: The split of the task dataset. Default is `train`.
- `repeat_times`: The number of rollouts generated for a task. If not set, it will be automatically set to `algorithm.repeat_times` for `taskset`, and `1` for `eval_tasksets`.
-- `format`: Defines keys for prompts and responses in the dataset.
- - `prompt_key`: Specifies which column in the dataset contains the prompt data.
- - `response_key`: Specifies which column in the dataset contains the response data.
- `rollout_args`: The parameters for rollout.
- `temperature`: The temperature for sampling.
- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used.
@@ -277,6 +276,7 @@ buffer:
name: countdown_buffer
storage_type: queue
path: sqlite:///countdown_buffer.db
+ max_read_timeout: 1800
sft_warmup_dataset:
name: warmup_data
@@ -299,10 +299,24 @@ buffer:
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- For `file` storage type, the path points to the directory containing the dataset files.
- For `sql` storage type, the path points to the SQLite database file.
+ - `format`: Defines keys for prompts and responses in the dataset.
+ - `prompt_type`: Specifies the type of prompts in the dataset. We support `plaintext`, `messages` for now.
+ - `plaintext`: The prompt is in string format.
+ - `messages`: The prompt is organized as a message list.
+ - `prompt_key`: Specifies which column in the dataset contains the user prompt data. Only for `plaintext`.
+ - `response_key`: Specifies which column in the dataset contains the response data. Only for `plaintext`.
+ - `system_prompt_key`: Specifies which column in the dataset contains the system prompt data. Only for `plaintext`.
+ - `system_prompt`: Specifies the system prompt in string format. It has lower priority than `system_prompt_key`. Only for `plaintext`.
+ - `messages_key`: Specifies which column in the dataset contains the messages data. Only for `messages`.
+ - `tools_key`: Specifies which column in the dataset contains the tools data. Support both `plaintext` and `messages`, but the tool data should be organized as a list of dict.
+ - `chosen_key`: Specifies which column in the dataset contains the DPO chosen data. Support both `plaintext` and `messages`, and the data type should be consistent with the prompt type.
+ - `rejected_key`: Similar to `chosen_key`, but it specifies which column in the dataset contains the DPO rejected data.
+ - `enable_concatenated_multi_turn`: Enable concatenated multi-turn SFT data preprocess. Only for `messages` and only take effect with SFT algorithm.
+ - `chat_template`: Specifies the chat template in string format. If not provided, use `model.custom_chat_template`.
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
-- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
+- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup). Its configuration is similar to the `experience_buffer`, but only for SFT usage.
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.
---
diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
index c395880a55..a478eb7e18 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
@@ -65,7 +65,7 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task`
In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:
-```
+```json
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
diff --git a/examples/sppo_gsm8k/gsm8k.yaml b/examples/sppo_gsm8k/gsm8k.yaml
index d72459afe3..f6b08c8ca9 100644
--- a/examples/sppo_gsm8k/gsm8k.yaml
+++ b/examples/sppo_gsm8k/gsm8k.yaml
@@ -20,8 +20,6 @@ cluster:
buffer:
total_steps: 100
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/tests/buffer/formatter_test.py b/tests/buffer/formatter_test.py
index 9da05e65af..7ddc0f88d1 100644
--- a/tests/buffer/formatter_test.py
+++ b/tests/buffer/formatter_test.py
@@ -37,12 +37,7 @@ def test_sft_messages_formatter(self):
self.assertIn("Hello", sequence)
# test tool
- config = FormatConfig(
- prompt_type=PromptType.MESSAGES,
- messages_key="messages",
- tools_key="tools",
- )
- formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
+
sample = {
"messages": [
{
@@ -99,19 +94,60 @@ def test_sft_messages_formatter(self):
}
],
}
+ config = FormatConfig(
+ prompt_type=PromptType.MESSAGES,
+ messages_key="messages",
+ tools_key="tools",
+ enable_concatenated_multi_turn=False,
+ )
+ formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
self.assertIsNotNone(exp.tokens)
self.assertIsNotNone(exp.prompt_length)
self.assertTrue(exp.prompt_length < len(exp.tokens))
+ self.assertIsNotNone(exp.action_mask)
+ self.assertEqual(len(exp.action_mask) + exp.prompt_length, len(exp.tokens))
+ # assert action mask is all true
+ self.assertTrue(all(exp.action_mask.tolist()))
sequence = self.tokenizer.decode(exp.tokens)
self.assertIn("What's the weather like in Beijing today?", sequence)
+ self.assertIn("Let me get the weather for you.", sequence)
self.assertIn(
"The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
sequence,
)
self.assertIn("get_weather", sequence)
+ config = FormatConfig(
+ prompt_type=PromptType.MESSAGES,
+ messages_key="messages",
+ tools_key="tools",
+ enable_concatenated_multi_turn=True,
+ )
+ formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
+ exp = formatter.format(sample)
+ self.assertIsInstance(exp, Experience)
+ self.assertIsNotNone(exp.tokens)
+ self.assertIsNotNone(exp.prompt_length)
+ self.assertTrue(exp.prompt_length < len(exp.tokens))
+ self.assertIsNotNone(exp.action_mask)
+ self.assertEqual(len(exp.action_mask) + exp.prompt_length, len(exp.tokens))
+ self.assertTrue(any(exp.action_mask.tolist()) and not all(exp.action_mask.tolist()))
+ prompt = self.tokenizer.decode(exp.tokens[: exp.prompt_length])
+ response = self.tokenizer.decode(exp.tokens[exp.prompt_length :])
+ self.assertIn("What's the weather like in Beijing today?", prompt)
+ self.assertNotIn("Let me get the weather for you.", prompt)
+ self.assertIn("Let me get the weather for you.", response)
+ self.assertNotIn(
+ "The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
+ prompt,
+ )
+ self.assertIn(
+ "The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
+ response,
+ )
+
def test_sft_plaintext_formatter(self):
# with system prompt key
config = FormatConfig(
diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py
index a24c21e0b4..158f56108d 100644
--- a/tests/common/vllm_test.py
+++ b/tests/common/vllm_test.py
@@ -40,6 +40,7 @@ def print_debug(*args):
print(*args)
+# Qwen2.5 chat template with {% generation %} mark
CHAT_TEMPLATE = r"""
{%- if tools %}
{{- '<|im_start|>system\n' }}
@@ -67,19 +68,19 @@ def print_debug(*args):
{%- elif (message.role == "assistant" and not message.tool_calls) %}
{{- '<|im_start|>' + message.role + '\n'}}{% generation %}{{- message.content + '<|im_end|>' + '\n' }}{% endgeneration %}
{%- elif message.role == "assistant" %}
- {{- '<|im_start|>' + message.role }}{% generation %}
+ {{- '<|im_start|>' + message.role + '\n'}}{% generation %}
{%- if message.content %}
- {{- '\n' + message.content }}
+ {{- message.content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
- {{- '\n\n{"name": "' }}
+ {{- '\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{{- tool_call.arguments | tojson }}
- {{- '}\n' }}
+ {{- '}\n\n' }}
{%- endfor %}
{{- '<|im_end|>\n' }}{% endgeneration %}
{%- elif message.role == "tool" %}
@@ -307,7 +308,7 @@ def test_api(self):
class TestTokenizer(unittest.TestCase):
- def test_assistant_token_mask(self):
+ def test_action_mask(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
@@ -338,6 +339,80 @@ def test_assistant_token_mask(self):
self.assertTrue(torch.equal(action_mask, action_mask_hf))
self.assertEqual(prompt_length, prompt_length_hf)
+ def test_action_mask_with_tools(self):
+ messages = [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant with access to various tools. Use them when needed to help users.",
+ },
+ {"role": "user", "content": "What's the weather like in Beijing today?"},
+ {
+ "role": "assistant",
+ "content": "Let me get the weather for you.",
+ "tool_calls": [
+ {
+ "id": "call_abc123",
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "arguments": '{"location": "Beijing", "unit": "celsius"}',
+ },
+ }
+ ],
+ },
+ {
+ "role": "tool",
+ "content": '{"temperature": 22, "condition": "sunny", "humidity": 45}',
+ "tool_call_id": "call_abc123",
+ },
+ {
+ "role": "assistant",
+ "content": "The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
+ },
+ ]
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit",
+ },
+ },
+ "required": ["location"],
+ },
+ },
+ }
+ ]
+ tokenizer = AutoTokenizer.from_pretrained(get_model_path())
+ token_ids, action_mask, prompt_length = tokenize_and_mask_messages_default(
+ tokenizer=tokenizer,
+ messages=messages,
+ tools=tools,
+ chat_template=CHAT_TEMPLATE,
+ )
+ token_ids_hf, action_mask_hf, prompt_length_hf = tokenize_and_mask_messages_hf(
+ tokenizer=tokenizer,
+ messages=messages,
+ tools=tools,
+ chat_template=CHAT_TEMPLATE,
+ )
+ self.assertEqual(token_ids.shape, token_ids_hf.shape)
+ self.assertEqual(action_mask.shape, action_mask_hf.shape)
+ self.assertTrue(torch.equal(token_ids, token_ids_hf))
+ self.assertTrue(torch.equal(action_mask, action_mask_hf))
+ self.assertEqual(prompt_length, prompt_length_hf)
+
@parameterized_class(
("enable_thinking", "reasoning_parser"),
diff --git a/tests/tools.py b/tests/tools.py
index ccadb80ac5..6490218a77 100644
--- a/tests/tools.py
+++ b/tests/tools.py
@@ -99,6 +99,7 @@ def get_unittest_dataset_config(
prompt_type=PromptType.MESSAGES,
messages_key="messages",
tools_key="tools",
+ enable_concatenated_multi_turn=True,
),
)
elif dataset_name == "dpo":
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index b111d2de91..f0a7e7a185 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -157,18 +157,3 @@ def read(self, batch_size: Optional[int] = None) -> List:
task = self.formatter.format(sample)
tasks.append(task)
return tasks
-
-
-class RawDataReader(BaseFileReader):
- def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]):
- self.returned = False
- self.dataset = load_dataset(meta.path, name=meta.subset_name, split=meta.split)
-
- def __len__(self):
- return len(self.dataset)
-
- def read(self, batch_size: Optional[int] = None) -> List:
- if self.returned:
- raise StopIteration
- self.returned = True
- return self.dataset.to_list()
diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py
index 4304da82ed..bf74589266 100644
--- a/trinity/buffer/schema/formatter.py
+++ b/trinity/buffer/schema/formatter.py
@@ -4,6 +4,7 @@
from trinity.common.config import FormatConfig, StorageConfig
from trinity.common.constants import PromptType
from trinity.common.experience import Experience
+from trinity.common.models.utils import get_action_mask_method
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.registry import Registry
@@ -99,10 +100,14 @@ class SFTFormatter(ExperienceFormatter):
def __init__(self, tokenizer, format_config: FormatConfig):
self.tokenizer = tokenizer
self.prompt_type = format_config.prompt_type
+ self.enable_concatenated_multi_turn = format_config.enable_concatenated_multi_turn
+ self.chat_template = format_config.chat_template or tokenizer.chat_template
# For messages type
if self.prompt_type == PromptType.MESSAGES:
self.messages_key = format_config.messages_key
self.tools_key = format_config.tools_key
+ if format_config.enable_concatenated_multi_turn:
+ self.action_mask_method = get_action_mask_method(self.chat_template)
# For plaintext type
elif self.prompt_type == PromptType.PLAINTEXT:
self.prompt_key = format_config.prompt_key
@@ -114,19 +119,52 @@ def __init__(self, tokenizer, format_config: FormatConfig):
raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
def _messages_to_experience(
- self, messages: List[Dict], tools: Optional[List[Dict]] = None
+ self,
+ messages: List[Dict],
+ tools: Optional[List[Dict]] = None,
) -> Experience:
+ """Convert messages and tools into an Experience object.
+
+ Args:
+ messages (List[Dict]): The list of message dictionaries.
+ tools (Optional[List[Dict]], optional): The list of tool dictionaries. Defaults to None.
+
+ Returns:
+ Experience: The resulting Experience object.
+ """
tokens = self.tokenizer.apply_chat_template(
- messages, tools=tools, add_generation_prompt=False, return_tensors="pt"
- )[0]
- prompt_tokens_ids = self.tokenizer.apply_chat_template(
- messages[:-1], tools=tools, add_generation_prompt=True, return_tensors="pt"
+ messages,
+ tools=tools,
+ add_generation_prompt=False,
+ return_tensors="pt",
+ chat_template=self.chat_template,
)[0]
- return Experience(
- tokens=tokens,
- prompt_length=len(prompt_tokens_ids),
- messages=messages,
- )
+ if self.enable_concatenated_multi_turn:
+ token_ids, action_mask, prompt_length = self.action_mask_method(
+ tokenizer=self.tokenizer,
+ messages=messages,
+ tools=tools,
+ chat_template=self.chat_template,
+ )
+ return Experience(
+ tokens=token_ids,
+ action_mask=action_mask[prompt_length:],
+ prompt_length=prompt_length,
+ messages=messages,
+ )
+ else:
+ prompt_tokens_ids = self.tokenizer.apply_chat_template(
+ messages[:-1],
+ tools=tools,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ chat_template=self.chat_template,
+ )[0]
+ return Experience(
+ tokens=tokens,
+ prompt_length=len(prompt_tokens_ids),
+ messages=messages,
+ )
def format(self, sample: Dict) -> Experience:
if self.prompt_type == PromptType.MESSAGES:
@@ -181,6 +219,7 @@ class DPOFormatter(ExperienceFormatter):
def __init__(self, tokenizer, format_config: FormatConfig):
self.tokenizer = tokenizer
self.prompt_type = format_config.prompt_type
+ self.chat_template = format_config.chat_template
if self.prompt_type == PromptType.PLAINTEXT:
self.prompt_key = format_config.prompt_key
self.chosen_key = format_config.chosen_key
@@ -199,13 +238,22 @@ def _messages_to_experience(
self, prompt_messages, chosen_messages, rejected_messages
) -> Experience:
prompt_tokens = self.tokenizer.apply_chat_template(
- prompt_messages, add_generation_prompt=True, return_tensors="pt"
+ prompt_messages,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ chat_template=self.chat_template,
)[0]
chosen_tokens = self.tokenizer.apply_chat_template(
- prompt_messages + chosen_messages, add_generation_prompt=False, return_tensors="pt"
+ prompt_messages + chosen_messages,
+ add_generation_prompt=False,
+ return_tensors="pt",
+ chat_template=self.chat_template,
)[0][len(prompt_tokens) :]
rejected_tokens = self.tokenizer.apply_chat_template(
- prompt_messages + rejected_messages, add_generation_prompt=False, return_tensors="pt"
+ prompt_messages + rejected_messages,
+ add_generation_prompt=False,
+ return_tensors="pt",
+ chat_template=self.chat_template,
)[0][len(prompt_tokens) :]
return Experience(
tokens=prompt_tokens,
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 1382a561dc..052eb97c1b 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -52,6 +52,12 @@ class FormatConfig:
chosen_key: str = "chosen"
rejected_key: str = "rejected"
+ # for multi-turn sft
+ enable_concatenated_multi_turn: bool = False
+
+ # for sft / dpo, if None, use model.custom_chat_template
+ chat_template: Optional[str] = None
+
@dataclass
class GenerationConfig:
@@ -619,14 +625,19 @@ def _check_buffer(self) -> None: # noqa: C901
)
self.buffer.trainer_input.experience_buffer.storage_type = StorageType.QUEUE
- if self.buffer.trainer_input.experience_buffer is not None:
- from trinity.algorithm.algorithm import ALGORITHM_TYPE
+ from trinity.algorithm.algorithm import ALGORITHM_TYPE
+
+ self.buffer.trainer_input.experience_buffer.schema_type = ALGORITHM_TYPE.get(
+ self.algorithm.algorithm_type
+ ).schema
+
+ if self.buffer.trainer_input.experience_buffer.ray_namespace is None:
+ self.buffer.trainer_input.experience_buffer.ray_namespace = self.ray_namespace
- self.buffer.trainer_input.experience_buffer.schema_type = ALGORITHM_TYPE.get(
- self.algorithm.algorithm_type
- ).schema
- if self.buffer.trainer_input.experience_buffer.ray_namespace is None:
- self.buffer.trainer_input.experience_buffer.ray_namespace = self.ray_namespace
+ if self.buffer.trainer_input.experience_buffer.format.chat_template is None:
+ self.buffer.trainer_input.experience_buffer.format.chat_template = (
+ self.model.custom_chat_template
+ )
# create buffer.cache_dir at ///buffer
self.buffer.cache_dir = os.path.abspath(os.path.join(self.checkpoint_job_dir, "buffer"))
@@ -653,6 +664,10 @@ def _check_buffer(self) -> None: # noqa: C901
)
if self.buffer.trainer_input.sft_warmup_dataset.ray_namespace is None:
self.buffer.trainer_input.sft_warmup_dataset.ray_namespace = self.ray_namespace
+ if self.buffer.trainer_input.sft_warmup_dataset.format.chat_template is None:
+ self.buffer.trainer_input.sft_warmup_dataset.format.chat_template = (
+ self.model.custom_chat_template
+ )
# check input/output buffers in pipelines
if self.data_processor.experience_pipeline is not None:
diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py
index 49760739eb..67f7262241 100644
--- a/trinity/common/models/utils.py
+++ b/trinity/common/models/utils.py
@@ -2,7 +2,7 @@
import os
import re
from concurrent.futures import ThreadPoolExecutor
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch.distributed._tensor import DTensor, Placement, Shard
@@ -13,14 +13,16 @@
def tokenize_and_mask_messages_hf(
tokenizer: Any,
messages: List[dict],
+ tools: Optional[List[dict]] = None,
chat_template: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Calculate the assistant token mask with `chat_template`.
Args:
tokenizer (Any): The tokenizer.
- chat_template (str): The chat template with `{% generation %}` symbol.
messages (List[dict]): Messages with `role` and `content` fields.
+ tools (Optional[List[dict]]): The list of tool dictionaries.
+ chat_template (str): The chat template with `{% generation %}` symbol.
Returns:
`torch.Tensor`: The token_ids (sequence_length)
@@ -29,6 +31,7 @@ def tokenize_and_mask_messages_hf(
"""
token_dict = tokenizer.apply_chat_template(
messages,
+ tools=tools,
chat_template=chat_template,
add_generation_prompt=False,
padding=False,
@@ -46,14 +49,16 @@ def tokenize_and_mask_messages_hf(
def tokenize_and_mask_messages_default(
tokenizer: Any,
messages: List[dict],
+ tools: Optional[List[dict]] = None,
chat_template: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Calculate the assistant token mask.
Args:
tokenizer (Any): The tokenizer.
- chat_template (str): The chat template with `{% generation %}` symbol.
messages (List[dict]): Messages with `role` and `content` fields.
+ tools (Optional[List[dict]]): The list of tool dictionaries.
+ chat_template (str): The chat template with `{% generation %}` symbol.
Returns:
`torch.Tensor`: The token_ids (sequence_length)
@@ -69,6 +74,7 @@ def tokenize_and_mask_messages_default(
tokens = tokenizer.apply_chat_template(
messages,
+ tools=tools,
chat_template=chat_template,
add_generation_prompt=False,
padding=False,
@@ -81,6 +87,7 @@ def tokenize_and_mask_messages_default(
if message["role"] == "assistant":
prompt_token_ids = tokenizer.apply_chat_template(
messages[:idx],
+ tools=tools,
chat_template=chat_template,
add_generation_prompt=True,
padding=False,
@@ -91,6 +98,7 @@ def tokenize_and_mask_messages_default(
prompt_length = prompt_token_ids.shape[1]
prompt_response_token_ids = tokenizer.apply_chat_template(
messages[: idx + 1],
+ tools=tools,
chat_template=chat_template,
add_generation_prompt=False,
padding=False,
@@ -104,6 +112,24 @@ def tokenize_and_mask_messages_default(
return tokens[0], assistant_token_mask, prompt_length
+def get_action_mask_method(chat_template: Optional[str] = None) -> Callable:
+ """Get the action mask method according to the chat template.
+
+ Args:
+ chat_template (str): The chat template. If { % generation % } is present, use HF tokenizer's `return_assistant_tokens_mask`.
+
+ Returns:
+ The action mask method.
+ """
+ if chat_template is None:
+ return tokenize_and_mask_messages_default
+ # check if the chat template contains `{% generation %}` symbol
+ elif re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
+ return tokenize_and_mask_messages_hf
+ else:
+ return tokenize_and_mask_messages_default
+
+
def get_checkpoint_dir_with_step_num(
checkpoint_root_path: str,
trainer_type: str = "verl",
diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py
index 8cd3469d3f..3860cb05c7 100644
--- a/trinity/common/models/vllm_model.py
+++ b/trinity/common/models/vllm_model.py
@@ -1,8 +1,6 @@
-"""A wrapper around the vllm.AsyncEngine to handle async requests.
-"""
+"""A wrapper around the vllm.AsyncEngine to handle async requests."""
import os
-import re
from typing import Any, Dict, List, Optional, Sequence, Union
import aiohttp
@@ -14,10 +12,7 @@
from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
-from trinity.common.models.utils import (
- tokenize_and_mask_messages_default,
- tokenize_and_mask_messages_hf,
-)
+from trinity.common.models.utils import get_action_mask_method
from trinity.utils.log import get_logger
@@ -83,17 +78,7 @@ def __init__(
self.chat_template = None
if self.config.chat_template:
self.chat_template = self.config.chat_template
- if self.chat_template is None or not re.search(
- r"\{\%-?\s*generation\s*-?\%\}", self.chat_template
- ):
- self.logger.warning(
- "The provided chat template does not support `return_assitant_tokens_mask`. "
- "The default assistant mask method will be used, which may cause performance "
- "degradation and lead to incorrect results."
- )
- self.action_mask_method = tokenize_and_mask_messages_default
- else:
- self.action_mask_method = tokenize_and_mask_messages_hf
+ self.action_mask_method = get_action_mask_method(self.chat_template)
self.state_dict_meta = None
self.model_version = 0 # TODO: resume the value from the checkpoint
self.api_server_host = None
@@ -216,14 +201,21 @@ async def _generate_internal(self, prompt: Any, **kwargs) -> Any:
raise RuntimeError("[vLLM] The request is not finished. This should not happen.")
- async def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
+ async def convert_messages_to_experience(
+ self,
+ messages: List[dict],
+ tools: Optional[List[dict]] = None,
+ ) -> Experience:
"""Convert a list of messages into an experience."""
if self.tokenizer is None:
self.tokenizer = await self.async_llm.get_tokenizer()
if self.chat_template is None:
self.chat_template = self.tokenizer.get_chat_template()
token_ids, action_mask, prompt_length = self.action_mask_method(
- self.tokenizer, messages, self.chat_template
+ tokenizer=self.tokenizer,
+ messages=messages,
+ tools=tools,
+ chat_template=self.chat_template,
) # (seq_length, ), (seq_length, )
logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,)
return Experience(