Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions benchmark/config/countdown-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ buffer:
priority_fn: linear_decay
decay: 0.1
sft_warmup_steps: 0
max_retry_times: 3
max_retry_interval: 1
explorer:
runner_num: 32
max_timeout: 900
Expand Down
2 changes: 0 additions & 2 deletions benchmark/config/gsm8k-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ buffer:
priority_fn: linear_decay
decay: 0.1
sft_warmup_steps: 0
max_retry_times: 3
max_retry_interval: 1
explorer:
runner_per_model: 8
max_timeout: 900
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/_templates/versions.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: {{ current_version.name }}
<b>{{ current_version.name }}</b>
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
Expand All @@ -18,7 +18,7 @@
<dl>
<dt>Branches</dt>
{%- for item in versions.branches %}
<dd><a href="{{ item.url }}">{{ item.name }}</a> <b>(latest)</b></dd>
<dd><b><a href="{{ item.url }}">{{ item.name }}</a> (latest)</b></dd>
{%- endfor %}
</dl>
{%- endif %}
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MIXAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: type = ExperienceModel
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
Expand Down
2 changes: 0 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_step_wise.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ buffer:
total_epochs: 20
batch_size: 16
train_batch_size: 7680 # here: batch_size * repeat_times * max_env_steps
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: alfworld
Expand Down
3 changes: 1 addition & 2 deletions docs/sphinx_doc/source/tutorial/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from trinity.common.schema import ExperienceModel
from trinity.common.schema.sql_schema import ExperienceModel

engine = create_engine(buffer.trainer_input.experience_buffer.path)
session = sessionmaker(bind=engine)
Expand All @@ -129,7 +129,6 @@ sess = session()
MAX_EXPERIENCES = 4
experiences = (
sess.query(ExperienceModel)
.with_for_update()
.limit(MAX_EXPERIENCES)
.all()
)
Expand Down
83 changes: 35 additions & 48 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ synchronizer:
# Model weight synchronization settings
...
monitor:
# Monitoring configurations (e.g., WandB or TensorBoard)
# Monitoring configurations (e.g., WandB, TensorBoard or MLFlow)
...
service:
# Services to use
Expand All @@ -48,10 +48,12 @@ log:
...
```

Each of these sections will be explained in detail below.
Each of these sections will be explained in detail below. For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py).

```{note}
For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py).
```{tip}
Trinity-RFT uses [OmegaConf](https://omegaconf.readthedocs.io/en/latest/) to load YAML configuration files.
It supports some advanced features like [variable interpolation](https://omegaconf.readthedocs.io/en/latest/usage.html#variable-interpolation) and [environment variable substitution](https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html#oc-env).
Users can use these features to simplify configuration.
```

---
Expand All @@ -64,7 +66,7 @@ These are general settings that apply to the entire experiment.
project: Trinity-RFT
name: example
mode: both
checkpoint_root_dir: /PATH/TO/CHECKPOINT
checkpoint_root_dir: ${oc.env:CHECKPOINT_ROOT_DIR} # CHECKPOINT_ROOT_DIR is an environment variable set in advance
```

- `project`: The name of the project.
Expand Down Expand Up @@ -115,13 +117,25 @@ Used to log training metrics during execution.
```yaml
monitor:
monitor_type: wandb
monitor_args:
base_url: http://localhost:8080
api_key: your_api_key
enable_ray_timeline: False
```

- `monitor_type`: Type of monitoring system. Options:
- `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
- `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `<checkpoint_root_dir>/<project>/<name>/monitor/tensorboard`.
- `enable_ray_timeline`: Whether to export the ray timeline. If set to `True`, a `timeline.json` file will be exported to `<checkpoint_root_dir>/<project>/<name>/monitor`. You can view the timeline file in Chrome at [chrome://tracing](chrome://tracing).
- `mlflow`: Logs to [MLFlow](https://mlflow.org/). If [MLFlow authentication](https://mlflow.org/docs/latest/ml/auth/) is setup, set `MLFLOW_TRACKING_USERNAME` and `MLFLOW_TRACKING_PASSWORD` as environment variables before running.
- `monitor_args`: Dictionary of arguments for monitor initialization.
- For `wandb`:
- `base_url`: Overrides `WANDB_BASE_URL` if set.
- `api_key`: Overrides `WANDB_API_KEY` if set.
- For `mlflow`:
- `uri`: The URI of your MLFlow instance. Strongly recommended to set; defaults to `http://localhost:5000`.
- `username`: Overrides `MLFLOW_TRACKING_USERNAME` if set.
- `password`: Overrides `MLFLOW_TRACKING_PASSWORD` if set.
- `enable_ray_timeline`: If `True`, exports a `timeline.json` file to `<checkpoint_root_dir>/<project>/<name>/monitor`. Viewable in Chrome at [chrome://tracing](chrome://tracing).

---

Expand All @@ -131,8 +145,8 @@ Defines the model paths and token limits.

```yaml
model:
model_path: /PATH/TO/MODEL/
critic_model_path: ''
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance
critic_model_path: ${model.model_path} # use the value of model.model_path
max_response_tokens: 16384
max_model_len: 20480
```
Expand Down Expand Up @@ -174,10 +188,6 @@ buffer:
...
eval_tasksets:
...

explorer_output:
...

trainer_input:
experience_buffer:
...
Expand Down Expand Up @@ -255,41 +265,6 @@ The configuration for each task dataset is defined as follows:
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.


### Explorer Output

In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section.

```{note}
For `both` and `train` modes, users should use `buffer.trainer_input.experience_buffer` instead of `buffer.explorer_output`.
```

```yaml
buffer:
...
explorer_output:
name: countdown_buffer
storage_type: queue
path: sqlite:///countdown_buffer.db
wrap_in_ray: True
max_read_timeout: 1800
```

- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
- `storage_type`: The storage type for the experience buffer.
- `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
- `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes.
- `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
- `path`: The path to the experience buffer.
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- For `file` storage type, the path points to the directory containing the dataset files.
- For `sql` storage type, the path points to the SQLite database file.
- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor.
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.


### Trainer Input

Defines the experience buffer and optional SFT warm-up dataset.
Expand All @@ -314,7 +289,19 @@ buffer:
sft_warmup_steps: 0
```

- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`.
- `experience_buffer`: It is the input of Trainer and also the output of Explorer. This field is required even in explore mode.
- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
- `storage_type`: The storage type for the experience buffer.
- `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
- `sql`: Experience data is stored in a SQL database.
- `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
- `path`: The path to the experience buffer.
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- For `file` storage type, the path points to the directory containing the dataset files.
- For `sql` storage type, the path points to the SQLite database file.
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.

Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class OPMDPolicyLossFn(PolicyLossFn):

The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.

To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {object}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.

The `AlgorithmType` class includes the following attributes and methods:

Expand All @@ -473,7 +473,7 @@ class OPMDAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: type = ExperienceModel
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
Expand Down
2 changes: 0 additions & 2 deletions examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 30
batch_size: 80
max_retry_times: 1
max_retry_interval: 1
explorer_input:
taskset:
name: alfworld-train
Expand Down
2 changes: 0 additions & 2 deletions examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 30
batch_size: 80
max_retry_times: 1
max_retry_interval: 1
explorer_input:
taskset:
name: alfworld-train
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ buffer:
total_epochs: 1
batch_size: 32
train_batch_size: 512
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: dapo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ buffer:
total_epochs: 1
batch_size: 32
train_batch_size: 256
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
2 changes: 0 additions & 2 deletions examples/asymre_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ cluster:
buffer:
total_steps: 80
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
2 changes: 0 additions & 2 deletions examples/asymre_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ cluster:
buffer:
total_steps: 2000 # Exactly 2000 training steps as desired
batch_size: 16 # 128 trajectories per gradient step, batch_size is the number of tasks per batch
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: math
Expand Down
2 changes: 0 additions & 2 deletions examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
2 changes: 0 additions & 2 deletions examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 1
train_batch_size: 768
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
2 changes: 0 additions & 2 deletions examples/dapo_math/dapo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 32
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: dapo-math
Expand Down
2 changes: 0 additions & 2 deletions examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ cluster:
buffer:
total_epochs: 2
train_batch_size: 64
max_retry_times: 3
max_retry_interval: 1
trainer_input:
experience_buffer:
name: dpo_buffer
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 32
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: alfworld
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_alfworld_general_multi_step/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ buffer:
total_epochs: 20
batch_size: 16
train_batch_size: 7680 # 16 * 16 * 30
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: alfworld
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_email_search/email_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ buffer:
total_epochs: 1
batch_size: 16
train_batch_size: 640 # 16*8*5
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: enron_train
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 288
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: math
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 4
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: sciworld
Expand Down
Loading