@@ -18,7 +18,7 @@
- Branches
{%- for item in versions.branches %}
- - {{ item.name }} (latest)
+ - {{ item.name }} (latest)
{%- endfor %}
{%- endif %}
diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
index fd58b75144..5df890cd96 100644
--- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md
+++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
@@ -54,7 +54,7 @@ class MIXAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
diff --git a/docs/sphinx_doc/source/tutorial/example_step_wise.md b/docs/sphinx_doc/source/tutorial/example_step_wise.md
index be9df5880b..2d7e8bec5a 100644
--- a/docs/sphinx_doc/source/tutorial/example_step_wise.md
+++ b/docs/sphinx_doc/source/tutorial/example_step_wise.md
@@ -107,8 +107,6 @@ buffer:
total_epochs: 20
batch_size: 16
train_batch_size: 7680 # here: batch_size * repeat_times * max_env_steps
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: alfworld
diff --git a/docs/sphinx_doc/source/tutorial/faq.md b/docs/sphinx_doc/source/tutorial/faq.md
index 63b0fe5b1f..c408c69180 100644
--- a/docs/sphinx_doc/source/tutorial/faq.md
+++ b/docs/sphinx_doc/source/tutorial/faq.md
@@ -120,7 +120,7 @@ from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
-from trinity.common.schema import ExperienceModel
+from trinity.common.schema.sql_schema import ExperienceModel
engine = create_engine(buffer.trainer_input.experience_buffer.path)
session = sessionmaker(bind=engine)
@@ -129,7 +129,6 @@ sess = session()
MAX_EXPERIENCES = 4
experiences = (
sess.query(ExperienceModel)
- .with_for_update()
.limit(MAX_EXPERIENCES)
.all()
)
diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 567ef3f626..6b109f261a 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -35,7 +35,7 @@ synchronizer:
# Model weight synchronization settings
...
monitor:
- # Monitoring configurations (e.g., WandB or TensorBoard)
+ # Monitoring configurations (e.g., WandB, TensorBoard or MLFlow)
...
service:
# Services to use
@@ -48,10 +48,12 @@ log:
...
```
-Each of these sections will be explained in detail below.
+Each of these sections will be explained in detail below. For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py).
-```{note}
-For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py).
+```{tip}
+Trinity-RFT uses [OmegaConf](https://omegaconf.readthedocs.io/en/latest/) to load YAML configuration files.
+It supports some advanced features like [variable interpolation](https://omegaconf.readthedocs.io/en/latest/usage.html#variable-interpolation) and [environment variable substitution](https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html#oc-env).
+Users can use these features to simplify configuration.
```
---
@@ -64,7 +66,7 @@ These are general settings that apply to the entire experiment.
project: Trinity-RFT
name: example
mode: both
-checkpoint_root_dir: /PATH/TO/CHECKPOINT
+checkpoint_root_dir: ${oc.env:CHECKPOINT_ROOT_DIR} # CHECKPOINT_ROOT_DIR is an environment variable set in advance
```
- `project`: The name of the project.
@@ -115,13 +117,25 @@ Used to log training metrics during execution.
```yaml
monitor:
monitor_type: wandb
+ monitor_args:
+ base_url: http://localhost:8080
+ api_key: your_api_key
enable_ray_timeline: False
```
- `monitor_type`: Type of monitoring system. Options:
- `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
- `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `
///monitor/tensorboard`.
-- `enable_ray_timeline`: Whether to export the ray timeline. If set to `True`, a `timeline.json` file will be exported to `///monitor`. You can view the timeline file in Chrome at [chrome://tracing](chrome://tracing).
+ - `mlflow`: Logs to [MLFlow](https://mlflow.org/). If [MLFlow authentication](https://mlflow.org/docs/latest/ml/auth/) is setup, set `MLFLOW_TRACKING_USERNAME` and `MLFLOW_TRACKING_PASSWORD` as environment variables before running.
+- `monitor_args`: Dictionary of arguments for monitor initialization.
+ - For `wandb`:
+ - `base_url`: Overrides `WANDB_BASE_URL` if set.
+ - `api_key`: Overrides `WANDB_API_KEY` if set.
+ - For `mlflow`:
+ - `uri`: The URI of your MLFlow instance. Strongly recommended to set; defaults to `http://localhost:5000`.
+ - `username`: Overrides `MLFLOW_TRACKING_USERNAME` if set.
+ - `password`: Overrides `MLFLOW_TRACKING_PASSWORD` if set.
+- `enable_ray_timeline`: If `True`, exports a `timeline.json` file to `///monitor`. Viewable in Chrome at [chrome://tracing](chrome://tracing).
---
@@ -131,8 +145,8 @@ Defines the model paths and token limits.
```yaml
model:
- model_path: /PATH/TO/MODEL/
- critic_model_path: ''
+ model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance
+ critic_model_path: ${model.model_path} # use the value of model.model_path
max_response_tokens: 16384
max_model_len: 20480
```
@@ -174,10 +188,6 @@ buffer:
...
eval_tasksets:
...
-
- explorer_output:
- ...
-
trainer_input:
experience_buffer:
...
@@ -255,41 +265,6 @@ The configuration for each task dataset is defined as follows:
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.
-
-### Explorer Output
-
-In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section.
-
-```{note}
-For `both` and `train` modes, users should use `buffer.trainer_input.experience_buffer` instead of `buffer.explorer_output`.
-```
-
-```yaml
-buffer:
- ...
- explorer_output:
- name: countdown_buffer
- storage_type: queue
- path: sqlite:///countdown_buffer.db
- wrap_in_ray: True
- max_read_timeout: 1800
-```
-
-- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
-- `storage_type`: The storage type for the experience buffer.
- - `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
- - `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes.
- - `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
-- `path`: The path to the experience buffer.
- - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- - For `file` storage type, the path points to the directory containing the dataset files.
- - For `sql` storage type, the path points to the SQLite database file.
-- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor.
-- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
-- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
-- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
-
-
### Trainer Input
Defines the experience buffer and optional SFT warm-up dataset.
@@ -314,7 +289,19 @@ buffer:
sft_warmup_steps: 0
```
-- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`.
+- `experience_buffer`: It is the input of Trainer and also the output of Explorer. This field is required even in explore mode.
+ - `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
+ - `storage_type`: The storage type for the experience buffer.
+ - `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
+ - `sql`: Experience data is stored in a SQL database.
+ - `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
+ - `path`: The path to the experience buffer.
+ - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
+ - For `file` storage type, the path points to the directory containing the dataset files.
+ - For `sql` storage type, the path points to the SQLite database file.
+ - `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
+ - `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
+ - `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.
diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
index 8346bb0cf9..c395880a55 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
@@ -447,7 +447,7 @@ class OPMDPolicyLossFn(PolicyLossFn):
The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.
-To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {object}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
+To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
The `AlgorithmType` class includes the following attributes and methods:
@@ -473,7 +473,7 @@ class OPMDAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
diff --git a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
index c195b255bc..e79f1ce11c 100644
--- a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
+++ b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
@@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 30
batch_size: 80
- max_retry_times: 1
- max_retry_interval: 1
explorer_input:
taskset:
name: alfworld-train
diff --git a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
index 6dc17ae32c..9ecea9e923 100644
--- a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
+++ b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
@@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 30
batch_size: 80
- max_retry_times: 1
- max_retry_interval: 1
explorer_input:
taskset:
name: alfworld-train
diff --git a/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml b/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml
index b16d9a90f2..55699cfbd3 100644
--- a/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml
+++ b/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml
@@ -16,8 +16,6 @@ buffer:
total_epochs: 1
batch_size: 32
train_batch_size: 512
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: dapo
diff --git a/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml b/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml
index 1b8b92bdbc..0e8c2559e5 100644
--- a/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml
+++ b/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml
@@ -16,8 +16,6 @@ buffer:
total_epochs: 1
batch_size: 32
train_batch_size: 256
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/asymre_gsm8k/gsm8k.yaml b/examples/asymre_gsm8k/gsm8k.yaml
index d825d8f89f..3eeab4d3f4 100644
--- a/examples/asymre_gsm8k/gsm8k.yaml
+++ b/examples/asymre_gsm8k/gsm8k.yaml
@@ -18,8 +18,6 @@ cluster:
buffer:
total_steps: 80
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml
index b8472f9e2d..4d93cdf972 100644
--- a/examples/asymre_math/math.yaml
+++ b/examples/asymre_math/math.yaml
@@ -22,8 +22,6 @@ cluster:
buffer:
total_steps: 2000 # Exactly 2000 training steps as desired
batch_size: 16 # 128 trajectories per gradient step, batch_size is the number of tasks per batch
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: math
diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml
index 38da6a5c27..452536a5ce 100644
--- a/examples/async_gsm8k/explorer.yaml
+++ b/examples/async_gsm8k/explorer.yaml
@@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml
index f1f6b08b86..baffdf88c1 100644
--- a/examples/async_gsm8k/trainer.yaml
+++ b/examples/async_gsm8k/trainer.yaml
@@ -15,8 +15,6 @@ cluster:
buffer:
total_epochs: 1
train_batch_size: 768
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml
index d5fdc73dd0..929ea18a25 100644
--- a/examples/dapo_math/dapo.yaml
+++ b/examples/dapo_math/dapo.yaml
@@ -17,8 +17,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 32
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: dapo-math
diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml
index 4ae424c162..0de0b49941 100644
--- a/examples/dpo_humanlike/dpo.yaml
+++ b/examples/dpo_humanlike/dpo.yaml
@@ -17,8 +17,6 @@ cluster:
buffer:
total_epochs: 2
train_batch_size: 64
- max_retry_times: 3
- max_retry_interval: 1
trainer_input:
experience_buffer:
name: dpo_buffer
diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml
index 122d90061b..c8fe7863ac 100644
--- a/examples/grpo_alfworld/alfworld.yaml
+++ b/examples/grpo_alfworld/alfworld.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 32
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: alfworld
diff --git a/examples/grpo_alfworld_general_multi_step/alfworld.yaml b/examples/grpo_alfworld_general_multi_step/alfworld.yaml
index 4421578ea6..9f0d79cba5 100644
--- a/examples/grpo_alfworld_general_multi_step/alfworld.yaml
+++ b/examples/grpo_alfworld_general_multi_step/alfworld.yaml
@@ -16,8 +16,6 @@ buffer:
total_epochs: 20
batch_size: 16
train_batch_size: 7680 # 16 * 16 * 30
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: alfworld
diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml
index 381ec6fb3b..d7722f4ebc 100644
--- a/examples/grpo_email_search/email_search.yaml
+++ b/examples/grpo_email_search/email_search.yaml
@@ -16,8 +16,6 @@ buffer:
total_epochs: 1
batch_size: 16
train_batch_size: 640 # 16*8*5
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: enron_train
diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml
index 4339472466..7673c47c44 100644
--- a/examples/grpo_gsm8k/gsm8k.yaml
+++ b/examples/grpo_gsm8k/gsm8k.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
index bda28ccbd2..70d881dfbb 100644
--- a/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
+++ b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
@@ -35,8 +35,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
index 86762fc10d..9a6c5358e0 100644
--- a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
+++ b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
@@ -33,8 +33,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml
index 168ea3cd3e..16e73ff043 100644
--- a/examples/grpo_math/math.yaml
+++ b/examples/grpo_math/math.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 288
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: math
diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml
index c469423af1..743d6e6e2e 100644
--- a/examples/grpo_sciworld/sciworld.yaml
+++ b/examples/grpo_sciworld/sciworld.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 4
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: sciworld
diff --git a/examples/grpo_toolcall/toolace.yaml b/examples/grpo_toolcall/toolace.yaml
index 0581603b6b..6ffe975973 100644
--- a/examples/grpo_toolcall/toolace.yaml
+++ b/examples/grpo_toolcall/toolace.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 128
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: toolace_data
diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml
index 218f2d9360..82fff99e20 100644
--- a/examples/grpo_webshop/webshop.yaml
+++ b/examples/grpo_webshop/webshop.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 4
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: webshop
diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml
index 26d04852e1..debdf515ac 100644
--- a/examples/mix_chord/mix_chord.yaml
+++ b/examples/mix_chord/mix_chord.yaml
@@ -33,8 +33,6 @@ buffer:
total_epochs: 4
batch_size: 32
train_batch_size: 320
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: openr1_data_filtered_int
@@ -59,7 +57,6 @@ buffer:
total_epochs: 25
name: SFT_data
storage_type: file
- algorithm_type: sft
path: /PATH/TO/SFT_DATASET
split: 'train'
format:
diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml
index 11eafb6cc9..94a05c71ad 100644
--- a/examples/mix_math/mix_math.yaml
+++ b/examples/mix_math/mix_math.yaml
@@ -27,8 +27,6 @@ buffer:
total_epochs: 10
batch_size: 32
train_batch_size: 320
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: math_train
@@ -59,7 +57,6 @@ buffer:
total_epochs: 10
name: math_sft
storage_type: file
- algorithm_type: sft
path: /PATH/TO/EXPERT_DATA/
split: 'train'
format:
diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml
index 6b12aa0e34..18cdcdb10b 100644
--- a/examples/opmd_gsm8k/opmd_gsm8k.yaml
+++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml
index 35b0f217a8..f3b342fa7d 100644
--- a/examples/ppo_countdown/countdown.yaml
+++ b/examples/ppo_countdown/countdown.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 20
batch_size: 96
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: countdown
diff --git a/examples/sft_mot/sft.yaml b/examples/sft_mot/sft.yaml
index a309e75c8b..603adcea7c 100644
--- a/examples/sft_mot/sft.yaml
+++ b/examples/sft_mot/sft.yaml
@@ -14,8 +14,6 @@ cluster:
buffer:
total_epochs: 1
train_batch_size: 64
- max_retry_times: 3
- max_retry_interval: 1
trainer_input:
experience_buffer:
name: MoT
diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py
new file mode 100644
index 0000000000..b3d225e122
--- /dev/null
+++ b/tests/buffer/experience_storage_test.py
@@ -0,0 +1,90 @@
+import os
+import queue
+import sqlite3
+import threading
+import time
+
+import torch
+
+from tests.tools import RayUnittestBaseAysnc
+from trinity.buffer.reader.sql_reader import SQLReader
+from trinity.buffer.writer.sql_writer import SQLWriter
+from trinity.common.config import BufferConfig, StorageConfig
+from trinity.common.constants import StorageType
+from trinity.common.experience import Experience
+
+DB_PATH = os.path.join(os.path.dirname(__file__), "test.db")
+
+
+class ExperienceStorageTest(RayUnittestBaseAysnc):
+ def setUp(self):
+ self.total_num = 8
+ self.put_batch_size = 2
+ self.train_batch_size = 4
+
+ self.config = BufferConfig(
+ train_batch_size=self.train_batch_size,
+ )
+ if os.path.exists(DB_PATH):
+ os.remove(DB_PATH)
+
+ async def test_sql_storage(self):
+ meta = StorageConfig(
+ name="test_storage",
+ schema_type="experience",
+ storage_type=StorageType.SQL,
+ max_read_timeout=3,
+ path=f"sqlite:///{DB_PATH}",
+ )
+
+ writer = SQLWriter(meta, self.config)
+ reader = SQLReader(meta, self.config)
+ self.assertEqual(await writer.acquire(), 1)
+ exps = [
+ Experience(
+ tokens=torch.tensor([float(j) for j in range(i + 1)]),
+ prompt_length=i,
+ reward=float(i),
+ logprobs=torch.tensor([0.1]),
+ )
+ for i in range(1, self.put_batch_size + 1)
+ ]
+ for exp in exps:
+ exp.info = {"model_version": 0, "use_count": 0}
+ for _ in range(self.total_num // self.put_batch_size):
+ await writer.write_async(exps)
+ for _ in range(self.total_num // self.train_batch_size):
+ exps = reader.read()
+ self.assertEqual(len(exps), self.train_batch_size)
+ exps = [
+ Experience(
+ tokens=torch.tensor([float(j) for j in range(i + 1)]),
+ reward=float(i),
+ logprobs=torch.tensor([0.1]),
+ action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
+ )
+ for i in range(1, self.put_batch_size * 2 + 1)
+ ]
+ writer.write(exps)
+ exps = reader.read(batch_size=self.put_batch_size * 2)
+ self.assertEqual(len(exps), self.put_batch_size * 2)
+
+ def thread_read(reader, result_queue):
+ try:
+ batch = reader.read()
+ result_queue.put(batch)
+ except StopIteration as e:
+ result_queue.put(e)
+
+ result_queue = queue.Queue()
+ t = threading.Thread(target=thread_read, args=(reader, result_queue))
+ t.start()
+ time.sleep(2) # make sure the thread is waiting for data
+ self.assertEqual(await writer.release(), 0)
+ t.join(timeout=1)
+ self.assertIsInstance(result_queue.get(), StopIteration)
+ conn = sqlite3.connect(DB_PATH)
+ cursor = conn.cursor()
+ value = cursor.execute("SELECT COUNT(*) FROM test_storage;").fetchall()
+ self.assertEqual(value[0][0], self.total_num + self.put_batch_size * 2)
+ self.assertRaises(StopIteration, reader.read, batch_size=1)
diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py
index fb38e43660..00aec40744 100644
--- a/tests/buffer/file_test.py
+++ b/tests/buffer/file_test.py
@@ -9,9 +9,7 @@
get_unittest_dataset_config,
)
from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer
-from trinity.buffer.reader.file_reader import RawDataReader
from trinity.buffer.utils import default_storage_path
-from trinity.buffer.writer.file_writer import JSONWriter
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType
@@ -30,34 +28,6 @@ def tearDownClass(cls):
if os.path.exists(cls.temp_output_path):
os.system(f"rm -rf {cls.temp_output_path}")
- async def test_file_buffer(self):
- meta = StorageConfig(
- name="test_buffer",
- path=os.path.join(self.temp_output_path, "buffer.jsonl"),
- storage_type=StorageType.FILE,
- raw=True,
- )
- data = [
- {"key1": 1, "key2": 2},
- {"key1": 3, "key2": 4},
- {"key1": 5, "key2": 6},
- {"key1": 7, "key2": 8},
- ]
-
- # test writer
- writer = JSONWriter(meta, None)
- await writer.acquire()
- writer.write(data)
- await writer.release()
-
- # test reader
- meta.path = self.temp_output_path
- reader = RawDataReader(meta, None)
- loaded_data = reader.read()
- self.assertEqual(len(loaded_data), 4)
- self.assertEqual(loaded_data, data)
- self.assertRaises(StopIteration, reader.read)
-
def test_file_reader(self): # noqa: C901
"""Test file reader."""
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
@@ -90,7 +60,6 @@ def test_file_reader(self): # noqa: C901
while True:
try:
tasks.extend(reader.read())
- print(f"read from buffer, current len {len(tasks)}.")
except StopIteration:
break
self.assertEqual(len(tasks), 20 - 8)
diff --git a/tests/buffer/formatter_test.py b/tests/buffer/formatter_test.py
index c1cef3de46..9da05e65af 100644
--- a/tests/buffer/formatter_test.py
+++ b/tests/buffer/formatter_test.py
@@ -3,13 +3,8 @@
from transformers import AutoTokenizer
from tests.tools import get_model_path
-from trinity.buffer.schema.formatter import (
- DPOMessagesFormatter,
- DPOPlaintextFormatter,
- SFTMessagesFormatter,
- SFTPlaintextFormatter,
-)
-from trinity.common.config import FormatConfig
+from trinity.buffer.schema.formatter import FORMATTER
+from trinity.common.config import FormatConfig, StorageConfig
from trinity.common.constants import PromptType
from trinity.common.experience import Experience
@@ -23,7 +18,7 @@ def test_sft_messages_formatter(self):
prompt_type=PromptType.MESSAGES,
messages_key="message_list",
)
- formatter = SFTMessagesFormatter(tokenizer=self.tokenizer, format_config=config)
+ formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
sample = {
"message_list": [
{"role": "user", "content": "Hi"},
@@ -47,7 +42,7 @@ def test_sft_messages_formatter(self):
messages_key="messages",
tools_key="tools",
)
- formatter = SFTMessagesFormatter(tokenizer=self.tokenizer, format_config=config)
+ formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
sample = {
"messages": [
{
@@ -126,7 +121,7 @@ def test_sft_plaintext_formatter(self):
prompt_key="prompt",
response_key="response",
)
- formatter = SFTPlaintextFormatter(tokenizer=self.tokenizer, format_config=config)
+ formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
sample = {
"system": "You are a helpful assistant.",
"prompt": "What is 2+2?",
@@ -150,7 +145,7 @@ def test_sft_plaintext_formatter(self):
prompt_key="prompt",
response_key="response",
)
- formatter = SFTPlaintextFormatter(tokenizer=self.tokenizer, format_config=config)
+ formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
@@ -170,7 +165,7 @@ def test_dpo_plaintext_formatter(self):
chosen_key="chosen",
rejected_key="rejected",
)
- formatter = DPOPlaintextFormatter(tokenizer=self.tokenizer, format_config=config)
+ formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config)
sample = {"prompt": "What is 2+2?", "chosen": "2+2=4", "rejected": "2+2=5"}
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
@@ -196,7 +191,7 @@ def test_dpo_messages_formatter(self):
chosen_key="chosen",
rejected_key="rejected",
)
- formatter = DPOMessagesFormatter(tokenizer=self.tokenizer, format_config=config)
+ formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config)
sample = {
"messages": [
{"role": "user", "content": "What is your name?"},
@@ -217,3 +212,63 @@ def test_dpo_messages_formatter(self):
self.assertIn("What is your name?", prompt)
self.assertIn("My name is Assistant.", chosen)
self.assertIn("I don't have a favorite color.", rejected)
+
+ def test_task_formatter(self):
+ sample = {
+ "question": "1+1=",
+ "answer": "2",
+ "workflow": "math_rm_workflow",
+ "reward": "math_boxed_reward",
+ }
+ config = StorageConfig(
+ is_eval=True,
+ default_workflow_type="math_workflow",
+ default_eval_workflow_type="math_boxed_workflow",
+ workflow_args={"use_base": True, "with_think": True},
+ )
+ formatter = FORMATTER.get("task")(config=config)
+ task = formatter.format(sample)
+ from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow
+
+ self.assertEqual(task.workflow, MathBoxedWorkflow)
+ self.assertTrue(task.workflow_args.get("use_base"))
+ self.assertTrue(task.workflow_args.get("with_think"))
+ self.assertEqual(task.raw_task, sample)
+
+ config = StorageConfig(
+ is_eval=False,
+ default_workflow_type="math_workflow",
+ default_eval_workflow_type="math_boxed_workflow",
+ default_reward_fn_type="math_reward",
+ workflow_args={"use_base": False, "with_think": True},
+ )
+ formatter = FORMATTER.get("task")(config=config)
+ task = formatter.format(sample)
+ from trinity.common.rewards.math_reward import MathRewardFn
+ from trinity.common.workflows.workflow import MathWorkflow
+
+ self.assertEqual(task.workflow, MathWorkflow)
+ self.assertEqual(task.reward_fn, MathRewardFn)
+ self.assertFalse(task.workflow_args.get("use_base"))
+ self.assertTrue(task.workflow_args.get("with_think"))
+ self.assertEqual(task.raw_task, sample)
+
+ config = StorageConfig(
+ is_eval=False,
+ default_eval_workflow_type="math_workflow",
+ workflow_args={"use_base": True, "with_think": False},
+ format=FormatConfig(
+ workflow_key="workflow",
+ reward_fn_key="reward",
+ ),
+ )
+ formatter = FORMATTER.get("task")(config=config)
+ task = formatter.format(sample)
+ from trinity.common.rewards.math_reward import MathBoxedRewardFn
+ from trinity.common.workflows.math_rm_workflow import MathRMWorkflow
+
+ self.assertEqual(task.workflow, MathRMWorkflow)
+ self.assertEqual(task.reward_fn, MathBoxedRewardFn)
+ self.assertTrue(task.workflow_args.get("use_base"))
+ self.assertFalse(task.workflow_args.get("with_think"))
+ self.assertEqual(task.raw_task, sample)
diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py
index 27ee3cb0de..d4c5ea997b 100644
--- a/tests/buffer/queue_test.py
+++ b/tests/buffer/queue_test.py
@@ -33,7 +33,7 @@ class TestQueueBuffer(RayUnittestBaseAysnc):
async def test_queue_buffer(self, name, use_priority_queue):
meta = StorageConfig(
name=name,
- algorithm_type="ppo",
+ schema_type="experience",
storage_type=StorageType.QUEUE,
max_read_timeout=3,
path=BUFFER_FILE_PATH,
@@ -97,7 +97,7 @@ async def test_priority_queue_capacity(self):
self.config.train_batch_size = 4
meta = StorageConfig(
name="test_buffer_small",
- algorithm_type="ppo",
+ schema_type="experience",
storage_type=StorageType.QUEUE,
max_read_timeout=1,
capacity=100, # priority will use 2 * train_batch_size as capacity (8)
@@ -151,7 +151,7 @@ async def test_queue_buffer_capacity(self):
# test queue capacity
meta = StorageConfig(
name="test_buffer_small",
- algorithm_type="ppo",
+ schema_type="experience",
storage_type=StorageType.QUEUE,
max_read_timeout=3,
capacity=4,
@@ -180,7 +180,7 @@ async def test_priority_queue_buffer_reuse(self):
# test queue reuse
meta = StorageConfig(
name="test_buffer_small",
- algorithm_type="ppo",
+ schema_type="experience",
storage_type=StorageType.QUEUE,
max_read_timeout=3,
capacity=4,
@@ -306,8 +306,6 @@ def setUp(self):
self.train_batch_size = 4
self.config = BufferConfig(
- max_retry_times=3,
- max_retry_interval=1,
train_batch_size=self.train_batch_size,
)
if os.path.exists(BUFFER_FILE_PATH):
diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py
index 33e8a24462..1934b8fa6c 100644
--- a/tests/buffer/sql_test.py
+++ b/tests/buffer/sql_test.py
@@ -14,20 +14,17 @@
class TestSQLBuffer(RayUnittestBaseAysnc):
- async def test_create_sql_buffer(self) -> None:
+ async def test_sql_buffer_read_write(self) -> None:
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = StorageConfig(
name="test_buffer",
- algorithm_type="ppo",
+ schema_type="experience",
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL,
- wrap_in_ray=True,
)
config = BufferConfig(
- max_retry_times=3,
- max_retry_interval=1,
train_batch_size=read_batch_size,
)
sql_writer = SQLWriter(meta, config)
@@ -66,3 +63,7 @@ async def test_create_sql_buffer(self) -> None:
self.assertIsNotNone(db_wrapper)
self.assertEqual(await sql_writer.release(), 0)
self.assertRaises(StopIteration, sql_reader.read)
+
+ def setUp(self) -> None:
+ if os.path.exists(db_path):
+ os.remove(db_path)
diff --git a/tests/buffer/task_storage_test.py b/tests/buffer/task_storage_test.py
new file mode 100644
index 0000000000..b0adda2aff
--- /dev/null
+++ b/tests/buffer/task_storage_test.py
@@ -0,0 +1,63 @@
+import os
+
+import datasets
+from parameterized import parameterized
+
+from tests.tools import (
+ RayUnittestBase,
+ get_template_config,
+ get_unittest_dataset_config,
+)
+from trinity.buffer import get_buffer_reader
+from trinity.buffer.storage.sql import SQLTaskStorage
+from trinity.common.constants import StorageType
+
+db_path = os.path.join(os.path.dirname(__file__), "test.db")
+
+
+class TaskStorageTest(RayUnittestBase):
+ @parameterized.expand(
+ [
+ (StorageType.FILE, True, 2),
+ (StorageType.SQL, True, 2),
+ (StorageType.FILE, False, 0),
+ (StorageType.SQL, False, 0),
+ (StorageType.FILE, False, 2),
+ (StorageType.SQL, False, 2),
+ ]
+ )
+ def test_read_task(self, storage_type, is_eval, offset):
+ config = get_template_config()
+ total_samples = 17
+ batch_size = 4
+ config.buffer.explorer_input.taskset = get_unittest_dataset_config(
+ "countdown"
+ ) # 17 samples
+ config.buffer.batch_size = batch_size
+ config.buffer.explorer_input.taskset.storage_type = storage_type
+ config.buffer.explorer_input.taskset.is_eval = is_eval
+ config.buffer.explorer_input.taskset.index = offset
+ if storage_type == StorageType.SQL:
+ dataset = datasets.load_dataset(
+ config.buffer.explorer_input.taskset.path, split="train"
+ )
+ config.buffer.explorer_input.taskset.path = f"sqlite:///{db_path}"
+ SQLTaskStorage.load_from_dataset(
+ dataset, config.buffer.explorer_input.taskset, config.buffer
+ )
+ reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
+ tasks = []
+ try:
+ while True:
+ cur_tasks = reader.read()
+ tasks.extend(cur_tasks)
+ except StopIteration:
+ pass
+ if is_eval:
+ self.assertEqual(len(tasks), total_samples - offset)
+ else:
+ self.assertEqual(len(tasks), (total_samples - offset) // batch_size * batch_size)
+
+ def setUp(self):
+ if os.path.exists(db_path):
+ os.remove(db_path)
diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py
index f805778568..236d33d72f 100644
--- a/tests/explorer/scheduler_test.py
+++ b/tests/explorer/scheduler_test.py
@@ -213,7 +213,7 @@ def setUp(self):
) = StorageConfig(
name="test",
storage_type=StorageType.QUEUE,
- algorithm_type="ppo",
+ schema_type="experience",
path="",
)
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 1
diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py
index a77bc62097..cd96986c10 100644
--- a/tests/manager/synchronizer_test.py
+++ b/tests/manager/synchronizer_test.py
@@ -120,7 +120,6 @@ def test_synchronizer(self):
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
- wrap_in_ray=True,
)
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
@@ -140,7 +139,6 @@ def test_synchronizer(self):
explorer1_config.buffer.explorer_output = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
- wrap_in_ray=True,
)
explorer1_config.check_and_update()
@@ -245,7 +243,6 @@ def test_synchronizer(self):
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
- wrap_in_ray=True,
)
config.synchronizer.sync_method = self.sync_method
config.synchronizer.sync_style = self.sync_style
@@ -265,7 +262,6 @@ def test_synchronizer(self):
explorer1_config.buffer.explorer_output = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
- wrap_in_ray=True,
)
explorer2_config = deepcopy(explorer1_config)
explorer2_config.explorer.name = "explorer2"
@@ -346,7 +342,6 @@ def test_synchronizer(self):
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
- wrap_in_ray=True,
)
config.synchronizer.sync_method = SyncMethod.NCCL
config.synchronizer.sync_style = self.sync_style
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index bed27420c8..fd2bc06194 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -21,8 +21,6 @@ cluster: # 2 for explorer, 2 for trainer
buffer:
total_epochs: 1
batch_size: 4
- max_retry_times: 3
- max_retry_interval: 1
explorer_input:
taskset:
name: taskset
diff --git a/tests/test_configs/active_iterator_test_cfg.yaml b/tests/test_configs/active_iterator_test_cfg.yaml
deleted file mode 100644
index 3e6008b7cf..0000000000
--- a/tests/test_configs/active_iterator_test_cfg.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-data_processor:
- # basic info
- task_pipeline:
- input_buffers:
- - name: 'raw_input'
- path: 'tests/test_data/test_10/'
- storage_type: 'file'
- raw: true
- output_buffer:
- name: 'raw_output'
- path: './outputs/task_pipeline_output/processed.jsonl'
- storage_type: 'file'
- format:
- prompt_key: 'problem'
- response_key: 'solution'
- # cleaner related
- dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml'
- clean_strategy: 'iterative'
diff --git a/tests/test_configs/active_iterator_test_dj_cfg.yaml b/tests/test_configs/active_iterator_test_dj_cfg.yaml
deleted file mode 100644
index 367709968f..0000000000
--- a/tests/test_configs/active_iterator_test_dj_cfg.yaml
+++ /dev/null
@@ -1,11 +0,0 @@
-project_name: 'demo-process'
-
-text_keys: 'solution'
-
-process:
- - alphanumeric_filter:
- max_ratio: 0.9
- - language_id_score_filter:
- min_score: 0.5
- - character_repetition_filter:
- max_ratio: 0.5
diff --git a/tests/test_configs/cleaner_test_dj_cfg.yaml b/tests/test_configs/cleaner_test_dj_cfg.yaml
deleted file mode 100644
index cf11488963..0000000000
--- a/tests/test_configs/cleaner_test_dj_cfg.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-project_name: 'demo-process'
-
-export_path: './outputs/demo-process/demo-processed.jsonl'
-
-process:
- - alphanumeric_filter:
- - clean_email_mapper:
diff --git a/tests/test_configs/cleaner_test_rft_cfg.yaml b/tests/test_configs/cleaner_test_rft_cfg.yaml
deleted file mode 100644
index c78e3a1ac8..0000000000
--- a/tests/test_configs/cleaner_test_rft_cfg.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-data_processor:
- task_pipeline:
- input_buffers:
- - path: './tests/test_data/test_cleaner'
- raw: true
- dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml'
- clean_strategy: 'iterative'
diff --git a/tests/test_configs/human_annotator_test_dj_cfg.yaml b/tests/test_configs/human_annotator_test_dj_cfg.yaml
deleted file mode 100644
index ae53229062..0000000000
--- a/tests/test_configs/human_annotator_test_dj_cfg.yaml
+++ /dev/null
@@ -1,27 +0,0 @@
-project_name: 'demo-human-annotator'
-
-np: 1
-
-export_path: './outputs/demo-human-annotator/annotated-data.jsonl'
-
-process:
- - human_preference_annotation_mapper:
- # general annotation project settings
- project_name_prefix: "Human_Preference_Annotation_Demo"
- wait_for_annotations: true # Whether to wait for annotations to complete
- timeout: 3600 # Maximum time to wait for annotations in seconds (1 hour)
- poll_interval: 10 # Time between annotation status checks in seconds
- max_tasks_per_batch: 10 # Maximum number of tasks in a single batch
- notification_config:
- enabled: false
-
- # label studio connection settings
- api_url: "http://localhost:7070" # Default Label Studio URL
- api_key: "05409236-67a5-4169-af96-a52a818d0e81" # Your API key for label studuio authentication. Now it's the default api-key for default starting
-
- # human preference annotation settings
- prompt_key: "prompt" # Prompt field
- answer1_key: "answer1" # First answer option
- answer2_key: "answer2" # Second answer option
- chosen_key: "chosen" # Chosen field
- rejected_key: "rejected" # Rejected field
diff --git a/tests/test_configs/human_annotator_test_rft_cfg.yaml b/tests/test_configs/human_annotator_test_rft_cfg.yaml
deleted file mode 100644
index b20f015182..0000000000
--- a/tests/test_configs/human_annotator_test_rft_cfg.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-data_processor:
- task_pipeline:
- input_buffers:
- - path: './tests/test_data/test_human_annotator'
- raw: true
- dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml'
- format:
- prompt_key: 'prompt'
- chosen_key: 'chosen'
- rejected_key: 'rejected'
diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml
deleted file mode 100644
index bf474d3748..0000000000
--- a/tests/test_data/template.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-cluster:
- node_num: 1
- gpu_per_node: 8
-buffer:
- batch_size: 32
- max_retry_times: 3
- max_retry_interval: 1
- explorer_input:
- taskset:
- name: taskset
- storage_type: file
- path: ''
- default_workflow_type: ''
- default_eval_workflow_type: ''
- default_reward_fn_type: ''
-explorer:
- runner_num: 8
- rollout_model:
- engine_type: vllm
- engine_num: 2
- tensor_parallel_size: 2
- enable_prefix_caching: false
- enforce_eager: true
- dtype: bfloat16
- seed: 42
diff --git a/tests/test_data/test_10/data.parquet b/tests/test_data/test_10/data.parquet
deleted file mode 100644
index 9528828659..0000000000
Binary files a/tests/test_data/test_10/data.parquet and /dev/null differ
diff --git a/tests/test_data/test_10_with_rewfn_workflow/data.parquet b/tests/test_data/test_10_with_rewfn_workflow/data.parquet
deleted file mode 100644
index 52c14aff13..0000000000
Binary files a/tests/test_data/test_10_with_rewfn_workflow/data.parquet and /dev/null differ
diff --git a/tests/test_data/test_cleaner/data.parquet b/tests/test_data/test_cleaner/data.parquet
deleted file mode 100644
index 857b6334f4..0000000000
Binary files a/tests/test_data/test_cleaner/data.parquet and /dev/null differ
diff --git a/tests/test_data/test_human_annotator/data.jsonl b/tests/test_data/test_human_annotator/data.jsonl
deleted file mode 100644
index 1e46385914..0000000000
--- a/tests/test_data/test_human_annotator/data.jsonl
+++ /dev/null
@@ -1,10 +0,0 @@
-{"prompt": "What is the capital of France?", "answer1": "Paris", "answer2": "Lyon"}
-{"prompt": "Which planet is known as the Red Planet?", "answer1": "Mars", "answer2": "Venus"}
-{"prompt": "What is the chemical symbol for gold?", "answer1": "Au", "answer2": "Ag"}
-{"prompt": "Who wrote 'Romeo and Juliet'?", "answer1": "William Shakespeare", "answer2": "Christopher Marlowe"}
-{"prompt": "What is the largest mammal on Earth?", "answer1": "Blue Whale", "answer2": "African Elephant"}
-{"prompt": "In which year did World War II end?", "answer1": "1945", "answer2": "1944"}
-{"prompt": "What is the square root of 64?", "answer1": "8", "answer2": "6"}
-{"prompt": "Who painted the Mona Lisa?", "answer1": "Leonardo da Vinci", "answer2": "Michelangelo"}
-{"prompt": "What is the main component of the Sun?", "answer1": "Hydrogen", "answer2": "Helium"}
-{"prompt": "Which programming language was created by Guido van Rossum?", "answer1": "Python", "answer2": "Java"}
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 80e8322d28..24c0588ffa 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -250,8 +250,13 @@ def test_trainer(self):
"sft_for_gsm8k"
)
self.config.buffer.trainer_input.sft_warmup_steps = 3
+ self.config.buffer.trainer_input.experience_buffer = StorageConfig(
+ name="test_sql_storage",
+ max_read_timeout=20,
+ storage_type=StorageType.SQL,
+ max_retry_times=10,
+ )
self.config.check_and_update()
- self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
both(self.config)
@@ -428,7 +433,6 @@ def test_fully_async_mode(self, name, use_priority_queue):
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
- wrap_in_ray=True,
use_priority_queue=use_priority_queue,
)
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
@@ -450,7 +454,6 @@ def test_fully_async_mode(self, name, use_priority_queue):
explorer1_config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
- wrap_in_ray=True,
)
explorer2_config = deepcopy(explorer1_config)
explorer1_config.check_and_update()
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
index fcedf20267..e3fac8d1e2 100644
--- a/trinity/algorithm/algorithm.py
+++ b/trinity/algorithm/algorithm.py
@@ -4,7 +4,6 @@
from abc import ABC, ABCMeta, abstractmethod
from typing import Dict
-from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
from trinity.common.config import Config
from trinity.common.constants import SyncMethod
from trinity.utils.log import get_logger
@@ -23,11 +22,17 @@ def __setattr__(cls, name, value):
class AlgorithmType(ABC, metaclass=ConstantMeta):
- use_critic: bool
- use_reference: bool
- compute_advantage_in_trainer: bool
- can_balance_batch: bool
- schema: type
+ use_critic: bool # whether to use critic model
+
+ use_reference: bool # whether to use reference model
+
+ compute_advantage_in_trainer: bool # whether to compute advantage in trainer
+ # For algorithms that rely on experience grouping,
+ # we recommend set this value to False
+
+ can_balance_batch: bool # balance batch in trainer
+
+ schema: str # schema of training data
@classmethod
@abstractmethod
@@ -51,7 +56,7 @@ class SFTAlgorithm(AlgorithmType):
use_reference: bool = False
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
- schema: type = SFTDataModel
+ schema: str = "sft"
@classmethod
def default_config(cls) -> Dict:
@@ -71,7 +76,7 @@ class PPOAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = True
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
@@ -94,7 +99,7 @@ class GRPOAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
@@ -117,7 +122,7 @@ class OPMDAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
@@ -140,7 +145,7 @@ class AsymREAlgorithm(AlgorithmType):
use_reference: bool = False
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
@@ -163,7 +168,7 @@ class DPOAlgorithm(AlgorithmType):
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = False
- schema: type = DPODataModel
+ schema: str = "dpo"
@classmethod
def default_config(cls) -> Dict:
@@ -208,7 +213,7 @@ class MIXAlgorithm(AlgorithmType):
compute_advantage_in_trainer: bool = False
use_rollout: bool = True
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
@@ -230,7 +235,7 @@ class MIXCHORDAlgorithm(AlgorithmType):
compute_advantage_in_trainer: bool = False
use_rollout: bool = True
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
@@ -247,14 +252,14 @@ def default_config(cls) -> Dict:
class RAFTAlgorithm(AlgorithmType):
"""RAFT Algorithm.
This algorithm is conceptually similar to Supervised Fine-Tuning (SFT)
- but is designed to work with `ExperienceModel` schema from rollouts.
+ but is designed to work with `experience` schema from rollouts.
"""
use_critic: bool = False
use_reference: bool = False
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
- schema: type = ExperienceModel
+ schema: str = "experience"
@classmethod
def default_config(cls) -> Dict:
diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py
index ac5798c82d..a7d52e60a7 100644
--- a/trinity/buffer/buffer.py
+++ b/trinity/buffer/buffer.py
@@ -17,16 +17,17 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
return QueueReader(storage_config, buffer_config)
elif storage_config.storage_type == StorageType.FILE:
- from trinity.buffer.reader.file_reader import FILE_READERS
-
- algorithm_type = storage_config.algorithm_type
- if storage_config.raw:
- file_read_type = "raw"
- elif algorithm_type is not None:
- file_read_type = algorithm_type
+ from trinity.buffer.reader.file_reader import (
+ ExperienceFileReader,
+ TaskFileReader,
+ )
+
+ schema_type = storage_config.schema_type
+ if schema_type:
+ # only trainer input has schema type
+ return ExperienceFileReader(storage_config, buffer_config)
else:
- file_read_type = "rollout"
- return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore
+ return TaskFileReader(storage_config, buffer_config)
else:
raise ValueError(f"{storage_config.storage_type} not supported.")
diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py
index 6250b46703..efc1f16ad9 100644
--- a/trinity/buffer/pipelines/experience_pipeline.py
+++ b/trinity/buffer/pipelines/experience_pipeline.py
@@ -3,7 +3,7 @@
from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer
from trinity.buffer.operators.experience_operator import ExperienceOperator
-from trinity.buffer.ray_wrapper import is_database_url, is_json_file
+from trinity.buffer.storage.queue import is_database_url, is_json_file
from trinity.common.config import (
AlgorithmConfig,
BufferConfig,
diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py
deleted file mode 100644
index f946f78495..0000000000
--- a/trinity/buffer/ray_wrapper.py
+++ /dev/null
@@ -1,293 +0,0 @@
-"""Ray actor wrapper for different buffers."""
-import asyncio
-import json
-import os
-import time
-from collections import deque
-from copy import deepcopy
-from typing import List, Optional
-
-import ray
-from sqlalchemy import asc, create_engine, desc
-from sqlalchemy.exc import OperationalError
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.pool import NullPool
-
-from trinity.buffer.queue import QueueBuffer
-from trinity.buffer.schema import Base, create_dynamic_table
-from trinity.buffer.utils import default_storage_path, retry_session
-from trinity.common.config import BufferConfig, StorageConfig
-from trinity.common.constants import StorageType
-from trinity.common.experience import EID, Experience
-from trinity.common.workflows import Task
-from trinity.utils.log import get_logger
-
-
-class DBWrapper:
- """
- A wrapper of a SQL database.
-
- If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor,
- and provide a remote interface to the local database.
-
- For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), we
- recommend setting `wrap_in_ray` to `True`
- """
-
- def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
- self.logger = get_logger(f"sql_{storage_config.name}")
- if storage_config.path is None:
- storage_config.path = default_storage_path(storage_config, config)
- self.engine = create_engine(storage_config.path, poolclass=NullPool)
- self.table_model_cls = create_dynamic_table(
- storage_config.algorithm_type, storage_config.name
- )
-
- try:
- Base.metadata.create_all(self.engine, checkfirst=True)
- except OperationalError:
- self.logger.warning("Failed to create database, assuming it already exists.")
-
- self.session = sessionmaker(bind=self.engine)
- self.batch_size = config.train_batch_size
- self.max_retry_times = config.max_retry_times
- self.max_retry_interval = config.max_retry_interval
- self.ref_count = 0
- self.stopped = False
-
- @classmethod
- def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
- if storage_config.wrap_in_ray:
- return (
- ray.remote(cls)
- .options(
- name=f"sql-{storage_config.name}",
- namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
- get_if_exists=True,
- )
- .remote(storage_config, config)
- )
- else:
- return cls(storage_config, config)
-
- def write(self, data: list) -> None:
- with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
- experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
- session.add_all(experience_models)
-
- def read(self, batch_size: Optional[int] = None) -> List:
- if self.stopped:
- raise StopIteration()
-
- sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
-
- exp_list = []
- batch_size = batch_size or self.batch_size # type: ignore
- while len(exp_list) < batch_size:
- if len(exp_list):
- self.logger.info("waiting for experiences...")
- time.sleep(1)
- with retry_session(
- self.session, self.max_retry_times, self.max_retry_interval
- ) as session:
- # get a batch of experiences from the database
- experiences = (
- session.query(self.table_model_cls)
- .filter(self.table_model_cls.reward.isnot(None))
- .order_by(*sortOrder) # TODO: very slow
- .limit(batch_size - len(exp_list))
- .with_for_update()
- .all()
- )
- # update the consumed field
- for exp in experiences:
- exp.consumed += 1
- exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
- self.logger.info(f"get {len(exp_list)} experiences:")
- self.logger.info(f"reward = {[exp.reward for exp in exp_list]}")
- self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
- self.logger.info(f"first response_text = {exp_list[0].response_text}")
- return exp_list
-
- def acquire(self) -> int:
- self.ref_count += 1
- return self.ref_count
-
- def release(self) -> int:
- self.ref_count -= 1
- if self.ref_count <= 0:
- self.stopped = True
- return self.ref_count
-
-
-class _Encoder(json.JSONEncoder):
- def default(self, o):
- if isinstance(o, Experience):
- return o.to_dict()
- if isinstance(o, Task):
- return o.to_dict()
- if isinstance(o, EID):
- return o.to_dict()
- return super().default(o)
-
-
-class FileWrapper:
- """
- A wrapper of a local jsonl file.
-
- If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as
- a Ray Actor, and provide a remote interface to the local file.
-
- This wrapper is only for writing, if you want to read from the file, use
- StorageType.QUEUE instead.
- """
-
- def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
- if storage_config.path is None:
- storage_config.path = default_storage_path(storage_config, config)
- ext = os.path.splitext(storage_config.path)[-1]
- if ext != ".jsonl" and ext != ".json":
- raise ValueError(
- f"File path must end with '.json' or '.jsonl', got {storage_config.path}"
- )
- path_dir = os.path.dirname(storage_config.path)
- os.makedirs(path_dir, exist_ok=True)
- self.file = open(storage_config.path, "a", encoding="utf-8")
- self.encoder = _Encoder(ensure_ascii=False)
- self.ref_count = 0
-
- @classmethod
- def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
- if storage_config.wrap_in_ray:
- return (
- ray.remote(cls)
- .options(
- name=f"json-{storage_config.name}",
- namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
- get_if_exists=True,
- )
- .remote(storage_config, config)
- )
- else:
- return cls(storage_config, config)
-
- def write(self, data: List) -> None:
- for item in data:
- json_str = self.encoder.encode(item)
- self.file.write(json_str + "\n")
- self.file.flush()
-
- def read(self) -> List:
- raise NotImplementedError(
- "read() is not implemented for FileWrapper, please use QUEUE instead"
- )
-
- def acquire(self) -> int:
- self.ref_count += 1
- return self.ref_count
-
- def release(self) -> int:
- self.ref_count -= 1
- if self.ref_count <= 0:
- self.file.close()
- return self.ref_count
-
-
-def is_database_url(path: str) -> bool:
- return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])
-
-
-def is_json_file(path: str) -> bool:
- return path.endswith(".json") or path.endswith(".jsonl")
-
-
-class QueueWrapper:
- """An wrapper of a async queue."""
-
- def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
- self.logger = get_logger(f"queue_{storage_config.name}")
- self.config = config
- self.capacity = storage_config.capacity
- self.queue = QueueBuffer.get_queue(storage_config, config)
- st_config = deepcopy(storage_config)
- st_config.wrap_in_ray = False
- if st_config.path is not None:
- if is_database_url(st_config.path):
- from trinity.buffer.writer.sql_writer import SQLWriter
-
- st_config.storage_type = StorageType.SQL
- self.writer = SQLWriter(st_config, self.config)
- elif is_json_file(st_config.path):
- from trinity.buffer.writer.file_writer import JSONWriter
-
- st_config.storage_type = StorageType.FILE
- self.writer = JSONWriter(st_config, self.config)
- else:
- self.logger.warning("Unknown supported storage path: %s", st_config.path)
- self.writer = None
- else:
- from trinity.buffer.writer.file_writer import JSONWriter
-
- st_config.storage_type = StorageType.FILE
- self.writer = JSONWriter(st_config, self.config)
- self.logger.warning(f"Save experiences in {st_config.path}.")
- self.ref_count = 0
- self.exp_pool = deque() # A pool to store experiences
- self.closed = False
-
- async def acquire(self) -> int:
- self.ref_count += 1
- return self.ref_count
-
- async def release(self) -> int:
- """Release the queue."""
- self.ref_count -= 1
- if self.ref_count <= 0:
- await self.queue.close()
- await self.writer.release()
- return self.ref_count
-
- def length(self) -> int:
- """The length of the queue."""
- return self.queue.qsize()
-
- async def put_batch(self, exp_list: List) -> None:
- """Put batch of experience."""
- await self.queue.put(exp_list)
- if self.writer is not None:
- self.writer.write(exp_list)
-
- async def get_batch(self, batch_size: int, timeout: float) -> List:
- """Get batch of experience."""
- start_time = time.time()
- while len(self.exp_pool) < batch_size:
- if self.queue.stopped():
- # If the queue is stopped, ignore the rest of the experiences in the pool
- raise StopAsyncIteration("Queue is closed and no more items to get.")
- try:
- exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0)
- self.exp_pool.extend(exp_list)
- except asyncio.TimeoutError:
- if time.time() - start_time > timeout:
- self.logger.error(
- f"Timeout when waiting for experience, only get {len(self.exp_pool)} experiences.\n"
- "This phenomenon is usually caused by the workflow not returning enough "
- "experiences or running timeout. Please check your workflow implementation."
- )
- batch = list(self.exp_pool)
- self.exp_pool.clear()
- return batch
- return [self.exp_pool.popleft() for _ in range(batch_size)]
-
- @classmethod
- def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
- """Get the queue actor."""
- return (
- ray.remote(cls)
- .options(
- name=f"queue-{storage_config.name}",
- namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
- get_if_exists=True,
- )
- .remote(storage_config, config)
- )
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 5c7c558b47..b111d2de91 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -1,27 +1,14 @@
"""Filed based buffer reader."""
-import copy
from typing import List, Optional
import datasets
import transformers
from datasets import Dataset, load_dataset
-from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
from trinity.buffer.buffer_reader import BufferReader
-from trinity.buffer.schema.formatter import (
- DPOMessagesFormatter,
- DPOPlaintextFormatter,
- SFTMessagesFormatter,
- SFTPlaintextFormatter,
-)
+from trinity.buffer.schema.formatter import FORMATTER
from trinity.common.config import BufferConfig, StorageConfig
-from trinity.common.constants import PromptType, TaskType
-from trinity.common.rewards import REWARD_FUNCTIONS
-from trinity.common.workflows import WORKFLOWS, Task
-from trinity.utils.registry import Registry
-
-FILE_READERS = Registry("file_readers")
class DummyProgressBar:
@@ -112,22 +99,14 @@ async def read_async(self, batch_size: Optional[int] = None):
raise StopAsyncIteration from e
-@FILE_READERS.register_module(SFTAlgorithm.name())
-class SFTDataReader(BaseFileReader):
+class ExperienceFileReader(BaseFileReader):
"""Reader for SFT file data."""
def __init__(self, meta: StorageConfig, config: BufferConfig):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
- if meta.format.prompt_type == PromptType.MESSAGES:
- self.formatter = SFTMessagesFormatter(
- tokenizer=self.tokenizer, format_config=meta.format
- )
- elif meta.format.prompt_type == PromptType.PLAINTEXT:
- self.formatter = SFTPlaintextFormatter(
- tokenizer=self.tokenizer, format_config=meta.format
- )
- else:
- raise ValueError(f"Unknown prompt type: {self.prompt_type}")
+ self.formatter = FORMATTER.get(meta.schema_type)(
+ tokenizer=self.tokenizer, format_config=meta.format
+ )
self.read_batch_size = config.train_batch_size
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=meta.subset_name, split=meta.split),
@@ -148,48 +127,7 @@ def read(self, batch_size: Optional[int] = None) -> List:
return exp_list
-@FILE_READERS.register_module(DPOAlgorithm.name())
-class DPODataReader(BaseFileReader):
- def __init__(self, meta: StorageConfig, config: BufferConfig):
- self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
- if meta.format.prompt_type == PromptType.MESSAGES:
- self.formatter = DPOMessagesFormatter(
- tokenizer=self.tokenizer, format_config=meta.format
- )
- elif meta.format.prompt_type == PromptType.PLAINTEXT:
- self.formatter = DPOPlaintextFormatter(
- tokenizer=self.tokenizer, format_config=meta.format
- )
- self.read_batch_size = config.train_batch_size
- self.dataset = _HFBatchReader(
- load_dataset(meta.path, name=meta.subset_name, split=meta.split),
- name=meta.name,
- default_batch_size=self.read_batch_size,
- total_epochs=meta.total_epochs,
- drop_last=True,
- total_steps=meta.total_steps,
- enable_progress_bar=meta.enable_progress_bar,
- ) # TODO: support resume
-
- def _get_assistant_message(self, item) -> dict:
- if isinstance(item, List):
- item = item[0]
- if isinstance(item, str):
- return {"role": "assistant", "content": item}
- else:
- return item
-
- def read(self, batch_size: Optional[int] = None) -> List:
- batch_data = self.dataset.read_batch(batch_size or self.read_batch_size)
- exp_list = []
- for sample in batch_data:
- experience = self.formatter.format(sample)
- exp_list.append(experience)
- return exp_list
-
-
-@FILE_READERS.register_module("rollout")
-class RolloutDataReader(BaseFileReader):
+class TaskFileReader(BaseFileReader):
def __init__(self, meta: StorageConfig, config: BufferConfig):
self.meta = meta
self.name = meta.name
@@ -203,61 +141,24 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
load_dataset(meta.path, name=subset_name, split=self.split),
name=meta.name,
default_batch_size=self.read_batch_size,
- total_epochs=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1,
+ total_epochs=self.meta.total_epochs if not self.meta.is_eval else 1,
offset=self.meta.index,
- drop_last=self.meta.task_type == TaskType.EXPLORE,
+ drop_last=not self.meta.is_eval,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
)
- self.prompt_key = meta.format.prompt_key
- self.response_key = meta.format.response_key
- self.workflow_key = meta.format.workflow_key
- self.reward_fn_key = meta.format.reward_fn_key
-
- self.task_type = meta.task_type
- self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore
- self.default_eval_workflow_cls = None
- if getattr(meta, "default_eval_workflow_type", None):
- self.default_eval_workflow_cls = WORKFLOWS.get(meta.default_eval_workflow_type)
- self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore
+ self.formatter = FORMATTER.get("task")(meta)
def read(self, batch_size: Optional[int] = None) -> List:
batch_size = batch_size or self.read_batch_size
tasks = []
samples = self.dataset.read_batch(batch_size)
for sample in samples:
- if self.task_type == TaskType.EVAL and self.default_eval_workflow_cls:
- workflow_class = self.default_eval_workflow_cls
- else:
- workflow_class = (
- WORKFLOWS.get(sample[self.workflow_key])
- if self.workflow_key in sample
- else self.default_workflow_cls
- )
- reward_fn = (
- REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
- if self.reward_fn_key in sample
- else self.default_reward_fn_cls
- )
- assert (
- workflow_class is not None
- ), "`default_workflow_type` or `workflow_key` is required"
- task = Task(
- workflow=workflow_class,
- repeat_times=self.meta.repeat_times,
- format_args=copy.deepcopy(self.meta.format),
- rollout_args=copy.deepcopy(self.meta.rollout_args),
- workflow_args=copy.deepcopy(self.meta.workflow_args),
- reward_fn_args=copy.deepcopy(self.meta.reward_fn_args),
- is_eval=self.meta.task_type == TaskType.EVAL,
- reward_fn=reward_fn,
- raw_task=sample,
- )
+ task = self.formatter.format(sample)
tasks.append(task)
return tasks
-@FILE_READERS.register_module("raw")
class RawDataReader(BaseFileReader):
def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]):
self.returned = False
diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py
index ad5ad1165f..e4ea695a21 100644
--- a/trinity/buffer/reader/queue_reader.py
+++ b/trinity/buffer/reader/queue_reader.py
@@ -5,7 +5,7 @@
import ray
from trinity.buffer.buffer_reader import BufferReader
-from trinity.buffer.ray_wrapper import QueueWrapper
+from trinity.buffer.storage.queue import QueueStorage
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
@@ -17,7 +17,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig):
assert storage_config.storage_type == StorageType.QUEUE
self.timeout = storage_config.max_read_timeout
self.read_batch_size = config.train_batch_size
- self.queue = QueueWrapper.get_wrapper(storage_config, config)
+ self.queue = QueueStorage.get_wrapper(storage_config, config)
def read(self, batch_size: Optional[int] = None) -> List:
try:
diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py
index e715557957..7e45b842d2 100644
--- a/trinity/buffer/reader/sql_reader.py
+++ b/trinity/buffer/reader/sql_reader.py
@@ -5,7 +5,7 @@
import ray
from trinity.buffer.buffer_reader import BufferReader
-from trinity.buffer.ray_wrapper import DBWrapper
+from trinity.buffer.storage.sql import SQLStorage
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
@@ -16,16 +16,19 @@ class SQLReader(BufferReader):
def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
self.wrap_in_ray = meta.wrap_in_ray
- self.db_wrapper = DBWrapper.get_wrapper(meta, config)
+ self.storage = SQLStorage.get_wrapper(meta, config)
def read(self, batch_size: Optional[int] = None) -> List:
if self.wrap_in_ray:
- return ray.get(self.db_wrapper.read.remote(batch_size))
+ return ray.get(self.storage.read.remote(batch_size))
else:
- return self.db_wrapper.read(batch_size)
+ return self.storage.read(batch_size)
async def read_async(self, batch_size: Optional[int] = None) -> List:
if self.wrap_in_ray:
- return await self.db_wrapper.read.remote(batch_size)
+ try:
+ return ray.get(self.storage.read.remote(batch_size))
+ except StopIteration:
+ raise StopAsyncIteration
else:
- return self.db_wrapper.read(batch_size)
+ return self.storage.read(batch_size)
diff --git a/trinity/buffer/schema/__init__.py b/trinity/buffer/schema/__init__.py
index 269c524bdf..5fdf581dbc 100644
--- a/trinity/buffer/schema/__init__.py
+++ b/trinity/buffer/schema/__init__.py
@@ -1,3 +1,4 @@
-from .sql_schema import Base, create_dynamic_table
+from trinity.buffer.schema.formatter import FORMATTER
+from trinity.buffer.schema.sql_schema import init_engine
-__all__ = ["create_dynamic_table", "Base"]
+__all__ = ["init_engine", "FORMATTER"]
diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py
index 7285909e07..4304da82ed 100644
--- a/trinity/buffer/schema/formatter.py
+++ b/trinity/buffer/schema/formatter.py
@@ -1,39 +1,79 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
-from trinity.common.config import FormatConfig
+from trinity.common.config import FormatConfig, StorageConfig
+from trinity.common.constants import PromptType
from trinity.common.experience import Experience
+from trinity.common.rewards import REWARD_FUNCTIONS
+from trinity.common.workflows import WORKFLOWS, Task
+from trinity.utils.registry import Registry
+FORMATTER = Registry("formatter")
-class Formatter(ABC):
- @abstractmethod
- def __init__(self, format_config: FormatConfig):
- """Initialize the formatter with the given configuration."""
+class ExperienceFormatter(ABC):
@abstractmethod
def format(self, sample: Dict) -> Experience:
"""Format a raw sample dict into an experience."""
-def format_messages(
- tokenizer, messages: List[Dict], tools: Optional[List[Dict]] = None
-) -> Experience:
- tokens = tokenizer.apply_chat_template(
- messages, tools=tools, add_generation_prompt=False, return_tensors="pt"
- )[0]
- prompt_tokens_ids = tokenizer.apply_chat_template(
- messages[:-1], tools=tools, add_generation_prompt=True, return_tensors="pt"
- )[0]
- return Experience(
- tokens=tokens,
- prompt_length=len(prompt_tokens_ids),
- )
+@FORMATTER.register_module("task")
+class TaskFormatter:
+ """Formatter for task data.
+ Example Input:
-class SFTMessagesFormatter(Formatter):
- """Formatter for SFT message list data.
+ .. code-block:: python
- Example Input:
+ {
+ "input": "Hello",
+ "output": "Hi"
+ }
+ """
+
+ def __init__(self, config: StorageConfig):
+ self.config = config
+ self.is_eval = config.is_eval
+ self.default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) # type: ignore
+ if self.is_eval and config.default_eval_workflow_type:
+ self.default_workflow_cls = WORKFLOWS.get(config.default_eval_workflow_type)
+ self.default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) # type: ignore
+ self.workflow_key = config.format.workflow_key
+ self.reward_fn_key = config.format.reward_fn_key
+
+ def format(self, sample: Dict) -> Task:
+ """Format a raw sample dict into a Task."""
+
+ workflow_name = sample.get(self.workflow_key, None) if self.workflow_key else None
+ reward_fn_name = sample.get(self.reward_fn_key, None) if self.reward_fn_key else None
+
+ workflow_cls = (
+ WORKFLOWS.get(workflow_name) if workflow_name else None
+ ) or self.default_workflow_cls
+ reward_fn_cls = (
+ REWARD_FUNCTIONS.get(reward_fn_name) if reward_fn_name else None
+ ) or self.default_reward_fn_cls
+ assert workflow_cls is not None, "`default_workflow_type` or `workflow_key` is required"
+ return Task(
+ workflow=workflow_cls,
+ reward_fn=reward_fn_cls,
+ format_args=self.config.format,
+ repeat_times=self.config.repeat_times,
+ rollout_args=self.config.rollout_args,
+ workflow_args=self.config.workflow_args,
+ reward_fn_args=self.config.reward_fn_args,
+ is_eval=self.is_eval,
+ raw_task=sample,
+ )
+
+
+@FORMATTER.register_module("sft")
+class SFTFormatter(ExperienceFormatter):
+ """Formatter for SFT data, supporting both message list and plaintext formats.
+
+ Uses format_config.prompt_type to distinguish between 'messages' and 'plaintext'.
+
+ Example input of MESSAGES:
.. code-block:: python
@@ -43,25 +83,9 @@ class SFTMessagesFormatter(Formatter):
{"role": "assistant", "content": "I'm fine, thank you!"}
]
}
- """
-
- def __init__(self, tokenizer, format_config: FormatConfig):
- self.tokenizer = tokenizer
- self.messages_key = format_config.messages_key
- self.tools_key = format_config.tools_key
-
- def format(self, sample: Dict) -> Experience:
- """Format a raw sample dict into an experience."""
- messages = sample[self.messages_key]
- tools = sample.get(self.tools_key, None)
- return format_messages(self.tokenizer, messages, tools)
-
-class SFTPlaintextFormatter(Formatter):
- """Formatter for SFT plaintext data.
-
- Example Input:
+ Example input of PLAINTEXT:
.. code-block:: python
@@ -74,54 +98,60 @@ class SFTPlaintextFormatter(Formatter):
def __init__(self, tokenizer, format_config: FormatConfig):
self.tokenizer = tokenizer
- self.prompt_key = format_config.prompt_key
- self.response_key = format_config.response_key
- self.system_prompt_key = format_config.system_prompt_key
- self.system_prompt = format_config.system_prompt
- self.tools_key = format_config.tools_key
+ self.prompt_type = format_config.prompt_type
+ # For messages type
+ if self.prompt_type == PromptType.MESSAGES:
+ self.messages_key = format_config.messages_key
+ self.tools_key = format_config.tools_key
+ # For plaintext type
+ elif self.prompt_type == PromptType.PLAINTEXT:
+ self.prompt_key = format_config.prompt_key
+ self.response_key = format_config.response_key
+ self.system_prompt_key = format_config.system_prompt_key
+ self.system_prompt = format_config.system_prompt
+ self.tools_key = format_config.tools_key
+ else:
+ raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
+
+ def _messages_to_experience(
+ self, messages: List[Dict], tools: Optional[List[Dict]] = None
+ ) -> Experience:
+ tokens = self.tokenizer.apply_chat_template(
+ messages, tools=tools, add_generation_prompt=False, return_tensors="pt"
+ )[0]
+ prompt_tokens_ids = self.tokenizer.apply_chat_template(
+ messages[:-1], tools=tools, add_generation_prompt=True, return_tensors="pt"
+ )[0]
+ return Experience(
+ tokens=tokens,
+ prompt_length=len(prompt_tokens_ids),
+ messages=messages,
+ )
def format(self, sample: Dict) -> Experience:
- """Format a raw sample dict into an experience."""
- messages = []
- if self.system_prompt_key is not None:
- system_message = {"role": "system", "content": sample[self.system_prompt_key]}
- messages.append(system_message)
- elif self.system_prompt is not None:
- system_message = {"role": "system", "content": self.system_prompt}
- messages.append(system_message)
- messages.append({"role": "user", "content": sample[self.prompt_key]})
- messages.append({"role": "assistant", "content": sample[self.response_key]})
-
- return format_messages(self.tokenizer, messages, sample.get(self.tools_key, None))
-
-
-def format_dpo_messages(
- tokenizer,
- prompt_messages: List[Dict],
- chosen_messages: List[Dict],
- reject_messages: List[Dict],
-):
- prompt_tokens = tokenizer.apply_chat_template(
- prompt_messages, add_generation_prompt=True, return_tensors="pt"
- )[0]
- chosen_tokens = tokenizer.apply_chat_template(
- prompt_messages + chosen_messages, add_generation_prompt=False, return_tensors="pt"
- )[0][len(prompt_tokens) :]
- reject_tokens = tokenizer.apply_chat_template(
- prompt_messages + reject_messages, add_generation_prompt=False, return_tensors="pt"
- )[0][len(prompt_tokens) :]
- return Experience(
- tokens=prompt_tokens,
- prompt_length=len(prompt_tokens),
- chosen=chosen_tokens,
- rejected=reject_tokens,
- )
-
-
-class DPOPlaintextFormatter(Formatter):
+ if self.prompt_type == PromptType.MESSAGES:
+ messages = sample[self.messages_key]
+ elif self.prompt_type == PromptType.PLAINTEXT:
+ messages = []
+ if self.system_prompt_key is not None:
+ system_message = {"role": "system", "content": sample[self.system_prompt_key]}
+ messages.append(system_message)
+ elif self.system_prompt is not None:
+ system_message = {"role": "system", "content": self.system_prompt}
+ messages.append(system_message)
+ messages.append({"role": "user", "content": sample[self.prompt_key]})
+ messages.append({"role": "assistant", "content": sample[self.response_key]})
+ else:
+ raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
+ tools = sample.get(self.tools_key, None)
+ return self._messages_to_experience(messages, tools)
+
+
+@FORMATTER.register_module("dpo")
+class DPOFormatter(ExperienceFormatter):
"""Formatter for DPO plaintext data.
- Example Input:
+ Example Input for PLAINTEXT:
.. code-block:: python
@@ -130,38 +160,8 @@ class DPOPlaintextFormatter(Formatter):
"chosen": "My name is Assistant.",
"rejected": "I don't have a name."
}
- """
- def __init__(self, tokenizer, format_config: FormatConfig):
- self.tokenizer = tokenizer
- self.prompt_key = format_config.prompt_key
- self.chosen_key = format_config.chosen_key
- self.rejected_key = format_config.rejected_key
- self.system_prompt_key = format_config.system_prompt_key
- self.system_prompt = format_config.system_prompt
- # currently DPO not support tools
-
- def format(self, sample: Dict) -> Experience:
- messages = []
- if self.system_prompt_key is not None:
- system_message = {"role": "system", "content": sample[self.system_prompt_key]}
- messages.append(system_message)
- elif self.system_prompt is not None:
- system_message = {"role": "system", "content": self.system_prompt}
- messages.append(system_message)
- messages.append({"role": "user", "content": sample[self.prompt_key]})
- return format_dpo_messages(
- tokenizer=self.tokenizer,
- prompt_messages=messages,
- chosen_messages=[{"role": "assistant", "content": sample[self.chosen_key]}],
- reject_messages=[{"role": "assistant", "content": sample[self.rejected_key]}],
- )
-
-
-class DPOMessagesFormatter(Formatter):
- """Formatter for DPO message list data.
-
- Example Input:
+ Example Input for MESSAGES:
.. code-block:: python
@@ -180,12 +180,62 @@ class DPOMessagesFormatter(Formatter):
def __init__(self, tokenizer, format_config: FormatConfig):
self.tokenizer = tokenizer
- self.messages_key = format_config.messages_key
- self.chosen_key = format_config.chosen_key
- self.rejected_key = format_config.rejected_key
+ self.prompt_type = format_config.prompt_type
+ if self.prompt_type == PromptType.PLAINTEXT:
+ self.prompt_key = format_config.prompt_key
+ self.chosen_key = format_config.chosen_key
+ self.rejected_key = format_config.rejected_key
+ self.system_prompt_key = format_config.system_prompt_key
+ self.system_prompt = format_config.system_prompt
+ elif self.prompt_type == PromptType.MESSAGES:
+ self.messages_key = format_config.messages_key
+ self.chosen_key = format_config.chosen_key
+ self.rejected_key = format_config.rejected_key
+ else:
+ raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
+ # currently DPO not support tools
+
+ def _messages_to_experience(
+ self, prompt_messages, chosen_messages, rejected_messages
+ ) -> Experience:
+ prompt_tokens = self.tokenizer.apply_chat_template(
+ prompt_messages, add_generation_prompt=True, return_tensors="pt"
+ )[0]
+ chosen_tokens = self.tokenizer.apply_chat_template(
+ prompt_messages + chosen_messages, add_generation_prompt=False, return_tensors="pt"
+ )[0][len(prompt_tokens) :]
+ rejected_tokens = self.tokenizer.apply_chat_template(
+ prompt_messages + rejected_messages, add_generation_prompt=False, return_tensors="pt"
+ )[0][len(prompt_tokens) :]
+ return Experience(
+ tokens=prompt_tokens,
+ prompt_length=len(prompt_tokens),
+ chosen=chosen_tokens,
+ rejected=rejected_tokens,
+ chosen_messages=prompt_messages + chosen_messages,
+ rejected_messages=prompt_messages + rejected_messages,
+ )
def format(self, sample: Dict) -> Experience:
- messages = sample[self.messages_key]
- chosen = sample[self.chosen_key]
- rejected = sample[self.rejected_key]
- return format_dpo_messages(self.tokenizer, messages, chosen, rejected)
+ if self.prompt_type == PromptType.PLAINTEXT:
+ messages = []
+ if self.system_prompt_key is not None:
+ system_message = {"role": "system", "content": sample[self.system_prompt_key]}
+ messages.append(system_message)
+ elif self.system_prompt is not None:
+ system_message = {"role": "system", "content": self.system_prompt}
+ messages.append(system_message)
+ messages.append({"role": "user", "content": sample[self.prompt_key]})
+ chosen = [{"role": "assistant", "content": sample[self.chosen_key]}]
+ rejected = [{"role": "assistant", "content": sample[self.rejected_key]}]
+ elif self.prompt_type == PromptType.MESSAGES:
+ messages = sample[self.messages_key]
+ chosen = sample[self.chosen_key]
+ rejected = sample[self.rejected_key]
+ else:
+ raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
+ return self._messages_to_experience(
+ prompt_messages=messages,
+ chosen_messages=chosen,
+ rejected_messages=rejected,
+ )
diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py
index 61149b8b25..74a9135b7f 100644
--- a/trinity/buffer/schema/sql_schema.py
+++ b/trinity/buffer/schema/sql_schema.py
@@ -1,142 +1,135 @@
-"""Schema for SQLAlchemy models."""
+"""SQLAlchemy models for different data."""
-from typing import Any, Optional, Union
+from typing import Dict, Optional, Tuple
-from sqlalchemy import Column, Float, Integer, LargeBinary, String
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import JSON, Column, Float, Integer, LargeBinary, Text, create_engine
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.pool import NullPool
from trinity.common.experience import Experience
+from trinity.utils.log import get_logger
+from trinity.utils.registry import Registry
+
+SQL_SCHEMA = Registry("sql_schema")
Base = declarative_base()
+@SQL_SCHEMA.register_module("task")
class TaskModel(Base): # type: ignore
"""Model for storing tasks in SQLAlchemy."""
__abstract__ = True
- __table_args__ = {
- "keep_existing": True,
- }
-
id = Column(Integer, primary_key=True, autoincrement=True)
- task_desc = Column(String, nullable=True)
- workflow_type = Column(String, nullable=True)
- reward_type = Column(String, nullable=True)
+ raw_task = Column(JSON, nullable=False)
+
+ @classmethod
+ def from_dict(cls, dict: Dict):
+ return cls(raw_task=dict)
+@SQL_SCHEMA.register_module("experience")
class ExperienceModel(Base): # type: ignore
"""SQLAlchemy model for Experience."""
__abstract__ = True
- __table_args__ = {
- "keep_existing": True,
- }
-
id = Column(Integer, primary_key=True, autoincrement=True)
- serialized_exp = Column(LargeBinary, nullable=True)
- prompt = Column(String, nullable=True)
- response = Column(String, nullable=True)
+ # for single turn
+ prompt = Column(Text, nullable=True)
+ response = Column(Text, nullable=True)
+ # for multi turn
+ message_list = Column(JSON, nullable=True)
reward = Column(Float, nullable=True)
- consumed = Column(Integer, default=0)
- priority = Column(Float, default=0.0)
+ # serialized experience object
+ experience_bytes = Column(LargeBinary, nullable=True)
def to_experience(self) -> Experience:
"""Load the experience from the database."""
- return Experience.deserialize(self.serialized_exp)
+ return Experience.deserialize(self.experience_bytes)
@classmethod
def from_experience(cls, experience: Experience):
"""Save the experience to database."""
return cls(
- serialized_exp=experience.serialize(),
+ experience_bytes=experience.serialize(),
reward=experience.reward,
prompt=experience.prompt_text,
response=experience.response_text,
+ message_list=experience.messages,
)
+@SQL_SCHEMA.register_module("sft")
class SFTDataModel(Base): # type: ignore
"""SQLAlchemy model for SFT data."""
__abstract__ = True
- __table_args__ = {
- "keep_existing": True,
- }
-
id = Column(Integer, primary_key=True, autoincrement=True)
- serialized_exp = Column(LargeBinary, nullable=True)
- messages = Column(String, nullable=True)
- consumed = Column(Integer, default=0)
+ message_list = Column(JSON, nullable=True)
+ experience_bytes = Column(LargeBinary, nullable=True)
def to_experience(self) -> Experience:
"""Load the experience from the database."""
- return Experience.deserialize(self.serialized_exp)
+ return Experience.deserialize(self.experience_bytes)
@classmethod
- def from_messages(
- cls,
- messages: list[dict],
- tokenizer: Any,
- chat_template: Optional[str] = None,
- ) -> "SFTDataModel":
- """Convert a list of messages into a single instance of SFT data."""
- from trinity.common.models.utils import tokenize_and_mask_messages_hf
-
- tokens, action_mask, prompt_length = tokenize_and_mask_messages_hf(
- tokenizer=tokenizer,
- messages=messages,
- chat_template=chat_template,
- )
- exp = Experience(
- tokens=tokens,
- action_mask=action_mask,
- prompt_length=prompt_length,
- info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])},
- )
+ def from_experience(cls, experience: Experience):
+ """Save the experience to database."""
return cls(
- serialized_exp=exp.serialize(),
- messages=messages,
+ experience_bytes=experience.serialize(),
+ message_list=experience.messages,
)
+@SQL_SCHEMA.register_module("dpo")
class DPODataModel(Base): # type: ignore
"""SQLAlchemy model for DPO data."""
__abstract__ = True
- __table_args__ = {
- "keep_existing": True,
- }
-
id = Column(Integer, primary_key=True, autoincrement=True)
- serialized_exp = Column(LargeBinary, nullable=True)
- chosen = Column(LargeBinary, nullable=True)
- rejected = Column(LargeBinary, nullable=True)
- consumed = Column(Integer, default=0)
+ chosen_message_list = Column(JSON, nullable=True)
+ rejected_message_list = Column(JSON, nullable=True)
+ experience_bytes = Column(LargeBinary, nullable=True)
def to_experience(self) -> Experience:
"""Load the experience from the database."""
- exp = Experience.deserialize(self.serialized_exp)
- exp.chosen = Experience.deserialize(self.chosen)
- exp.rejected = Experience.deserialize(self.rejected)
- return exp
+ return Experience.deserialize(self.experience_bytes)
+
+ @classmethod
+ def from_experience(cls, experience: Experience):
+ """Save the experience to database."""
+ return cls(
+ experience_bytes=experience.serialize(),
+ chosen_message_list=experience.chosen_messages,
+ rejected_message_list=experience.rejected_messages,
+ )
-def create_dynamic_table(algorithm_type: Union[str | None], table_name: str) -> Any:
- """Create a dynamic table based on the provided algorithm type and table name."""
- if algorithm_type is None:
- base_class = TaskModel
- else:
- from trinity.algorithm.algorithm import ALGORITHM_TYPE
+def init_engine(db_url: str, table_name, schema_type: Optional[str]) -> Tuple:
+ """Get the sqlalchemy engine."""
+ logger = get_logger(__name__)
+ engine = create_engine(db_url, poolclass=NullPool)
- algorithm = ALGORITHM_TYPE.get(algorithm_type)
- base_class = algorithm.schema
+ if schema_type is None:
+ schema_type = "task"
+
+ base_class = SQL_SCHEMA.get(schema_type)
table_attrs = {
"__tablename__": table_name,
+ "__abstract__": False,
+ "__table_args__": {"keep_existing": True},
}
+ table_cls = type(table_name, (base_class,), table_attrs)
+
+ try:
+ Base.metadata.create_all(engine, checkfirst=True)
+ except OperationalError:
+ logger.warning(f"Failed to create table {table_name}, assuming it already exists.")
- return type(table_name, (base_class,), table_attrs)
+ return engine, table_cls
diff --git a/trinity/buffer/storage/__init__.py b/trinity/buffer/storage/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/buffer/storage/file.py b/trinity/buffer/storage/file.py
new file mode 100644
index 0000000000..6f393cc48f
--- /dev/null
+++ b/trinity/buffer/storage/file.py
@@ -0,0 +1,84 @@
+"""File Storage"""
+import json
+import os
+from typing import List
+
+import ray
+
+from trinity.buffer.utils import default_storage_path
+from trinity.common.config import BufferConfig, StorageConfig
+from trinity.common.experience import EID, Experience
+from trinity.common.workflows import Task
+
+
+class _Encoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, Experience):
+ return o.to_dict()
+ if isinstance(o, Task):
+ return o.to_dict()
+ if isinstance(o, EID):
+ return o.to_dict()
+ return super().default(o)
+
+
+class FileStorage:
+ """
+ A wrapper of a local jsonl file.
+
+ If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as
+ a Ray Actor, and provide a remote interface to the local file.
+
+ This wrapper is only for writing, if you want to read from the file, use
+ StorageType.QUEUE instead.
+ """
+
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
+ if storage_config.path is None:
+ storage_config.path = default_storage_path(storage_config, config)
+ ext = os.path.splitext(storage_config.path)[-1]
+ if ext != ".jsonl" and ext != ".json":
+ raise ValueError(
+ f"File path must end with '.json' or '.jsonl', got {storage_config.path}"
+ )
+ path_dir = os.path.dirname(os.path.abspath(storage_config.path))
+ os.makedirs(path_dir, exist_ok=True)
+ self.file = open(storage_config.path, "a", encoding="utf-8")
+ self.encoder = _Encoder(ensure_ascii=False)
+ self.ref_count = 0
+
+ @classmethod
+ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
+ if storage_config.wrap_in_ray:
+ return (
+ ray.remote(cls)
+ .options(
+ name=f"json-{storage_config.name}",
+ namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
+ get_if_exists=True,
+ )
+ .remote(storage_config, config)
+ )
+ else:
+ return cls(storage_config, config)
+
+ def write(self, data: List) -> None:
+ for item in data:
+ json_str = self.encoder.encode(item)
+ self.file.write(json_str + "\n")
+ self.file.flush()
+
+ def read(self) -> List:
+ raise NotImplementedError(
+ "read() is not implemented for FILE Storage, please use QUEUE instead"
+ )
+
+ def acquire(self) -> int:
+ self.ref_count += 1
+ return self.ref_count
+
+ def release(self) -> int:
+ self.ref_count -= 1
+ if self.ref_count <= 0:
+ self.file.close()
+ return self.ref_count
diff --git a/trinity/buffer/queue.py b/trinity/buffer/storage/queue.py
similarity index 63%
rename from trinity/buffer/queue.py
rename to trinity/buffer/storage/queue.py
index a43c4e82ca..71cebd9d27 100644
--- a/trinity/buffer/queue.py
+++ b/trinity/buffer/storage/queue.py
@@ -1,17 +1,30 @@
-"""Implementation of async queue buffers."""
+"""Ray Queue storage"""
import asyncio
+import time
from abc import ABC, abstractmethod
from collections import deque
+from copy import deepcopy
from functools import partial
from typing import List, Optional
+import ray
from sortedcontainers import SortedDict
from trinity.common.config import BufferConfig, StorageConfig
+from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
+
+def is_database_url(path: str) -> bool:
+ return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])
+
+
+def is_json_file(path: str) -> bool:
+ return path.endswith(".json") or path.endswith(".jsonl")
+
+
PRIORITY_FUNC = Registry("priority_fn")
@@ -194,3 +207,95 @@ async def close(self) -> None:
def stopped(self) -> bool:
return self._closed and len(self.priority_groups) == 0
+
+
+class QueueStorage:
+ """An wrapper of a async queue."""
+
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
+ self.logger = get_logger(f"queue_{storage_config.name}", in_ray_actor=True)
+ self.config = config
+ self.capacity = storage_config.capacity
+ self.queue = QueueBuffer.get_queue(storage_config, config)
+ st_config = deepcopy(storage_config)
+ st_config.wrap_in_ray = False
+ if st_config.path is not None:
+ if is_database_url(st_config.path):
+ from trinity.buffer.writer.sql_writer import SQLWriter
+
+ st_config.storage_type = StorageType.SQL
+ self.writer = SQLWriter(st_config, self.config)
+ elif is_json_file(st_config.path):
+ from trinity.buffer.writer.file_writer import JSONWriter
+
+ st_config.storage_type = StorageType.FILE
+ self.writer = JSONWriter(st_config, self.config)
+ else:
+ self.logger.warning("Unknown supported storage path: %s", st_config.path)
+ self.writer = None
+ else:
+ from trinity.buffer.writer.file_writer import JSONWriter
+
+ st_config.storage_type = StorageType.FILE
+ self.writer = JSONWriter(st_config, self.config)
+ self.logger.warning(f"Save experiences in {st_config.path}.")
+ self.ref_count = 0
+ self.exp_pool = deque() # A pool to store experiences
+ self.closed = False
+
+ async def acquire(self) -> int:
+ self.ref_count += 1
+ return self.ref_count
+
+ async def release(self) -> int:
+ """Release the queue."""
+ self.ref_count -= 1
+ if self.ref_count <= 0:
+ await self.queue.close()
+ await self.writer.release()
+ return self.ref_count
+
+ def length(self) -> int:
+ """The length of the queue."""
+ return self.queue.qsize()
+
+ async def put_batch(self, exp_list: List) -> None:
+ """Put batch of experience."""
+ await self.queue.put(exp_list)
+ if self.writer is not None:
+ self.writer.write(exp_list)
+
+ async def get_batch(self, batch_size: int, timeout: float) -> List:
+ """Get batch of experience."""
+ start_time = time.time()
+ while len(self.exp_pool) < batch_size:
+ if self.queue.stopped():
+ # If the queue is stopped, ignore the rest of the experiences in the pool
+ raise StopAsyncIteration("Queue is closed and no more items to get.")
+ try:
+ exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0)
+ self.exp_pool.extend(exp_list)
+ except asyncio.TimeoutError:
+ if time.time() - start_time > timeout:
+ self.logger.error(
+ f"Timeout when waiting for experience, only get {len(self.exp_pool)} experiences.\n"
+ "This phenomenon is usually caused by the workflow not returning enough "
+ "experiences or running timeout. Please check your workflow implementation."
+ )
+ batch = list(self.exp_pool)
+ self.exp_pool.clear()
+ return batch
+ return [self.exp_pool.popleft() for _ in range(batch_size)]
+
+ @classmethod
+ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
+ """Get the queue actor."""
+ return (
+ ray.remote(cls)
+ .options(
+ name=f"queue-{storage_config.name}",
+ namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
+ get_if_exists=True,
+ )
+ .remote(storage_config, config)
+ )
diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py
new file mode 100644
index 0000000000..cbbdce4c53
--- /dev/null
+++ b/trinity/buffer/storage/sql.py
@@ -0,0 +1,223 @@
+"""SQL database storage"""
+
+import time
+from abc import abstractmethod
+from typing import Dict, List, Optional
+
+import ray
+from datasets import Dataset
+from sqlalchemy import asc
+from sqlalchemy.orm import sessionmaker
+
+from trinity.buffer.schema import init_engine
+from trinity.buffer.schema.formatter import FORMATTER, TaskFormatter
+from trinity.buffer.utils import default_storage_path, retry_session
+from trinity.common.config import BufferConfig, StorageConfig
+from trinity.common.experience import Experience
+from trinity.common.rewards import REWARD_FUNCTIONS
+from trinity.common.workflows import WORKFLOWS, Task
+from trinity.utils.log import get_logger
+
+
+class SQLStorage:
+ """
+ An Storage based on SQL Database.
+
+ If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor,
+ and provide a remote interface to the local database.
+
+ For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), please
+ set `wrap_in_ray` to `True`.
+ """
+
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
+ self.logger = get_logger(f"sql_{storage_config.name}", in_ray_actor=True)
+ if storage_config.path is None:
+ storage_config.path = default_storage_path(storage_config, config)
+ self.engine, self.table_model_cls = init_engine(
+ db_url=storage_config.path,
+ table_name=storage_config.name,
+ schema_type=storage_config.schema_type,
+ )
+ self.logger.info(f"Init SQL storage at {storage_config.path}")
+ self.session = sessionmaker(bind=self.engine)
+ self.max_retry_times = storage_config.max_retry_times
+ self.max_retry_interval = storage_config.max_retry_interval
+ self.ref_count = 0
+ self.stopped = False
+ # Assume that the auto-increment ID starts counting from 1, so the default offset should be 0.
+ self.offset = storage_config.index
+
+ @classmethod
+ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
+ if storage_config.schema_type is None:
+ storage_cls = SQLTaskStorage
+ else:
+ storage_cls = SQLExperienceStorage
+ if storage_config.wrap_in_ray:
+ return (
+ ray.remote(storage_cls)
+ .options(
+ name=f"sql-{storage_config.name}",
+ namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
+ get_if_exists=True,
+ max_concurrency=5,
+ )
+ .remote(storage_config, config)
+ )
+ else:
+ return storage_cls(storage_config, config)
+
+ @abstractmethod
+ def write(self, data: List) -> None:
+ """Write a batch of data."""
+
+ @abstractmethod
+ def read(self, batch_size: Optional[int] = None) -> List:
+ """Read a batch of data."""
+
+ def acquire(self) -> int:
+ self.ref_count += 1
+ return self.ref_count
+
+ def release(self) -> int:
+ self.ref_count -= 1
+ if self.ref_count <= 0:
+ self.stopped = True
+ return self.ref_count
+
+
+class SQLExperienceStorage(SQLStorage):
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
+ super().__init__(storage_config, config)
+ self.batch_size = config.train_batch_size
+ self.max_timeout = storage_config.max_read_timeout
+
+ def write(self, data: List[Experience]) -> None:
+ with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
+ experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
+ session.add_all(experience_models)
+
+ def read(self, batch_size: Optional[int] = None) -> List[Experience]:
+ if self.stopped:
+ raise StopIteration()
+
+ exp_list = []
+ batch_size = batch_size or self.batch_size # type: ignore
+ start_time = time.time()
+ while len(exp_list) < batch_size:
+ if self.stopped:
+ raise StopIteration()
+ if len(exp_list):
+ self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...")
+ time.sleep(1)
+ if time.time() - start_time > self.max_timeout:
+ self.logger.warning(
+ f"Max read timeout reached ({self.max_timeout} s), only get {len(exp_list)} experiences, stopping..."
+ )
+ raise StopIteration()
+ with retry_session(
+ self.session, self.max_retry_times, self.max_retry_interval
+ ) as session:
+ # get a batch of experiences from the database
+ experiences = (
+ session.query(self.table_model_cls)
+ .filter(self.table_model_cls.id > self.offset)
+ .order_by(asc(self.table_model_cls.id))
+ .limit(batch_size - len(exp_list))
+ .all()
+ )
+ if experiences:
+ self.offset = experiences[-1].id
+ start_time = time.time()
+ exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
+ return exp_list
+
+ @classmethod
+ def load_from_dataset(
+ cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig
+ ) -> "SQLExperienceStorage":
+ import transformers
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
+ storage = cls(
+ storage_config=storage_config,
+ config=config,
+ )
+ formatter = FORMATTER.get(storage_config.schema_type)(tokenizer, storage_config.format)
+ batch_size = storage.batch_size
+ batch = []
+ for item in dataset:
+ batch.append(formatter.format(item))
+ if len(batch) >= batch_size:
+ storage.write(batch)
+ batch.clear()
+ if batch:
+ storage.write(batch)
+ return storage
+
+
+class SQLTaskStorage(SQLStorage):
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
+ super().__init__(storage_config, config)
+ self.batch_size = config.batch_size
+ self.is_eval = storage_config.is_eval
+ self.default_workflow_cls = WORKFLOWS.get(storage_config.default_workflow_type) # type: ignore
+ if self.is_eval and storage_config.default_eval_workflow_type:
+ self.default_workflow_cls = WORKFLOWS.get(storage_config.default_eval_workflow_type)
+ self.default_reward_fn_cls = REWARD_FUNCTIONS.get(storage_config.default_reward_fn_type) # type: ignore
+ self.formatter = TaskFormatter(storage_config)
+ self.offset = storage_config.index
+ if storage_config.total_steps:
+ self.total_samples = self.batch_size * storage_config.total_steps
+ else:
+ if storage_config.total_epochs > 1:
+ self.logger.warning(
+ f"SQL Storage do not support total_epochs, the value {storage_config.total_epochs} will be ignored"
+ )
+ self.total_samples = float("inf")
+
+ def write(self, data: List[Dict]) -> None:
+ with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
+ tasks = [self.table_model_cls.from_dict(item) for item in data]
+ session.add_all(tasks)
+
+ def read(self, batch_size: Optional[int] = None) -> List[Task]:
+ if self.stopped:
+ raise StopIteration()
+ if self.offset > self.total_samples:
+ raise StopIteration()
+ batch_size = batch_size or self.batch_size
+ with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
+ query = (
+ session.query(self.table_model_cls)
+ .filter(self.table_model_cls.id > self.offset)
+ .order_by(asc(self.table_model_cls.id))
+ .limit(batch_size)
+ )
+ results = query.all()
+ if len(results) == 0:
+ raise StopIteration()
+ if not self.is_eval and len(results) < batch_size:
+ raise StopIteration()
+ self.offset = results[-1].id
+ return [self.formatter.format(item.raw_task) for item in results]
+
+ @classmethod
+ def load_from_dataset(
+ cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig
+ ) -> "SQLTaskStorage":
+ storage = cls(
+ storage_config=storage_config,
+ config=config,
+ )
+ batch_size = config.batch_size
+ batch = []
+ for item in dataset:
+ batch.append(item)
+ if len(batch) >= batch_size:
+ storage.write(batch)
+ batch.clear()
+ if batch:
+ storage.write(batch)
+ return storage
diff --git a/trinity/buffer/utils.py b/trinity/buffer/utils.py
index aa2c0e4849..7db8dde867 100644
--- a/trinity/buffer/utils.py
+++ b/trinity/buffer/utils.py
@@ -6,18 +6,19 @@
from trinity.common.constants import StorageType
from trinity.utils.log import get_logger
-logger = get_logger(__name__)
-
@contextmanager
def retry_session(session_maker, max_retry_times: int, max_retry_interval: float):
"""A Context manager for retrying session."""
+ logger = get_logger(__name__)
for attempt in range(max_retry_times):
try:
session = session_maker()
yield session
session.commit()
break
+ except StopIteration as e:
+ raise e
except Exception as e:
import traceback
diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py
index 93c10479ca..5a579bf59c 100644
--- a/trinity/buffer/writer/file_writer.py
+++ b/trinity/buffer/writer/file_writer.py
@@ -3,7 +3,7 @@
import ray
from trinity.buffer.buffer_writer import BufferWriter
-from trinity.buffer.ray_wrapper import FileWrapper
+from trinity.buffer.storage.file import FileStorage
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
@@ -11,7 +11,7 @@
class JSONWriter(BufferWriter):
def __init__(self, meta: StorageConfig, config: BufferConfig):
assert meta.storage_type == StorageType.FILE
- self.writer = FileWrapper.get_wrapper(meta, config)
+ self.writer = FileStorage.get_wrapper(meta, config)
self.wrap_in_ray = meta.wrap_in_ray
def write(self, data: List) -> None:
diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py
index 27dc6df28a..7e4f4a9ca1 100644
--- a/trinity/buffer/writer/queue_writer.py
+++ b/trinity/buffer/writer/queue_writer.py
@@ -4,7 +4,7 @@
import ray
from trinity.buffer.buffer_writer import BufferWriter
-from trinity.buffer.ray_wrapper import QueueWrapper
+from trinity.buffer.storage.queue import QueueStorage
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
@@ -15,7 +15,7 @@ class QueueWriter(BufferWriter):
def __init__(self, meta: StorageConfig, config: BufferConfig):
assert meta.storage_type == StorageType.QUEUE
self.config = config
- self.queue = QueueWrapper.get_wrapper(meta, config)
+ self.queue = QueueStorage.get_wrapper(meta, config)
def write(self, data: List) -> None:
ray.get(self.queue.put_batch.remote(data))
diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py
index a951201b80..9ffaa13adb 100644
--- a/trinity/buffer/writer/sql_writer.py
+++ b/trinity/buffer/writer/sql_writer.py
@@ -3,7 +3,7 @@
import ray
from trinity.buffer.buffer_writer import BufferWriter
-from trinity.buffer.ray_wrapper import DBWrapper
+from trinity.buffer.storage.sql import SQLStorage
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
@@ -15,7 +15,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
# we only support write RFT algorithm buffer for now
self.wrap_in_ray = meta.wrap_in_ray
- self.db_wrapper = DBWrapper.get_wrapper(meta, config)
+ self.db_wrapper = SQLStorage.get_wrapper(meta, config)
def write(self, data: list) -> None:
if self.wrap_in_ray:
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 911885fe50..1382a561dc 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -16,7 +16,6 @@
StorageType,
SyncMethod,
SyncStyle,
- TaskType,
)
from trinity.utils.annotations import Experimental
from trinity.utils.log import get_logger
@@ -28,6 +27,7 @@
class FormatConfig:
"""Configuration for data formatting"""
+ # for sft / dpo
prompt_type: PromptType = PromptType.MESSAGES
# for plaintext input
@@ -73,18 +73,12 @@ class StorageConfig:
path: Optional[str] = None
repeat_times: Optional[int] = None
- # only available for StorageType.FILE. When requiring data processing on raw data, set the raw to True.
- raw: bool = False
-
# used for StorageType.FILE
split: str = "train"
subset_name: Optional[str] = None
format: FormatConfig = field(default_factory=FormatConfig)
index: int = 0
- # used for StorageType.SQL/FILE
- wrap_in_ray: bool = True
-
# used for StorageType.QUEUE
capacity: int = 10000
max_read_timeout: float = 1800
@@ -94,6 +88,10 @@ class StorageConfig:
default_factory=lambda: {"priority_fn": "linear_decay", "decay": 0.1}
)
+ # used for StorageType.SQL
+ max_retry_times: int = 3
+ max_retry_interval: int = 1
+
# used for rollout tasks
default_workflow_type: Optional[str] = None
default_eval_workflow_type: Optional[str] = None
@@ -108,8 +106,11 @@ class StorageConfig:
# get storage from existing experiment
ray_namespace: Optional[str] = None
- # ! DO NOT SET, automatically set from algorithm.algorithm_type
- algorithm_type: Optional[str] = None
+ # ! DO NOT SET except you know what you are doing
+ wrap_in_ray: bool = True
+
+ # ! DO NOT SET, automatically set
+ schema_type: Optional[str] = None
# ! DO NOT SET, automatically set from buffer.total_epochs
total_epochs: int = 1 # automatically set
@@ -118,7 +119,7 @@ class StorageConfig:
total_steps: Optional[int] = None # automatically set
# ! DO NOT SET, automatically set corresponding to train/eval
- task_type: TaskType = TaskType.EXPLORE
+ is_eval: bool = False
@dataclass
@@ -349,10 +350,6 @@ class BufferConfig:
# for trainer
trainer_input: TrainerInput = field(default_factory=TrainerInput)
- # for storage connection
- max_retry_times: int = 3
- max_retry_interval: int = 1
-
# ! DO NOT SET FOLLOWING FIELDS
explorer_output: Optional[StorageConfig] = None # automatically set
tokenizer_path: Optional[str] = None # automatically set
@@ -551,7 +548,7 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.trainer_input.experience_buffer.total_epochs = self.buffer.total_epochs
self.buffer.trainer_input.experience_buffer.total_steps = self.buffer.total_steps
else:
- self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE
+ self.buffer.explorer_input.taskset.is_eval = False
self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs
self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps
if self.buffer.explorer_input.taskset.default_workflow_type is None:
@@ -582,7 +579,7 @@ def _check_buffer(self) -> None: # noqa: C901
if not dataset.path:
logger.warning(f"Eval dataset [{dataset}]'s path is not configured. Skip.")
continue
- dataset.task_type = TaskType.EVAL
+ dataset.is_eval = True
if not dataset.name:
dataset.name = f"eval_taskset_{idx}"
if dataset.repeat_times is None:
@@ -623,9 +620,11 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.trainer_input.experience_buffer.storage_type = StorageType.QUEUE
if self.buffer.trainer_input.experience_buffer is not None:
- self.buffer.trainer_input.experience_buffer.algorithm_type = (
+ from trinity.algorithm.algorithm import ALGORITHM_TYPE
+
+ self.buffer.trainer_input.experience_buffer.schema_type = ALGORITHM_TYPE.get(
self.algorithm.algorithm_type
- )
+ ).schema
if self.buffer.trainer_input.experience_buffer.ray_namespace is None:
self.buffer.trainer_input.experience_buffer.ray_namespace = self.ray_namespace
@@ -648,7 +647,7 @@ def _check_buffer(self) -> None: # noqa: C901
"`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0"
)
if self.buffer.trainer_input.sft_warmup_dataset is not None:
- self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO
+ self.buffer.trainer_input.sft_warmup_dataset.schema_type = "sft"
self.buffer.trainer_input.sft_warmup_dataset.total_steps = (
self.buffer.trainer_input.sft_warmup_steps
)
diff --git a/trinity/common/constants.py b/trinity/common/constants.py
index c5a0d76af0..531d113965 100644
--- a/trinity/common/constants.py
+++ b/trinity/common/constants.py
@@ -48,13 +48,6 @@ class PromptType(CaseInsensitiveEnum):
PLAINTEXT = "plaintext" # user prompt text and assistant response text
-class TaskType(Enum):
- """Task Type."""
-
- EXPLORE = 0
- EVAL = 1
-
-
class StorageType(CaseInsensitiveEnum):
"""Storage Type."""
@@ -68,6 +61,7 @@ class MonitorType(CaseInsensitiveEnum):
WANDB = "wandb"
TENSORBOARD = "tensorboard"
+ MLFLOW = "mlflow"
class SyncMethodEnumMeta(CaseInsensitiveEnumMeta):
diff --git a/trinity/common/experience.py b/trinity/common/experience.py
index d4d9aa6cc2..3fbcd6df72 100644
--- a/trinity/common/experience.py
+++ b/trinity/common/experience.py
@@ -98,6 +98,7 @@ class CustomField:
class Experience:
eid: EID = field(default_factory=EID) # Unique identifier for the experience
tokens: Optional[Tensor] = None # [seq_length]
+ prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
logprobs: Optional[Tensor] = None # [resp_length]
reward: Optional[float] = None
advantages: Optional[Tensor] = None # [resp_length]
@@ -110,12 +111,11 @@ class Experience:
) # Metrics associated with the experience, directly used by the monitor
# for single-turn experiences
- prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
response_text: Optional[str] = None # Text of the response
prompt_text: Optional[str] = None # Text of the prompt
# for multi-turn experiences
- # Action mask which indicates which tokens are generated by the model
+ # Action mask indicates which tokens are generated by the model
action_mask: Optional[Tensor] = None # [resp_length]
messages: Optional[List[dict]] = None # List of messages
tools: Optional[List[dict]] = None
@@ -123,8 +123,8 @@ class Experience:
# for dpo experiences
chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length]
rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length]
- chosen_text: Optional[str] = None # Text of the chosen response
- rejected_text: Optional[str] = None # Text of the rejected response
+ chosen_messages: Optional[List[dict]] = None # Chosen message list (Include prompt message)
+ rejected_messages: Optional[List[dict]] = None # Rejected message list (Include prompt message)
def __init__( # noqa: C901
self,
@@ -145,8 +145,8 @@ def __init__( # noqa: C901
tools=None,
chosen=None,
rejected=None,
- chosen_text=None,
- rejected_text=None,
+ chosen_messages=None,
+ rejected_messages=None,
):
if action_mask is not None:
experience_type = "multi_turn"
@@ -201,8 +201,8 @@ def __init__( # noqa: C901
if isinstance(rejected, list):
rejected = torch.tensor(rejected, dtype=torch.int32)
self.rejected = rejected
- self.chosen_text = chosen_text
- self.rejected_text = rejected_text
+ self.chosen_messages = chosen_messages
+ self.rejected_messages = rejected_messages
if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
@@ -241,10 +241,10 @@ def to_dict(self) -> dict:
res["messages"] = self.messages
if self.tools is not None:
res["tools"] = self.tools
- if self.chosen_text is not None:
- res["chosen_text"] = self.chosen_text
- if self.rejected_text is not None:
- res["rejected_text"] = self.rejected_text
+ if self.chosen_messages is not None:
+ res["chosen_messages"] = self.chosen_messages
+ if self.rejected_messages is not None:
+ res["rejected_messages"] = self.rejected_messages
if self.reward is not None:
res["reward"] = float(self.reward)
return res
@@ -343,7 +343,7 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
- response_text=exp.chosen_text,
+ messages=exp.chosen_messages,
)
)
single_turn_experiences.append(
@@ -360,7 +360,7 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
- response_text=exp.rejected_text,
+ messages=exp.rejected_messages,
)
)
return single_turn_experiences
diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py
index 9a1da4ca49..c7808ae732 100644
--- a/trinity/common/workflows/workflow.py
+++ b/trinity/common/workflows/workflow.py
@@ -24,7 +24,7 @@
class Task(dict):
"""A Task class that defines a task and its associated reward function / workflow."""
- workflow: Type[Workflow]
+ workflow: Type[Workflow] = None
repeat_times: Optional[int] = None
format_args: FormatConfig = field(default_factory=FormatConfig)
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py
index f241482dca..fa3c4da524 100644
--- a/trinity/manager/config_manager.py
+++ b/trinity/manager/config_manager.py
@@ -519,11 +519,11 @@ def _gen_buffer_config(self):
"name": "experience_buffer",
"storage_type": st.session_state["storage_type"],
"path": experience_buffer_path,
+ "max_retry_interval": st.session_state["max_retry_interval"],
+ "max_retry_times": st.session_state["buffer_max_retry_times"],
},
"sft_warmup_steps": st.session_state["sft_warmup_steps"],
},
- "max_retry_times": st.session_state["buffer_max_retry_times"],
- "max_retry_interval": st.session_state["max_retry_interval"],
}
if st.session_state["algorithm_type"] != "dpo":
experience_buffer = buffer_config["trainer_input"]["experience_buffer"]