Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ modelscope download --dataset HumanLLMs/Human-Like-DPO-Dataset --local_dir $DATA
huggingface-cli download HumanLLMs/Human-Like-DPO-Dataset --repo-type dataset --local-dir $DATASET_PATH/human_like_dpo_dataset
```

Below are some data samples in JSONL format:
```json
{"prompt":"Oh, I just saw the best meme - have you seen it?","chosen":"\ud83d\ude02 Ah, no I haven't! I'm dying to know, what's the meme about? Is it a funny cat or a ridiculous situation? Spill the beans! \ud83e\udd23","rejected":"I'm an artificial intelligence language model, I don't have personal experiences or opinions. However, I can provide you with information on highly-rated and critically acclaimed films, as well as recommendations based on specific genres or themes. Would you like me to suggest some notable movies or discuss a particular genre of interest?"}
{"prompt":"Have you tried any new hobbies or activities recently?","chosen":"You know, I've been meaning to try my hand at gardening, but I haven't gotten around to it yet. I've heard it's super relaxing and a great way to get some fresh air. Maybe I'll finally get around to buying some seeds and pots this weekend. What about you? Have you taken up anything new and exciting lately? \ud83c\udf31\ud83d\udc40","rejected":"I'm an artificial intelligence language model, and as such, I don't have personal experiences or engage in physical activities such as dining or cooking. My purpose is to provide information, answer questions, and assist with tasks to the best of my abilities, while maintaining a professional and impartial demeanor. If you have any specific questions or topics related to restaurants or recipes, I'd be happy to provide information or guidance."}
```

More details of dataset downloading are referred to [ModelScope](https://modelscope.cn/docs/datasets/download) or [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space).

Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pass the proper keys to the config.
Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If you use different datasets, pass the proper keys to the config.

For SFT, we download the dataset to the local directory `/PATH/TO/SFT_DATASET/`, which usually contains message-based data.
For SFT, we download the `open-r1/Mixture-of-Thoughts` dataset to the local directory `$DATASET_PATH/Mixture-of-Thoughts`, which contains message-based data, we list a simplified sample here.

```json
{"messages": [{"content": "You will be given a competitive programming problem...","role": "user"},{"content": "<think>\n...</think>\n...This approach efficiently combines hashing and dynamic programming to solve the problem within the given constraints.","role": "assistant"}], "num_tokens": 22185, "source": "open-r1/codeforces-cots"}
```

## Step 2: Setup Configuration

Expand Down Expand Up @@ -77,9 +87,22 @@ trainer:
save_interval: 30
```

`buffer.trainer_input.experience_buffer` specifies the dataset to be used for training, including its name, storage type, path, and format.

- The name `human_like_dpo` is a unique identifier for this dataset configuration, you can use other names as long as they are unique within the project.
- The storage type `file` means the dataset is stored in a file on the local filesystem and the `path` is pointed to the local directory where the dataset is stored. Note that the `file` storage type also supports using huggingface datasets path like `HumanLLMs/Human-Like-DPO-Dataset`.
- The format specifies how the data is structured within the dataset. In this case, it is defined as follows:

- `prompt_type: plaintext` indicates that the prompts are in plain text format.
- `prompt_key: prompt` specifies the key in the dataset that contains the user prompts.
- `chosen_key: chosen` specifies the key in the dataset that contains the chosen responses.
- `rejected_key: rejected` specifies the key in the dataset that contains the rejected responses.

For more configuration options, please refer to the {ref}`Configuration Guide <Configuration Guide>`.

### Configuration for SFT

We set the `algorithm_type` as `sft` to run SFT process. Then we modify the config file [`examples/sft_mot/sft.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/sft_mot/sft.yaml) with the following changes:
We set the `algorithm_type` as `sft` to run SFT process and then modify the config file [`examples/sft_mot/sft.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/sft_mot/sft.yaml) with the following changes:

```yaml
project: <project_name>
Expand All @@ -100,7 +123,7 @@ buffer:
experience_buffer:
name: <sft_dataset_name>
storage_type: file
path: /PATH/TO/SFT_DATASET/
path: $DATASET_PATH/Mixture-of-Thoughts
split: train
format:
prompt_type: messages
Expand All @@ -110,13 +133,16 @@ trainer:
save_interval: 50
```

Here we set `buffer.trainer_input.experience_buffer.format.prompt_type` to `messages` because the source data is in message format. We also set `buffer.trainer_input.experience_buffer.format.messages_key` to `messages` to specify the key in the dataset that contains the messages.

## Step 3: Run the Experiment

Run DPO process with the following command:

```shell
trinity run --config examples/dpo_humanlike/dpo.yaml
```

or, for SFT:

```shell
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ We prompt a powerful LLM to generate responses with the CoT process for some pre
```json
{
"messages": [
{ "role": "system", "content": <system_prompt> },
{ "role": "system", "content": "<system_prompt>" },
{ "role": "user", "content": "What is the sum of 4 and 12?" },
{ "role": "assistant", "content": "<think>thinking process...</think>\n<answer>16</answer>" } ]
},
Expand Down
22 changes: 18 additions & 4 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
(Configuration Guide)=
# Configuration Guide

This section provides a detailed description of the configuration files used in **Trinity-RFT**.
Expand Down Expand Up @@ -231,6 +232,7 @@ buffer:
split: test
repeat_times: 1
format:
prompt_type: `plaintext`
prompt_key: 'question'
response_key: 'answer'
rollout_args:
Expand All @@ -256,9 +258,6 @@ The configuration for each task dataset is defined as follows:
- `subset_name`: The subset name of the task dataset. Default is `None`.
- `split`: The split of the task dataset. Default is `train`.
- `repeat_times`: The number of rollouts generated for a task. If not set, it will be automatically set to `algorithm.repeat_times` for `taskset`, and `1` for `eval_tasksets`.
- `format`: Defines keys for prompts and responses in the dataset.
- `prompt_key`: Specifies which column in the dataset contains the prompt data.
- `response_key`: Specifies which column in the dataset contains the response data.
- `rollout_args`: The parameters for rollout.
- `temperature`: The temperature for sampling.
- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used.
Expand All @@ -277,6 +276,7 @@ buffer:
name: countdown_buffer
storage_type: queue
path: sqlite:///countdown_buffer.db
max_read_timeout: 1800

sft_warmup_dataset:
name: warmup_data
Expand All @@ -299,10 +299,24 @@ buffer:
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- For `file` storage type, the path points to the directory containing the dataset files.
- For `sql` storage type, the path points to the SQLite database file.
- `format`: Defines keys for prompts and responses in the dataset.
- `prompt_type`: Specifies the type of prompts in the dataset. We support `plaintext`, `messages` for now.
- `plaintext`: The prompt is in string format.
- `messages`: The prompt is organized as a message list.
- `prompt_key`: Specifies which column in the dataset contains the user prompt data. Only for `plaintext`.
- `response_key`: Specifies which column in the dataset contains the response data. Only for `plaintext`.
- `system_prompt_key`: Specifies which column in the dataset contains the system prompt data. Only for `plaintext`.
- `system_prompt`: Specifies the system prompt in string format. It has lower priority than `system_prompt_key`. Only for `plaintext`.
- `messages_key`: Specifies which column in the dataset contains the messages data. Only for `messages`.
- `tools_key`: Specifies which column in the dataset contains the tools data. Support both `plaintext` and `messages`, but the tool data should be organized as a list of dict.
- `chosen_key`: Specifies which column in the dataset contains the DPO chosen data. Support both `plaintext` and `messages`, and the data type should be consistent with the prompt type.
- `rejected_key`: Similar to `chosen_key`, but it specifies which column in the dataset contains the DPO rejected data.
- `enable_concatenated_multi_turn`: Enable concatenated multi-turn SFT data preprocess. Only for `messages` and only take effect with SFT algorithm.
- `chat_template`: Specifies the chat template in string format. If not provided, use `model.custom_chat_template`.
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup). Its configuration is similar to the `experience_buffer`, but only for SFT usage.
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.

---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task`

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:

```
```json
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
Expand Down
2 changes: 0 additions & 2 deletions examples/sppo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ cluster:
buffer:
total_steps: 100
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
48 changes: 42 additions & 6 deletions tests/buffer/formatter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@ def test_sft_messages_formatter(self):
self.assertIn("Hello", sequence)

# test tool
config = FormatConfig(
prompt_type=PromptType.MESSAGES,
messages_key="messages",
tools_key="tools",
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)

sample = {
"messages": [
{
Expand Down Expand Up @@ -99,19 +94,60 @@ def test_sft_messages_formatter(self):
}
],
}
config = FormatConfig(
prompt_type=PromptType.MESSAGES,
messages_key="messages",
tools_key="tools",
enable_concatenated_multi_turn=False,
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
self.assertIsNotNone(exp.tokens)
self.assertIsNotNone(exp.prompt_length)
self.assertTrue(exp.prompt_length < len(exp.tokens))
self.assertIsNotNone(exp.action_mask)
self.assertEqual(len(exp.action_mask) + exp.prompt_length, len(exp.tokens))
# assert action mask is all true
self.assertTrue(all(exp.action_mask.tolist()))
sequence = self.tokenizer.decode(exp.tokens)
self.assertIn("What's the weather like in Beijing today?", sequence)
self.assertIn("Let me get the weather for you.", sequence)
self.assertIn(
"The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
sequence,
)
self.assertIn("get_weather", sequence)

config = FormatConfig(
prompt_type=PromptType.MESSAGES,
messages_key="messages",
tools_key="tools",
enable_concatenated_multi_turn=True,
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
self.assertIsNotNone(exp.tokens)
self.assertIsNotNone(exp.prompt_length)
self.assertTrue(exp.prompt_length < len(exp.tokens))
self.assertIsNotNone(exp.action_mask)
self.assertEqual(len(exp.action_mask) + exp.prompt_length, len(exp.tokens))
self.assertTrue(any(exp.action_mask.tolist()) and not all(exp.action_mask.tolist()))
prompt = self.tokenizer.decode(exp.tokens[: exp.prompt_length])
response = self.tokenizer.decode(exp.tokens[exp.prompt_length :])
self.assertIn("What's the weather like in Beijing today?", prompt)
self.assertNotIn("Let me get the weather for you.", prompt)
self.assertIn("Let me get the weather for you.", response)
self.assertNotIn(
"The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
prompt,
)
self.assertIn(
"The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
response,
)

def test_sft_plaintext_formatter(self):
# with system prompt key
config = FormatConfig(
Expand Down
85 changes: 80 additions & 5 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def print_debug(*args):
print(*args)


# Qwen2.5 chat template with {% generation %} mark
CHAT_TEMPLATE = r"""
{%- if tools %}
{{- '<|im_start|>system\n' }}
Expand Down Expand Up @@ -67,19 +68,19 @@ def print_debug(*args):
{%- elif (message.role == "assistant" and not message.tool_calls) %}
{{- '<|im_start|>' + message.role + '\n'}}{% generation %}{{- message.content + '<|im_end|>' + '\n' }}{% endgeneration %}
{%- elif message.role == "assistant" %}
{{- '<|im_start|>' + message.role }}{% generation %}
{{- '<|im_start|>' + message.role + '\n'}}{% generation %}
{%- if message.content %}
{{- '\n' + message.content }}
{{- message.content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n{"name": "' }}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\n</tool_call>' }}
{{- '}\n</tool_call>\n' }}
{%- endfor %}
{{- '<|im_end|>\n' }}{% endgeneration %}
{%- elif message.role == "tool" %}
Expand Down Expand Up @@ -307,7 +308,7 @@ def test_api(self):


class TestTokenizer(unittest.TestCase):
def test_assistant_token_mask(self):
def test_action_mask(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
Expand Down Expand Up @@ -338,6 +339,80 @@ def test_assistant_token_mask(self):
self.assertTrue(torch.equal(action_mask, action_mask_hf))
self.assertEqual(prompt_length, prompt_length_hf)

def test_action_mask_with_tools(self):
messages = [
{
"role": "system",
"content": "You are a helpful assistant with access to various tools. Use them when needed to help users.",
},
{"role": "user", "content": "What's the weather like in Beijing today?"},
{
"role": "assistant",
"content": "Let me get the weather for you.",
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "Beijing", "unit": "celsius"}',
},
}
],
},
{
"role": "tool",
"content": '{"temperature": 22, "condition": "sunny", "humidity": 45}',
"tool_call_id": "call_abc123",
},
{
"role": "assistant",
"content": "The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
},
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit",
},
},
"required": ["location"],
},
},
}
]
tokenizer = AutoTokenizer.from_pretrained(get_model_path())
token_ids, action_mask, prompt_length = tokenize_and_mask_messages_default(
tokenizer=tokenizer,
messages=messages,
tools=tools,
chat_template=CHAT_TEMPLATE,
)
token_ids_hf, action_mask_hf, prompt_length_hf = tokenize_and_mask_messages_hf(
tokenizer=tokenizer,
messages=messages,
tools=tools,
chat_template=CHAT_TEMPLATE,
)
self.assertEqual(token_ids.shape, token_ids_hf.shape)
self.assertEqual(action_mask.shape, action_mask_hf.shape)
self.assertTrue(torch.equal(token_ids, token_ids_hf))
self.assertTrue(torch.equal(action_mask, action_mask_hf))
self.assertEqual(prompt_length, prompt_length_hf)


@parameterized_class(
("enable_thinking", "reasoning_parser"),
Expand Down
1 change: 1 addition & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def get_unittest_dataset_config(
prompt_type=PromptType.MESSAGES,
messages_key="messages",
tools_key="tools",
enable_concatenated_multi_turn=True,
),
)
elif dataset_name == "dpo":
Expand Down
15 changes: 0 additions & 15 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,3 @@ def read(self, batch_size: Optional[int] = None) -> List:
task = self.formatter.format(sample)
tasks.append(task)
return tasks


class RawDataReader(BaseFileReader):
def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]):
self.returned = False
self.dataset = load_dataset(meta.path, name=meta.subset_name, split=meta.split)

def __len__(self):
return len(self.dataset)

def read(self, batch_size: Optional[int] = None) -> List:
if self.returned:
raise StopIteration
self.returned = True
return self.dataset.to_list()
Loading