diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml index c213e024e2..4df0573975 100644 --- a/benchmark/config/countdown-template.yaml +++ b/benchmark/config/countdown-template.yaml @@ -46,8 +46,6 @@ buffer: priority_fn: linear_decay decay: 0.1 sft_warmup_steps: 0 - max_retry_times: 3 - max_retry_interval: 1 explorer: runner_num: 32 max_timeout: 900 diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index 5e2ee99b77..9d6483e507 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -51,8 +51,6 @@ buffer: priority_fn: linear_decay decay: 0.1 sft_warmup_steps: 0 - max_retry_times: 3 - max_retry_interval: 1 explorer: runner_per_model: 8 max_timeout: 900 diff --git a/docs/sphinx_doc/source/_templates/versions.html b/docs/sphinx_doc/source/_templates/versions.html index a99867840c..00502f4fdd 100644 --- a/docs/sphinx_doc/source/_templates/versions.html +++ b/docs/sphinx_doc/source/_templates/versions.html @@ -2,7 +2,7 @@
Other Versions - v: {{ current_version.name }} + {{ current_version.name }}
@@ -18,7 +18,7 @@
Branches
{%- for item in versions.branches %} -
{{ item.name }} (latest)
+
{{ item.name }} (latest)
{%- endfor %}
{%- endif %} diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index fd58b75144..5df890cd96 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -54,7 +54,7 @@ class MIXAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: diff --git a/docs/sphinx_doc/source/tutorial/example_step_wise.md b/docs/sphinx_doc/source/tutorial/example_step_wise.md index be9df5880b..2d7e8bec5a 100644 --- a/docs/sphinx_doc/source/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source/tutorial/example_step_wise.md @@ -107,8 +107,6 @@ buffer: total_epochs: 20 batch_size: 16 train_batch_size: 7680 # here: batch_size * repeat_times * max_env_steps - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: alfworld diff --git a/docs/sphinx_doc/source/tutorial/faq.md b/docs/sphinx_doc/source/tutorial/faq.md index 63b0fe5b1f..c408c69180 100644 --- a/docs/sphinx_doc/source/tutorial/faq.md +++ b/docs/sphinx_doc/source/tutorial/faq.md @@ -120,7 +120,7 @@ from sqlalchemy import create_engine from sqlalchemy.exc import OperationalError from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool -from trinity.common.schema import ExperienceModel +from trinity.common.schema.sql_schema import ExperienceModel engine = create_engine(buffer.trainer_input.experience_buffer.path) session = sessionmaker(bind=engine) @@ -129,7 +129,6 @@ sess = session() MAX_EXPERIENCES = 4 experiences = ( sess.query(ExperienceModel) - .with_for_update() .limit(MAX_EXPERIENCES) .all() ) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 567ef3f626..6b109f261a 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -35,7 +35,7 @@ synchronizer: # Model weight synchronization settings ... monitor: - # Monitoring configurations (e.g., WandB or TensorBoard) + # Monitoring configurations (e.g., WandB, TensorBoard or MLFlow) ... service: # Services to use @@ -48,10 +48,12 @@ log: ... ``` -Each of these sections will be explained in detail below. +Each of these sections will be explained in detail below. For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py). -```{note} -For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py). +```{tip} +Trinity-RFT uses [OmegaConf](https://omegaconf.readthedocs.io/en/latest/) to load YAML configuration files. +It supports some advanced features like [variable interpolation](https://omegaconf.readthedocs.io/en/latest/usage.html#variable-interpolation) and [environment variable substitution](https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html#oc-env). +Users can use these features to simplify configuration. ``` --- @@ -64,7 +66,7 @@ These are general settings that apply to the entire experiment. project: Trinity-RFT name: example mode: both -checkpoint_root_dir: /PATH/TO/CHECKPOINT +checkpoint_root_dir: ${oc.env:CHECKPOINT_ROOT_DIR} # CHECKPOINT_ROOT_DIR is an environment variable set in advance ``` - `project`: The name of the project. @@ -115,13 +117,25 @@ Used to log training metrics during execution. ```yaml monitor: monitor_type: wandb + monitor_args: + base_url: http://localhost:8080 + api_key: your_api_key enable_ray_timeline: False ``` - `monitor_type`: Type of monitoring system. Options: - `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs. - `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `///monitor/tensorboard`. -- `enable_ray_timeline`: Whether to export the ray timeline. If set to `True`, a `timeline.json` file will be exported to `///monitor`. You can view the timeline file in Chrome at [chrome://tracing](chrome://tracing). + - `mlflow`: Logs to [MLFlow](https://mlflow.org/). If [MLFlow authentication](https://mlflow.org/docs/latest/ml/auth/) is setup, set `MLFLOW_TRACKING_USERNAME` and `MLFLOW_TRACKING_PASSWORD` as environment variables before running. +- `monitor_args`: Dictionary of arguments for monitor initialization. + - For `wandb`: + - `base_url`: Overrides `WANDB_BASE_URL` if set. + - `api_key`: Overrides `WANDB_API_KEY` if set. + - For `mlflow`: + - `uri`: The URI of your MLFlow instance. Strongly recommended to set; defaults to `http://localhost:5000`. + - `username`: Overrides `MLFLOW_TRACKING_USERNAME` if set. + - `password`: Overrides `MLFLOW_TRACKING_PASSWORD` if set. +- `enable_ray_timeline`: If `True`, exports a `timeline.json` file to `///monitor`. Viewable in Chrome at [chrome://tracing](chrome://tracing). --- @@ -131,8 +145,8 @@ Defines the model paths and token limits. ```yaml model: - model_path: /PATH/TO/MODEL/ - critic_model_path: '' + model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance + critic_model_path: ${model.model_path} # use the value of model.model_path max_response_tokens: 16384 max_model_len: 20480 ``` @@ -174,10 +188,6 @@ buffer: ... eval_tasksets: ... - - explorer_output: - ... - trainer_input: experience_buffer: ... @@ -255,41 +265,6 @@ The configuration for each task dataset is defined as follows: - `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used. - `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters. - -### Explorer Output - -In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section. - -```{note} -For `both` and `train` modes, users should use `buffer.trainer_input.experience_buffer` instead of `buffer.explorer_output`. -``` - -```yaml -buffer: - ... - explorer_output: - name: countdown_buffer - storage_type: queue - path: sqlite:///countdown_buffer.db - wrap_in_ray: True - max_read_timeout: 1800 -``` - -- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique. -- `storage_type`: The storage type for the experience buffer. - - `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases. - - `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes. - - `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode. -- `path`: The path to the experience buffer. - - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data. - - For `file` storage type, the path points to the directory containing the dataset files. - - For `sql` storage type, the path points to the SQLite database file. -- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor. -- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes). -- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`. -- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused. - - ### Trainer Input Defines the experience buffer and optional SFT warm-up dataset. @@ -314,7 +289,19 @@ buffer: sft_warmup_steps: 0 ``` -- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`. +- `experience_buffer`: It is the input of Trainer and also the output of Explorer. This field is required even in explore mode. + - `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique. + - `storage_type`: The storage type for the experience buffer. + - `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases. + - `sql`: Experience data is stored in a SQL database. + - `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode. + - `path`: The path to the experience buffer. + - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data. + - For `file` storage type, the path points to the directory containing the dataset files. + - For `sql` storage type, the path points to the SQLite database file. + - `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes). + - `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`. + - `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused. - `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup). - `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins. diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 8346bb0cf9..c395880a55 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -447,7 +447,7 @@ class OPMDPolicyLossFn(PolicyLossFn): The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect. -To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {object}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration. +To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration. The `AlgorithmType` class includes the following attributes and methods: @@ -473,7 +473,7 @@ class OPMDAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: diff --git a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml index c195b255bc..e79f1ce11c 100644 --- a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml +++ b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml @@ -15,8 +15,6 @@ cluster: buffer: total_epochs: 30 batch_size: 80 - max_retry_times: 1 - max_retry_interval: 1 explorer_input: taskset: name: alfworld-train diff --git a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml index 6dc17ae32c..9ecea9e923 100644 --- a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml +++ b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml @@ -15,8 +15,6 @@ cluster: buffer: total_epochs: 30 batch_size: 80 - max_retry_times: 1 - max_retry_interval: 1 explorer_input: taskset: name: alfworld-train diff --git a/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml b/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml index b16d9a90f2..55699cfbd3 100644 --- a/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml +++ b/examples/agentscope_tool_react/agentscope_tool_react_dapo.yaml @@ -16,8 +16,6 @@ buffer: total_epochs: 1 batch_size: 32 train_batch_size: 512 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: dapo diff --git a/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml b/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml index 1b8b92bdbc..0e8c2559e5 100644 --- a/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml +++ b/examples/agentscope_tool_react/agentscope_tool_react_gsm8k.yaml @@ -16,8 +16,6 @@ buffer: total_epochs: 1 batch_size: 32 train_batch_size: 256 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/asymre_gsm8k/gsm8k.yaml b/examples/asymre_gsm8k/gsm8k.yaml index d825d8f89f..3eeab4d3f4 100644 --- a/examples/asymre_gsm8k/gsm8k.yaml +++ b/examples/asymre_gsm8k/gsm8k.yaml @@ -18,8 +18,6 @@ cluster: buffer: total_steps: 80 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml index b8472f9e2d..4d93cdf972 100644 --- a/examples/asymre_math/math.yaml +++ b/examples/asymre_math/math.yaml @@ -22,8 +22,6 @@ cluster: buffer: total_steps: 2000 # Exactly 2000 training steps as desired batch_size: 16 # 128 trajectories per gradient step, batch_size is the number of tasks per batch - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: math diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml index 38da6a5c27..452536a5ce 100644 --- a/examples/async_gsm8k/explorer.yaml +++ b/examples/async_gsm8k/explorer.yaml @@ -15,8 +15,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml index f1f6b08b86..baffdf88c1 100644 --- a/examples/async_gsm8k/trainer.yaml +++ b/examples/async_gsm8k/trainer.yaml @@ -15,8 +15,6 @@ cluster: buffer: total_epochs: 1 train_batch_size: 768 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml index d5fdc73dd0..929ea18a25 100644 --- a/examples/dapo_math/dapo.yaml +++ b/examples/dapo_math/dapo.yaml @@ -17,8 +17,6 @@ cluster: buffer: total_epochs: 1 batch_size: 32 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: dapo-math diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 4ae424c162..0de0b49941 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -17,8 +17,6 @@ cluster: buffer: total_epochs: 2 train_batch_size: 64 - max_retry_times: 3 - max_retry_interval: 1 trainer_input: experience_buffer: name: dpo_buffer diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 122d90061b..c8fe7863ac 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 20 batch_size: 32 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: alfworld diff --git a/examples/grpo_alfworld_general_multi_step/alfworld.yaml b/examples/grpo_alfworld_general_multi_step/alfworld.yaml index 4421578ea6..9f0d79cba5 100644 --- a/examples/grpo_alfworld_general_multi_step/alfworld.yaml +++ b/examples/grpo_alfworld_general_multi_step/alfworld.yaml @@ -16,8 +16,6 @@ buffer: total_epochs: 20 batch_size: 16 train_batch_size: 7680 # 16 * 16 * 30 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: alfworld diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml index 381ec6fb3b..d7722f4ebc 100644 --- a/examples/grpo_email_search/email_search.yaml +++ b/examples/grpo_email_search/email_search.yaml @@ -16,8 +16,6 @@ buffer: total_epochs: 1 batch_size: 16 train_batch_size: 640 # 16*8*5 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: enron_train diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 4339472466..7673c47c44 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml index bda28ccbd2..70d881dfbb 100644 --- a/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml +++ b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml @@ -35,8 +35,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml index 86762fc10d..9a6c5358e0 100644 --- a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml +++ b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml @@ -33,8 +33,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index 168ea3cd3e..16e73ff043 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 20 batch_size: 288 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: math diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index c469423af1..743d6e6e2e 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 20 batch_size: 4 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: sciworld diff --git a/examples/grpo_toolcall/toolace.yaml b/examples/grpo_toolcall/toolace.yaml index 0581603b6b..6ffe975973 100644 --- a/examples/grpo_toolcall/toolace.yaml +++ b/examples/grpo_toolcall/toolace.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 1 batch_size: 128 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: toolace_data diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml index 218f2d9360..82fff99e20 100644 --- a/examples/grpo_webshop/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 20 batch_size: 4 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: webshop diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml index 26d04852e1..debdf515ac 100644 --- a/examples/mix_chord/mix_chord.yaml +++ b/examples/mix_chord/mix_chord.yaml @@ -33,8 +33,6 @@ buffer: total_epochs: 4 batch_size: 32 train_batch_size: 320 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: openr1_data_filtered_int @@ -59,7 +57,6 @@ buffer: total_epochs: 25 name: SFT_data storage_type: file - algorithm_type: sft path: /PATH/TO/SFT_DATASET split: 'train' format: diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index 11eafb6cc9..94a05c71ad 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -27,8 +27,6 @@ buffer: total_epochs: 10 batch_size: 32 train_batch_size: 320 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: math_train @@ -59,7 +57,6 @@ buffer: total_epochs: 10 name: math_sft storage_type: file - algorithm_type: sft path: /PATH/TO/EXPERT_DATA/ split: 'train' format: diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml index 6b12aa0e34..18cdcdb10b 100644 --- a/examples/opmd_gsm8k/opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 1 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: gsm8k diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml index 35b0f217a8..f3b342fa7d 100644 --- a/examples/ppo_countdown/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 20 batch_size: 96 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: countdown diff --git a/examples/sft_mot/sft.yaml b/examples/sft_mot/sft.yaml index a309e75c8b..603adcea7c 100644 --- a/examples/sft_mot/sft.yaml +++ b/examples/sft_mot/sft.yaml @@ -14,8 +14,6 @@ cluster: buffer: total_epochs: 1 train_batch_size: 64 - max_retry_times: 3 - max_retry_interval: 1 trainer_input: experience_buffer: name: MoT diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py new file mode 100644 index 0000000000..b3d225e122 --- /dev/null +++ b/tests/buffer/experience_storage_test.py @@ -0,0 +1,90 @@ +import os +import queue +import sqlite3 +import threading +import time + +import torch + +from tests.tools import RayUnittestBaseAysnc +from trinity.buffer.reader.sql_reader import SQLReader +from trinity.buffer.writer.sql_writer import SQLWriter +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.constants import StorageType +from trinity.common.experience import Experience + +DB_PATH = os.path.join(os.path.dirname(__file__), "test.db") + + +class ExperienceStorageTest(RayUnittestBaseAysnc): + def setUp(self): + self.total_num = 8 + self.put_batch_size = 2 + self.train_batch_size = 4 + + self.config = BufferConfig( + train_batch_size=self.train_batch_size, + ) + if os.path.exists(DB_PATH): + os.remove(DB_PATH) + + async def test_sql_storage(self): + meta = StorageConfig( + name="test_storage", + schema_type="experience", + storage_type=StorageType.SQL, + max_read_timeout=3, + path=f"sqlite:///{DB_PATH}", + ) + + writer = SQLWriter(meta, self.config) + reader = SQLReader(meta, self.config) + self.assertEqual(await writer.acquire(), 1) + exps = [ + Experience( + tokens=torch.tensor([float(j) for j in range(i + 1)]), + prompt_length=i, + reward=float(i), + logprobs=torch.tensor([0.1]), + ) + for i in range(1, self.put_batch_size + 1) + ] + for exp in exps: + exp.info = {"model_version": 0, "use_count": 0} + for _ in range(self.total_num // self.put_batch_size): + await writer.write_async(exps) + for _ in range(self.total_num // self.train_batch_size): + exps = reader.read() + self.assertEqual(len(exps), self.train_batch_size) + exps = [ + Experience( + tokens=torch.tensor([float(j) for j in range(i + 1)]), + reward=float(i), + logprobs=torch.tensor([0.1]), + action_mask=torch.tensor([j % 2 for j in range(i + 1)]), + ) + for i in range(1, self.put_batch_size * 2 + 1) + ] + writer.write(exps) + exps = reader.read(batch_size=self.put_batch_size * 2) + self.assertEqual(len(exps), self.put_batch_size * 2) + + def thread_read(reader, result_queue): + try: + batch = reader.read() + result_queue.put(batch) + except StopIteration as e: + result_queue.put(e) + + result_queue = queue.Queue() + t = threading.Thread(target=thread_read, args=(reader, result_queue)) + t.start() + time.sleep(2) # make sure the thread is waiting for data + self.assertEqual(await writer.release(), 0) + t.join(timeout=1) + self.assertIsInstance(result_queue.get(), StopIteration) + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + value = cursor.execute("SELECT COUNT(*) FROM test_storage;").fetchall() + self.assertEqual(value[0][0], self.total_num + self.put_batch_size * 2) + self.assertRaises(StopIteration, reader.read, batch_size=1) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index fb38e43660..00aec40744 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -9,9 +9,7 @@ get_unittest_dataset_config, ) from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer -from trinity.buffer.reader.file_reader import RawDataReader from trinity.buffer.utils import default_storage_path -from trinity.buffer.writer.file_writer import JSONWriter from trinity.common.config import StorageConfig from trinity.common.constants import StorageType @@ -30,34 +28,6 @@ def tearDownClass(cls): if os.path.exists(cls.temp_output_path): os.system(f"rm -rf {cls.temp_output_path}") - async def test_file_buffer(self): - meta = StorageConfig( - name="test_buffer", - path=os.path.join(self.temp_output_path, "buffer.jsonl"), - storage_type=StorageType.FILE, - raw=True, - ) - data = [ - {"key1": 1, "key2": 2}, - {"key1": 3, "key2": 4}, - {"key1": 5, "key2": 6}, - {"key1": 7, "key2": 8}, - ] - - # test writer - writer = JSONWriter(meta, None) - await writer.acquire() - writer.write(data) - await writer.release() - - # test reader - meta.path = self.temp_output_path - reader = RawDataReader(meta, None) - loaded_data = reader.read() - self.assertEqual(len(loaded_data), 4) - self.assertEqual(loaded_data, data) - self.assertRaises(StopIteration, reader.read) - def test_file_reader(self): # noqa: C901 """Test file reader.""" reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) @@ -90,7 +60,6 @@ def test_file_reader(self): # noqa: C901 while True: try: tasks.extend(reader.read()) - print(f"read from buffer, current len {len(tasks)}.") except StopIteration: break self.assertEqual(len(tasks), 20 - 8) diff --git a/tests/buffer/formatter_test.py b/tests/buffer/formatter_test.py index c1cef3de46..9da05e65af 100644 --- a/tests/buffer/formatter_test.py +++ b/tests/buffer/formatter_test.py @@ -3,13 +3,8 @@ from transformers import AutoTokenizer from tests.tools import get_model_path -from trinity.buffer.schema.formatter import ( - DPOMessagesFormatter, - DPOPlaintextFormatter, - SFTMessagesFormatter, - SFTPlaintextFormatter, -) -from trinity.common.config import FormatConfig +from trinity.buffer.schema.formatter import FORMATTER +from trinity.common.config import FormatConfig, StorageConfig from trinity.common.constants import PromptType from trinity.common.experience import Experience @@ -23,7 +18,7 @@ def test_sft_messages_formatter(self): prompt_type=PromptType.MESSAGES, messages_key="message_list", ) - formatter = SFTMessagesFormatter(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) sample = { "message_list": [ {"role": "user", "content": "Hi"}, @@ -47,7 +42,7 @@ def test_sft_messages_formatter(self): messages_key="messages", tools_key="tools", ) - formatter = SFTMessagesFormatter(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) sample = { "messages": [ { @@ -126,7 +121,7 @@ def test_sft_plaintext_formatter(self): prompt_key="prompt", response_key="response", ) - formatter = SFTPlaintextFormatter(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) sample = { "system": "You are a helpful assistant.", "prompt": "What is 2+2?", @@ -150,7 +145,7 @@ def test_sft_plaintext_formatter(self): prompt_key="prompt", response_key="response", ) - formatter = SFTPlaintextFormatter(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) exp = formatter.format(sample) self.assertIsInstance(exp, Experience) @@ -170,7 +165,7 @@ def test_dpo_plaintext_formatter(self): chosen_key="chosen", rejected_key="rejected", ) - formatter = DPOPlaintextFormatter(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config) sample = {"prompt": "What is 2+2?", "chosen": "2+2=4", "rejected": "2+2=5"} exp = formatter.format(sample) self.assertIsInstance(exp, Experience) @@ -196,7 +191,7 @@ def test_dpo_messages_formatter(self): chosen_key="chosen", rejected_key="rejected", ) - formatter = DPOMessagesFormatter(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config) sample = { "messages": [ {"role": "user", "content": "What is your name?"}, @@ -217,3 +212,63 @@ def test_dpo_messages_formatter(self): self.assertIn("What is your name?", prompt) self.assertIn("My name is Assistant.", chosen) self.assertIn("I don't have a favorite color.", rejected) + + def test_task_formatter(self): + sample = { + "question": "1+1=", + "answer": "2", + "workflow": "math_rm_workflow", + "reward": "math_boxed_reward", + } + config = StorageConfig( + is_eval=True, + default_workflow_type="math_workflow", + default_eval_workflow_type="math_boxed_workflow", + workflow_args={"use_base": True, "with_think": True}, + ) + formatter = FORMATTER.get("task")(config=config) + task = formatter.format(sample) + from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow + + self.assertEqual(task.workflow, MathBoxedWorkflow) + self.assertTrue(task.workflow_args.get("use_base")) + self.assertTrue(task.workflow_args.get("with_think")) + self.assertEqual(task.raw_task, sample) + + config = StorageConfig( + is_eval=False, + default_workflow_type="math_workflow", + default_eval_workflow_type="math_boxed_workflow", + default_reward_fn_type="math_reward", + workflow_args={"use_base": False, "with_think": True}, + ) + formatter = FORMATTER.get("task")(config=config) + task = formatter.format(sample) + from trinity.common.rewards.math_reward import MathRewardFn + from trinity.common.workflows.workflow import MathWorkflow + + self.assertEqual(task.workflow, MathWorkflow) + self.assertEqual(task.reward_fn, MathRewardFn) + self.assertFalse(task.workflow_args.get("use_base")) + self.assertTrue(task.workflow_args.get("with_think")) + self.assertEqual(task.raw_task, sample) + + config = StorageConfig( + is_eval=False, + default_eval_workflow_type="math_workflow", + workflow_args={"use_base": True, "with_think": False}, + format=FormatConfig( + workflow_key="workflow", + reward_fn_key="reward", + ), + ) + formatter = FORMATTER.get("task")(config=config) + task = formatter.format(sample) + from trinity.common.rewards.math_reward import MathBoxedRewardFn + from trinity.common.workflows.math_rm_workflow import MathRMWorkflow + + self.assertEqual(task.workflow, MathRMWorkflow) + self.assertEqual(task.reward_fn, MathBoxedRewardFn) + self.assertTrue(task.workflow_args.get("use_base")) + self.assertFalse(task.workflow_args.get("with_think")) + self.assertEqual(task.raw_task, sample) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 27ee3cb0de..d4c5ea997b 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -33,7 +33,7 @@ class TestQueueBuffer(RayUnittestBaseAysnc): async def test_queue_buffer(self, name, use_priority_queue): meta = StorageConfig( name=name, - algorithm_type="ppo", + schema_type="experience", storage_type=StorageType.QUEUE, max_read_timeout=3, path=BUFFER_FILE_PATH, @@ -97,7 +97,7 @@ async def test_priority_queue_capacity(self): self.config.train_batch_size = 4 meta = StorageConfig( name="test_buffer_small", - algorithm_type="ppo", + schema_type="experience", storage_type=StorageType.QUEUE, max_read_timeout=1, capacity=100, # priority will use 2 * train_batch_size as capacity (8) @@ -151,7 +151,7 @@ async def test_queue_buffer_capacity(self): # test queue capacity meta = StorageConfig( name="test_buffer_small", - algorithm_type="ppo", + schema_type="experience", storage_type=StorageType.QUEUE, max_read_timeout=3, capacity=4, @@ -180,7 +180,7 @@ async def test_priority_queue_buffer_reuse(self): # test queue reuse meta = StorageConfig( name="test_buffer_small", - algorithm_type="ppo", + schema_type="experience", storage_type=StorageType.QUEUE, max_read_timeout=3, capacity=4, @@ -306,8 +306,6 @@ def setUp(self): self.train_batch_size = 4 self.config = BufferConfig( - max_retry_times=3, - max_retry_interval=1, train_batch_size=self.train_batch_size, ) if os.path.exists(BUFFER_FILE_PATH): diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 33e8a24462..1934b8fa6c 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -14,20 +14,17 @@ class TestSQLBuffer(RayUnittestBaseAysnc): - async def test_create_sql_buffer(self) -> None: + async def test_sql_buffer_read_write(self) -> None: total_num = 8 put_batch_size = 2 read_batch_size = 4 meta = StorageConfig( name="test_buffer", - algorithm_type="ppo", + schema_type="experience", path=f"sqlite:///{db_path}", storage_type=StorageType.SQL, - wrap_in_ray=True, ) config = BufferConfig( - max_retry_times=3, - max_retry_interval=1, train_batch_size=read_batch_size, ) sql_writer = SQLWriter(meta, config) @@ -66,3 +63,7 @@ async def test_create_sql_buffer(self) -> None: self.assertIsNotNone(db_wrapper) self.assertEqual(await sql_writer.release(), 0) self.assertRaises(StopIteration, sql_reader.read) + + def setUp(self) -> None: + if os.path.exists(db_path): + os.remove(db_path) diff --git a/tests/buffer/task_storage_test.py b/tests/buffer/task_storage_test.py new file mode 100644 index 0000000000..b0adda2aff --- /dev/null +++ b/tests/buffer/task_storage_test.py @@ -0,0 +1,63 @@ +import os + +import datasets +from parameterized import parameterized + +from tests.tools import ( + RayUnittestBase, + get_template_config, + get_unittest_dataset_config, +) +from trinity.buffer import get_buffer_reader +from trinity.buffer.storage.sql import SQLTaskStorage +from trinity.common.constants import StorageType + +db_path = os.path.join(os.path.dirname(__file__), "test.db") + + +class TaskStorageTest(RayUnittestBase): + @parameterized.expand( + [ + (StorageType.FILE, True, 2), + (StorageType.SQL, True, 2), + (StorageType.FILE, False, 0), + (StorageType.SQL, False, 0), + (StorageType.FILE, False, 2), + (StorageType.SQL, False, 2), + ] + ) + def test_read_task(self, storage_type, is_eval, offset): + config = get_template_config() + total_samples = 17 + batch_size = 4 + config.buffer.explorer_input.taskset = get_unittest_dataset_config( + "countdown" + ) # 17 samples + config.buffer.batch_size = batch_size + config.buffer.explorer_input.taskset.storage_type = storage_type + config.buffer.explorer_input.taskset.is_eval = is_eval + config.buffer.explorer_input.taskset.index = offset + if storage_type == StorageType.SQL: + dataset = datasets.load_dataset( + config.buffer.explorer_input.taskset.path, split="train" + ) + config.buffer.explorer_input.taskset.path = f"sqlite:///{db_path}" + SQLTaskStorage.load_from_dataset( + dataset, config.buffer.explorer_input.taskset, config.buffer + ) + reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + tasks = [] + try: + while True: + cur_tasks = reader.read() + tasks.extend(cur_tasks) + except StopIteration: + pass + if is_eval: + self.assertEqual(len(tasks), total_samples - offset) + else: + self.assertEqual(len(tasks), (total_samples - offset) // batch_size * batch_size) + + def setUp(self): + if os.path.exists(db_path): + os.remove(db_path) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index f805778568..236d33d72f 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -213,7 +213,7 @@ def setUp(self): ) = StorageConfig( name="test", storage_type=StorageType.QUEUE, - algorithm_type="ppo", + schema_type="experience", path="", ) self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 1 diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index a77bc62097..cd96986c10 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -120,7 +120,6 @@ def test_synchronizer(self): config.buffer.trainer_input.experience_buffer = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - wrap_in_ray=True, ) config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER @@ -140,7 +139,6 @@ def test_synchronizer(self): explorer1_config.buffer.explorer_output = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - wrap_in_ray=True, ) explorer1_config.check_and_update() @@ -245,7 +243,6 @@ def test_synchronizer(self): config.buffer.trainer_input.experience_buffer = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - wrap_in_ray=True, ) config.synchronizer.sync_method = self.sync_method config.synchronizer.sync_style = self.sync_style @@ -265,7 +262,6 @@ def test_synchronizer(self): explorer1_config.buffer.explorer_output = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - wrap_in_ray=True, ) explorer2_config = deepcopy(explorer1_config) explorer2_config.explorer.name = "explorer2" @@ -346,7 +342,6 @@ def test_synchronizer(self): config.buffer.trainer_input.experience_buffer = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - wrap_in_ray=True, ) config.synchronizer.sync_method = SyncMethod.NCCL config.synchronizer.sync_style = self.sync_style diff --git a/tests/template/config.yaml b/tests/template/config.yaml index bed27420c8..fd2bc06194 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -21,8 +21,6 @@ cluster: # 2 for explorer, 2 for trainer buffer: total_epochs: 1 batch_size: 4 - max_retry_times: 3 - max_retry_interval: 1 explorer_input: taskset: name: taskset diff --git a/tests/test_configs/active_iterator_test_cfg.yaml b/tests/test_configs/active_iterator_test_cfg.yaml deleted file mode 100644 index 3e6008b7cf..0000000000 --- a/tests/test_configs/active_iterator_test_cfg.yaml +++ /dev/null @@ -1,18 +0,0 @@ -data_processor: - # basic info - task_pipeline: - input_buffers: - - name: 'raw_input' - path: 'tests/test_data/test_10/' - storage_type: 'file' - raw: true - output_buffer: - name: 'raw_output' - path: './outputs/task_pipeline_output/processed.jsonl' - storage_type: 'file' - format: - prompt_key: 'problem' - response_key: 'solution' - # cleaner related - dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' - clean_strategy: 'iterative' diff --git a/tests/test_configs/active_iterator_test_dj_cfg.yaml b/tests/test_configs/active_iterator_test_dj_cfg.yaml deleted file mode 100644 index 367709968f..0000000000 --- a/tests/test_configs/active_iterator_test_dj_cfg.yaml +++ /dev/null @@ -1,11 +0,0 @@ -project_name: 'demo-process' - -text_keys: 'solution' - -process: - - alphanumeric_filter: - max_ratio: 0.9 - - language_id_score_filter: - min_score: 0.5 - - character_repetition_filter: - max_ratio: 0.5 diff --git a/tests/test_configs/cleaner_test_dj_cfg.yaml b/tests/test_configs/cleaner_test_dj_cfg.yaml deleted file mode 100644 index cf11488963..0000000000 --- a/tests/test_configs/cleaner_test_dj_cfg.yaml +++ /dev/null @@ -1,7 +0,0 @@ -project_name: 'demo-process' - -export_path: './outputs/demo-process/demo-processed.jsonl' - -process: - - alphanumeric_filter: - - clean_email_mapper: diff --git a/tests/test_configs/cleaner_test_rft_cfg.yaml b/tests/test_configs/cleaner_test_rft_cfg.yaml deleted file mode 100644 index c78e3a1ac8..0000000000 --- a/tests/test_configs/cleaner_test_rft_cfg.yaml +++ /dev/null @@ -1,7 +0,0 @@ -data_processor: - task_pipeline: - input_buffers: - - path: './tests/test_data/test_cleaner' - raw: true - dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml' - clean_strategy: 'iterative' diff --git a/tests/test_configs/human_annotator_test_dj_cfg.yaml b/tests/test_configs/human_annotator_test_dj_cfg.yaml deleted file mode 100644 index ae53229062..0000000000 --- a/tests/test_configs/human_annotator_test_dj_cfg.yaml +++ /dev/null @@ -1,27 +0,0 @@ -project_name: 'demo-human-annotator' - -np: 1 - -export_path: './outputs/demo-human-annotator/annotated-data.jsonl' - -process: - - human_preference_annotation_mapper: - # general annotation project settings - project_name_prefix: "Human_Preference_Annotation_Demo" - wait_for_annotations: true # Whether to wait for annotations to complete - timeout: 3600 # Maximum time to wait for annotations in seconds (1 hour) - poll_interval: 10 # Time between annotation status checks in seconds - max_tasks_per_batch: 10 # Maximum number of tasks in a single batch - notification_config: - enabled: false - - # label studio connection settings - api_url: "http://localhost:7070" # Default Label Studio URL - api_key: "05409236-67a5-4169-af96-a52a818d0e81" # Your API key for label studuio authentication. Now it's the default api-key for default starting - - # human preference annotation settings - prompt_key: "prompt" # Prompt field - answer1_key: "answer1" # First answer option - answer2_key: "answer2" # Second answer option - chosen_key: "chosen" # Chosen field - rejected_key: "rejected" # Rejected field diff --git a/tests/test_configs/human_annotator_test_rft_cfg.yaml b/tests/test_configs/human_annotator_test_rft_cfg.yaml deleted file mode 100644 index b20f015182..0000000000 --- a/tests/test_configs/human_annotator_test_rft_cfg.yaml +++ /dev/null @@ -1,10 +0,0 @@ -data_processor: - task_pipeline: - input_buffers: - - path: './tests/test_data/test_human_annotator' - raw: true - dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml' - format: - prompt_key: 'prompt' - chosen_key: 'chosen' - rejected_key: 'rejected' diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml deleted file mode 100644 index bf474d3748..0000000000 --- a/tests/test_data/template.yaml +++ /dev/null @@ -1,25 +0,0 @@ -cluster: - node_num: 1 - gpu_per_node: 8 -buffer: - batch_size: 32 - max_retry_times: 3 - max_retry_interval: 1 - explorer_input: - taskset: - name: taskset - storage_type: file - path: '' - default_workflow_type: '' - default_eval_workflow_type: '' - default_reward_fn_type: '' -explorer: - runner_num: 8 - rollout_model: - engine_type: vllm - engine_num: 2 - tensor_parallel_size: 2 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 diff --git a/tests/test_data/test_10/data.parquet b/tests/test_data/test_10/data.parquet deleted file mode 100644 index 9528828659..0000000000 Binary files a/tests/test_data/test_10/data.parquet and /dev/null differ diff --git a/tests/test_data/test_10_with_rewfn_workflow/data.parquet b/tests/test_data/test_10_with_rewfn_workflow/data.parquet deleted file mode 100644 index 52c14aff13..0000000000 Binary files a/tests/test_data/test_10_with_rewfn_workflow/data.parquet and /dev/null differ diff --git a/tests/test_data/test_cleaner/data.parquet b/tests/test_data/test_cleaner/data.parquet deleted file mode 100644 index 857b6334f4..0000000000 Binary files a/tests/test_data/test_cleaner/data.parquet and /dev/null differ diff --git a/tests/test_data/test_human_annotator/data.jsonl b/tests/test_data/test_human_annotator/data.jsonl deleted file mode 100644 index 1e46385914..0000000000 --- a/tests/test_data/test_human_annotator/data.jsonl +++ /dev/null @@ -1,10 +0,0 @@ -{"prompt": "What is the capital of France?", "answer1": "Paris", "answer2": "Lyon"} -{"prompt": "Which planet is known as the Red Planet?", "answer1": "Mars", "answer2": "Venus"} -{"prompt": "What is the chemical symbol for gold?", "answer1": "Au", "answer2": "Ag"} -{"prompt": "Who wrote 'Romeo and Juliet'?", "answer1": "William Shakespeare", "answer2": "Christopher Marlowe"} -{"prompt": "What is the largest mammal on Earth?", "answer1": "Blue Whale", "answer2": "African Elephant"} -{"prompt": "In which year did World War II end?", "answer1": "1945", "answer2": "1944"} -{"prompt": "What is the square root of 64?", "answer1": "8", "answer2": "6"} -{"prompt": "Who painted the Mona Lisa?", "answer1": "Leonardo da Vinci", "answer2": "Michelangelo"} -{"prompt": "What is the main component of the Sun?", "answer1": "Hydrogen", "answer2": "Helium"} -{"prompt": "Which programming language was created by Guido van Rossum?", "answer1": "Python", "answer2": "Java"} diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 80e8322d28..24c0588ffa 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -250,8 +250,13 @@ def test_trainer(self): "sft_for_gsm8k" ) self.config.buffer.trainer_input.sft_warmup_steps = 3 + self.config.buffer.trainer_input.experience_buffer = StorageConfig( + name="test_sql_storage", + max_read_timeout=20, + storage_type=StorageType.SQL, + max_retry_times=10, + ) self.config.check_and_update() - self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20 self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 both(self.config) @@ -428,7 +433,6 @@ def test_fully_async_mode(self, name, use_priority_queue): config.buffer.trainer_input.experience_buffer = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - wrap_in_ray=True, use_priority_queue=use_priority_queue, ) config.synchronizer.sync_method = SyncMethod.CHECKPOINT @@ -450,7 +454,6 @@ def test_fully_async_mode(self, name, use_priority_queue): explorer1_config.buffer.trainer_input.experience_buffer = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - wrap_in_ray=True, ) explorer2_config = deepcopy(explorer1_config) explorer1_config.check_and_update() diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index fcedf20267..e3fac8d1e2 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -4,7 +4,6 @@ from abc import ABC, ABCMeta, abstractmethod from typing import Dict -from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel from trinity.common.config import Config from trinity.common.constants import SyncMethod from trinity.utils.log import get_logger @@ -23,11 +22,17 @@ def __setattr__(cls, name, value): class AlgorithmType(ABC, metaclass=ConstantMeta): - use_critic: bool - use_reference: bool - compute_advantage_in_trainer: bool - can_balance_batch: bool - schema: type + use_critic: bool # whether to use critic model + + use_reference: bool # whether to use reference model + + compute_advantage_in_trainer: bool # whether to compute advantage in trainer + # For algorithms that rely on experience grouping, + # we recommend set this value to False + + can_balance_batch: bool # balance batch in trainer + + schema: str # schema of training data @classmethod @abstractmethod @@ -51,7 +56,7 @@ class SFTAlgorithm(AlgorithmType): use_reference: bool = False compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = SFTDataModel + schema: str = "sft" @classmethod def default_config(cls) -> Dict: @@ -71,7 +76,7 @@ class PPOAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = True can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: @@ -94,7 +99,7 @@ class GRPOAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: @@ -117,7 +122,7 @@ class OPMDAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: @@ -140,7 +145,7 @@ class AsymREAlgorithm(AlgorithmType): use_reference: bool = False compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: @@ -163,7 +168,7 @@ class DPOAlgorithm(AlgorithmType): use_reference: bool = True compute_advantage_in_trainer: bool = False can_balance_batch: bool = False - schema: type = DPODataModel + schema: str = "dpo" @classmethod def default_config(cls) -> Dict: @@ -208,7 +213,7 @@ class MIXAlgorithm(AlgorithmType): compute_advantage_in_trainer: bool = False use_rollout: bool = True can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: @@ -230,7 +235,7 @@ class MIXCHORDAlgorithm(AlgorithmType): compute_advantage_in_trainer: bool = False use_rollout: bool = True can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: @@ -247,14 +252,14 @@ def default_config(cls) -> Dict: class RAFTAlgorithm(AlgorithmType): """RAFT Algorithm. This algorithm is conceptually similar to Supervised Fine-Tuning (SFT) - but is designed to work with `ExperienceModel` schema from rollouts. + but is designed to work with `experience` schema from rollouts. """ use_critic: bool = False use_reference: bool = False compute_advantage_in_trainer: bool = False can_balance_batch: bool = True - schema: type = ExperienceModel + schema: str = "experience" @classmethod def default_config(cls) -> Dict: diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index ac5798c82d..a7d52e60a7 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -17,16 +17,17 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig return QueueReader(storage_config, buffer_config) elif storage_config.storage_type == StorageType.FILE: - from trinity.buffer.reader.file_reader import FILE_READERS - - algorithm_type = storage_config.algorithm_type - if storage_config.raw: - file_read_type = "raw" - elif algorithm_type is not None: - file_read_type = algorithm_type + from trinity.buffer.reader.file_reader import ( + ExperienceFileReader, + TaskFileReader, + ) + + schema_type = storage_config.schema_type + if schema_type: + # only trainer input has schema type + return ExperienceFileReader(storage_config, buffer_config) else: - file_read_type = "rollout" - return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore + return TaskFileReader(storage_config, buffer_config) else: raise ValueError(f"{storage_config.storage_type} not supported.") diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index 6250b46703..efc1f16ad9 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -3,7 +3,7 @@ from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer from trinity.buffer.operators.experience_operator import ExperienceOperator -from trinity.buffer.ray_wrapper import is_database_url, is_json_file +from trinity.buffer.storage.queue import is_database_url, is_json_file from trinity.common.config import ( AlgorithmConfig, BufferConfig, diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py deleted file mode 100644 index f946f78495..0000000000 --- a/trinity/buffer/ray_wrapper.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Ray actor wrapper for different buffers.""" -import asyncio -import json -import os -import time -from collections import deque -from copy import deepcopy -from typing import List, Optional - -import ray -from sqlalchemy import asc, create_engine, desc -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool - -from trinity.buffer.queue import QueueBuffer -from trinity.buffer.schema import Base, create_dynamic_table -from trinity.buffer.utils import default_storage_path, retry_session -from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import StorageType -from trinity.common.experience import EID, Experience -from trinity.common.workflows import Task -from trinity.utils.log import get_logger - - -class DBWrapper: - """ - A wrapper of a SQL database. - - If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor, - and provide a remote interface to the local database. - - For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), we - recommend setting `wrap_in_ray` to `True` - """ - - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - self.logger = get_logger(f"sql_{storage_config.name}") - if storage_config.path is None: - storage_config.path = default_storage_path(storage_config, config) - self.engine = create_engine(storage_config.path, poolclass=NullPool) - self.table_model_cls = create_dynamic_table( - storage_config.algorithm_type, storage_config.name - ) - - try: - Base.metadata.create_all(self.engine, checkfirst=True) - except OperationalError: - self.logger.warning("Failed to create database, assuming it already exists.") - - self.session = sessionmaker(bind=self.engine) - self.batch_size = config.train_batch_size - self.max_retry_times = config.max_retry_times - self.max_retry_interval = config.max_retry_interval - self.ref_count = 0 - self.stopped = False - - @classmethod - def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): - if storage_config.wrap_in_ray: - return ( - ray.remote(cls) - .options( - name=f"sql-{storage_config.name}", - namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(storage_config, config) - ) - else: - return cls(storage_config, config) - - def write(self, data: list) -> None: - with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: - experience_models = [self.table_model_cls.from_experience(exp) for exp in data] - session.add_all(experience_models) - - def read(self, batch_size: Optional[int] = None) -> List: - if self.stopped: - raise StopIteration() - - sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) - - exp_list = [] - batch_size = batch_size or self.batch_size # type: ignore - while len(exp_list) < batch_size: - if len(exp_list): - self.logger.info("waiting for experiences...") - time.sleep(1) - with retry_session( - self.session, self.max_retry_times, self.max_retry_interval - ) as session: - # get a batch of experiences from the database - experiences = ( - session.query(self.table_model_cls) - .filter(self.table_model_cls.reward.isnot(None)) - .order_by(*sortOrder) # TODO: very slow - .limit(batch_size - len(exp_list)) - .with_for_update() - .all() - ) - # update the consumed field - for exp in experiences: - exp.consumed += 1 - exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) - self.logger.info(f"get {len(exp_list)} experiences:") - self.logger.info(f"reward = {[exp.reward for exp in exp_list]}") - self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}") - self.logger.info(f"first response_text = {exp_list[0].response_text}") - return exp_list - - def acquire(self) -> int: - self.ref_count += 1 - return self.ref_count - - def release(self) -> int: - self.ref_count -= 1 - if self.ref_count <= 0: - self.stopped = True - return self.ref_count - - -class _Encoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, Experience): - return o.to_dict() - if isinstance(o, Task): - return o.to_dict() - if isinstance(o, EID): - return o.to_dict() - return super().default(o) - - -class FileWrapper: - """ - A wrapper of a local jsonl file. - - If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as - a Ray Actor, and provide a remote interface to the local file. - - This wrapper is only for writing, if you want to read from the file, use - StorageType.QUEUE instead. - """ - - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - if storage_config.path is None: - storage_config.path = default_storage_path(storage_config, config) - ext = os.path.splitext(storage_config.path)[-1] - if ext != ".jsonl" and ext != ".json": - raise ValueError( - f"File path must end with '.json' or '.jsonl', got {storage_config.path}" - ) - path_dir = os.path.dirname(storage_config.path) - os.makedirs(path_dir, exist_ok=True) - self.file = open(storage_config.path, "a", encoding="utf-8") - self.encoder = _Encoder(ensure_ascii=False) - self.ref_count = 0 - - @classmethod - def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): - if storage_config.wrap_in_ray: - return ( - ray.remote(cls) - .options( - name=f"json-{storage_config.name}", - namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(storage_config, config) - ) - else: - return cls(storage_config, config) - - def write(self, data: List) -> None: - for item in data: - json_str = self.encoder.encode(item) - self.file.write(json_str + "\n") - self.file.flush() - - def read(self) -> List: - raise NotImplementedError( - "read() is not implemented for FileWrapper, please use QUEUE instead" - ) - - def acquire(self) -> int: - self.ref_count += 1 - return self.ref_count - - def release(self) -> int: - self.ref_count -= 1 - if self.ref_count <= 0: - self.file.close() - return self.ref_count - - -def is_database_url(path: str) -> bool: - return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"]) - - -def is_json_file(path: str) -> bool: - return path.endswith(".json") or path.endswith(".jsonl") - - -class QueueWrapper: - """An wrapper of a async queue.""" - - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - self.logger = get_logger(f"queue_{storage_config.name}") - self.config = config - self.capacity = storage_config.capacity - self.queue = QueueBuffer.get_queue(storage_config, config) - st_config = deepcopy(storage_config) - st_config.wrap_in_ray = False - if st_config.path is not None: - if is_database_url(st_config.path): - from trinity.buffer.writer.sql_writer import SQLWriter - - st_config.storage_type = StorageType.SQL - self.writer = SQLWriter(st_config, self.config) - elif is_json_file(st_config.path): - from trinity.buffer.writer.file_writer import JSONWriter - - st_config.storage_type = StorageType.FILE - self.writer = JSONWriter(st_config, self.config) - else: - self.logger.warning("Unknown supported storage path: %s", st_config.path) - self.writer = None - else: - from trinity.buffer.writer.file_writer import JSONWriter - - st_config.storage_type = StorageType.FILE - self.writer = JSONWriter(st_config, self.config) - self.logger.warning(f"Save experiences in {st_config.path}.") - self.ref_count = 0 - self.exp_pool = deque() # A pool to store experiences - self.closed = False - - async def acquire(self) -> int: - self.ref_count += 1 - return self.ref_count - - async def release(self) -> int: - """Release the queue.""" - self.ref_count -= 1 - if self.ref_count <= 0: - await self.queue.close() - await self.writer.release() - return self.ref_count - - def length(self) -> int: - """The length of the queue.""" - return self.queue.qsize() - - async def put_batch(self, exp_list: List) -> None: - """Put batch of experience.""" - await self.queue.put(exp_list) - if self.writer is not None: - self.writer.write(exp_list) - - async def get_batch(self, batch_size: int, timeout: float) -> List: - """Get batch of experience.""" - start_time = time.time() - while len(self.exp_pool) < batch_size: - if self.queue.stopped(): - # If the queue is stopped, ignore the rest of the experiences in the pool - raise StopAsyncIteration("Queue is closed and no more items to get.") - try: - exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0) - self.exp_pool.extend(exp_list) - except asyncio.TimeoutError: - if time.time() - start_time > timeout: - self.logger.error( - f"Timeout when waiting for experience, only get {len(self.exp_pool)} experiences.\n" - "This phenomenon is usually caused by the workflow not returning enough " - "experiences or running timeout. Please check your workflow implementation." - ) - batch = list(self.exp_pool) - self.exp_pool.clear() - return batch - return [self.exp_pool.popleft() for _ in range(batch_size)] - - @classmethod - def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): - """Get the queue actor.""" - return ( - ray.remote(cls) - .options( - name=f"queue-{storage_config.name}", - namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(storage_config, config) - ) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 5c7c558b47..b111d2de91 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -1,27 +1,14 @@ """Filed based buffer reader.""" -import copy from typing import List, Optional import datasets import transformers from datasets import Dataset, load_dataset -from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.schema.formatter import ( - DPOMessagesFormatter, - DPOPlaintextFormatter, - SFTMessagesFormatter, - SFTPlaintextFormatter, -) +from trinity.buffer.schema.formatter import FORMATTER from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import PromptType, TaskType -from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.workflows import WORKFLOWS, Task -from trinity.utils.registry import Registry - -FILE_READERS = Registry("file_readers") class DummyProgressBar: @@ -112,22 +99,14 @@ async def read_async(self, batch_size: Optional[int] = None): raise StopAsyncIteration from e -@FILE_READERS.register_module(SFTAlgorithm.name()) -class SFTDataReader(BaseFileReader): +class ExperienceFileReader(BaseFileReader): """Reader for SFT file data.""" def __init__(self, meta: StorageConfig, config: BufferConfig): self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) - if meta.format.prompt_type == PromptType.MESSAGES: - self.formatter = SFTMessagesFormatter( - tokenizer=self.tokenizer, format_config=meta.format - ) - elif meta.format.prompt_type == PromptType.PLAINTEXT: - self.formatter = SFTPlaintextFormatter( - tokenizer=self.tokenizer, format_config=meta.format - ) - else: - raise ValueError(f"Unknown prompt type: {self.prompt_type}") + self.formatter = FORMATTER.get(meta.schema_type)( + tokenizer=self.tokenizer, format_config=meta.format + ) self.read_batch_size = config.train_batch_size self.dataset = _HFBatchReader( load_dataset(meta.path, name=meta.subset_name, split=meta.split), @@ -148,48 +127,7 @@ def read(self, batch_size: Optional[int] = None) -> List: return exp_list -@FILE_READERS.register_module(DPOAlgorithm.name()) -class DPODataReader(BaseFileReader): - def __init__(self, meta: StorageConfig, config: BufferConfig): - self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) - if meta.format.prompt_type == PromptType.MESSAGES: - self.formatter = DPOMessagesFormatter( - tokenizer=self.tokenizer, format_config=meta.format - ) - elif meta.format.prompt_type == PromptType.PLAINTEXT: - self.formatter = DPOPlaintextFormatter( - tokenizer=self.tokenizer, format_config=meta.format - ) - self.read_batch_size = config.train_batch_size - self.dataset = _HFBatchReader( - load_dataset(meta.path, name=meta.subset_name, split=meta.split), - name=meta.name, - default_batch_size=self.read_batch_size, - total_epochs=meta.total_epochs, - drop_last=True, - total_steps=meta.total_steps, - enable_progress_bar=meta.enable_progress_bar, - ) # TODO: support resume - - def _get_assistant_message(self, item) -> dict: - if isinstance(item, List): - item = item[0] - if isinstance(item, str): - return {"role": "assistant", "content": item} - else: - return item - - def read(self, batch_size: Optional[int] = None) -> List: - batch_data = self.dataset.read_batch(batch_size or self.read_batch_size) - exp_list = [] - for sample in batch_data: - experience = self.formatter.format(sample) - exp_list.append(experience) - return exp_list - - -@FILE_READERS.register_module("rollout") -class RolloutDataReader(BaseFileReader): +class TaskFileReader(BaseFileReader): def __init__(self, meta: StorageConfig, config: BufferConfig): self.meta = meta self.name = meta.name @@ -203,61 +141,24 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): load_dataset(meta.path, name=subset_name, split=self.split), name=meta.name, default_batch_size=self.read_batch_size, - total_epochs=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1, + total_epochs=self.meta.total_epochs if not self.meta.is_eval else 1, offset=self.meta.index, - drop_last=self.meta.task_type == TaskType.EXPLORE, + drop_last=not self.meta.is_eval, total_steps=meta.total_steps, enable_progress_bar=meta.enable_progress_bar, ) - self.prompt_key = meta.format.prompt_key - self.response_key = meta.format.response_key - self.workflow_key = meta.format.workflow_key - self.reward_fn_key = meta.format.reward_fn_key - - self.task_type = meta.task_type - self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore - self.default_eval_workflow_cls = None - if getattr(meta, "default_eval_workflow_type", None): - self.default_eval_workflow_cls = WORKFLOWS.get(meta.default_eval_workflow_type) - self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore + self.formatter = FORMATTER.get("task")(meta) def read(self, batch_size: Optional[int] = None) -> List: batch_size = batch_size or self.read_batch_size tasks = [] samples = self.dataset.read_batch(batch_size) for sample in samples: - if self.task_type == TaskType.EVAL and self.default_eval_workflow_cls: - workflow_class = self.default_eval_workflow_cls - else: - workflow_class = ( - WORKFLOWS.get(sample[self.workflow_key]) - if self.workflow_key in sample - else self.default_workflow_cls - ) - reward_fn = ( - REWARD_FUNCTIONS.get(sample[self.reward_fn_key]) - if self.reward_fn_key in sample - else self.default_reward_fn_cls - ) - assert ( - workflow_class is not None - ), "`default_workflow_type` or `workflow_key` is required" - task = Task( - workflow=workflow_class, - repeat_times=self.meta.repeat_times, - format_args=copy.deepcopy(self.meta.format), - rollout_args=copy.deepcopy(self.meta.rollout_args), - workflow_args=copy.deepcopy(self.meta.workflow_args), - reward_fn_args=copy.deepcopy(self.meta.reward_fn_args), - is_eval=self.meta.task_type == TaskType.EVAL, - reward_fn=reward_fn, - raw_task=sample, - ) + task = self.formatter.format(sample) tasks.append(task) return tasks -@FILE_READERS.register_module("raw") class RawDataReader(BaseFileReader): def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]): self.returned = False diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index ad5ad1165f..e4ea695a21 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -5,7 +5,7 @@ import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.ray_wrapper import QueueWrapper +from trinity.buffer.storage.queue import QueueStorage from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType @@ -17,7 +17,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig): assert storage_config.storage_type == StorageType.QUEUE self.timeout = storage_config.max_read_timeout self.read_batch_size = config.train_batch_size - self.queue = QueueWrapper.get_wrapper(storage_config, config) + self.queue = QueueStorage.get_wrapper(storage_config, config) def read(self, batch_size: Optional[int] = None) -> List: try: diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index e715557957..7e45b842d2 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -5,7 +5,7 @@ import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.ray_wrapper import DBWrapper +from trinity.buffer.storage.sql import SQLStorage from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType @@ -16,16 +16,19 @@ class SQLReader(BufferReader): def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL self.wrap_in_ray = meta.wrap_in_ray - self.db_wrapper = DBWrapper.get_wrapper(meta, config) + self.storage = SQLStorage.get_wrapper(meta, config) def read(self, batch_size: Optional[int] = None) -> List: if self.wrap_in_ray: - return ray.get(self.db_wrapper.read.remote(batch_size)) + return ray.get(self.storage.read.remote(batch_size)) else: - return self.db_wrapper.read(batch_size) + return self.storage.read(batch_size) async def read_async(self, batch_size: Optional[int] = None) -> List: if self.wrap_in_ray: - return await self.db_wrapper.read.remote(batch_size) + try: + return ray.get(self.storage.read.remote(batch_size)) + except StopIteration: + raise StopAsyncIteration else: - return self.db_wrapper.read(batch_size) + return self.storage.read(batch_size) diff --git a/trinity/buffer/schema/__init__.py b/trinity/buffer/schema/__init__.py index 269c524bdf..5fdf581dbc 100644 --- a/trinity/buffer/schema/__init__.py +++ b/trinity/buffer/schema/__init__.py @@ -1,3 +1,4 @@ -from .sql_schema import Base, create_dynamic_table +from trinity.buffer.schema.formatter import FORMATTER +from trinity.buffer.schema.sql_schema import init_engine -__all__ = ["create_dynamic_table", "Base"] +__all__ = ["init_engine", "FORMATTER"] diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py index 7285909e07..4304da82ed 100644 --- a/trinity/buffer/schema/formatter.py +++ b/trinity/buffer/schema/formatter.py @@ -1,39 +1,79 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from trinity.common.config import FormatConfig +from trinity.common.config import FormatConfig, StorageConfig +from trinity.common.constants import PromptType from trinity.common.experience import Experience +from trinity.common.rewards import REWARD_FUNCTIONS +from trinity.common.workflows import WORKFLOWS, Task +from trinity.utils.registry import Registry +FORMATTER = Registry("formatter") -class Formatter(ABC): - @abstractmethod - def __init__(self, format_config: FormatConfig): - """Initialize the formatter with the given configuration.""" +class ExperienceFormatter(ABC): @abstractmethod def format(self, sample: Dict) -> Experience: """Format a raw sample dict into an experience.""" -def format_messages( - tokenizer, messages: List[Dict], tools: Optional[List[Dict]] = None -) -> Experience: - tokens = tokenizer.apply_chat_template( - messages, tools=tools, add_generation_prompt=False, return_tensors="pt" - )[0] - prompt_tokens_ids = tokenizer.apply_chat_template( - messages[:-1], tools=tools, add_generation_prompt=True, return_tensors="pt" - )[0] - return Experience( - tokens=tokens, - prompt_length=len(prompt_tokens_ids), - ) +@FORMATTER.register_module("task") +class TaskFormatter: + """Formatter for task data. + Example Input: -class SFTMessagesFormatter(Formatter): - """Formatter for SFT message list data. + .. code-block:: python - Example Input: + { + "input": "Hello", + "output": "Hi" + } + """ + + def __init__(self, config: StorageConfig): + self.config = config + self.is_eval = config.is_eval + self.default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) # type: ignore + if self.is_eval and config.default_eval_workflow_type: + self.default_workflow_cls = WORKFLOWS.get(config.default_eval_workflow_type) + self.default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) # type: ignore + self.workflow_key = config.format.workflow_key + self.reward_fn_key = config.format.reward_fn_key + + def format(self, sample: Dict) -> Task: + """Format a raw sample dict into a Task.""" + + workflow_name = sample.get(self.workflow_key, None) if self.workflow_key else None + reward_fn_name = sample.get(self.reward_fn_key, None) if self.reward_fn_key else None + + workflow_cls = ( + WORKFLOWS.get(workflow_name) if workflow_name else None + ) or self.default_workflow_cls + reward_fn_cls = ( + REWARD_FUNCTIONS.get(reward_fn_name) if reward_fn_name else None + ) or self.default_reward_fn_cls + assert workflow_cls is not None, "`default_workflow_type` or `workflow_key` is required" + return Task( + workflow=workflow_cls, + reward_fn=reward_fn_cls, + format_args=self.config.format, + repeat_times=self.config.repeat_times, + rollout_args=self.config.rollout_args, + workflow_args=self.config.workflow_args, + reward_fn_args=self.config.reward_fn_args, + is_eval=self.is_eval, + raw_task=sample, + ) + + +@FORMATTER.register_module("sft") +class SFTFormatter(ExperienceFormatter): + """Formatter for SFT data, supporting both message list and plaintext formats. + + Uses format_config.prompt_type to distinguish between 'messages' and 'plaintext'. + + Example input of MESSAGES: .. code-block:: python @@ -43,25 +83,9 @@ class SFTMessagesFormatter(Formatter): {"role": "assistant", "content": "I'm fine, thank you!"} ] } - """ - - def __init__(self, tokenizer, format_config: FormatConfig): - self.tokenizer = tokenizer - self.messages_key = format_config.messages_key - self.tools_key = format_config.tools_key - - def format(self, sample: Dict) -> Experience: - """Format a raw sample dict into an experience.""" - messages = sample[self.messages_key] - tools = sample.get(self.tools_key, None) - return format_messages(self.tokenizer, messages, tools) - -class SFTPlaintextFormatter(Formatter): - """Formatter for SFT plaintext data. - - Example Input: + Example input of PLAINTEXT: .. code-block:: python @@ -74,54 +98,60 @@ class SFTPlaintextFormatter(Formatter): def __init__(self, tokenizer, format_config: FormatConfig): self.tokenizer = tokenizer - self.prompt_key = format_config.prompt_key - self.response_key = format_config.response_key - self.system_prompt_key = format_config.system_prompt_key - self.system_prompt = format_config.system_prompt - self.tools_key = format_config.tools_key + self.prompt_type = format_config.prompt_type + # For messages type + if self.prompt_type == PromptType.MESSAGES: + self.messages_key = format_config.messages_key + self.tools_key = format_config.tools_key + # For plaintext type + elif self.prompt_type == PromptType.PLAINTEXT: + self.prompt_key = format_config.prompt_key + self.response_key = format_config.response_key + self.system_prompt_key = format_config.system_prompt_key + self.system_prompt = format_config.system_prompt + self.tools_key = format_config.tools_key + else: + raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") + + def _messages_to_experience( + self, messages: List[Dict], tools: Optional[List[Dict]] = None + ) -> Experience: + tokens = self.tokenizer.apply_chat_template( + messages, tools=tools, add_generation_prompt=False, return_tensors="pt" + )[0] + prompt_tokens_ids = self.tokenizer.apply_chat_template( + messages[:-1], tools=tools, add_generation_prompt=True, return_tensors="pt" + )[0] + return Experience( + tokens=tokens, + prompt_length=len(prompt_tokens_ids), + messages=messages, + ) def format(self, sample: Dict) -> Experience: - """Format a raw sample dict into an experience.""" - messages = [] - if self.system_prompt_key is not None: - system_message = {"role": "system", "content": sample[self.system_prompt_key]} - messages.append(system_message) - elif self.system_prompt is not None: - system_message = {"role": "system", "content": self.system_prompt} - messages.append(system_message) - messages.append({"role": "user", "content": sample[self.prompt_key]}) - messages.append({"role": "assistant", "content": sample[self.response_key]}) - - return format_messages(self.tokenizer, messages, sample.get(self.tools_key, None)) - - -def format_dpo_messages( - tokenizer, - prompt_messages: List[Dict], - chosen_messages: List[Dict], - reject_messages: List[Dict], -): - prompt_tokens = tokenizer.apply_chat_template( - prompt_messages, add_generation_prompt=True, return_tensors="pt" - )[0] - chosen_tokens = tokenizer.apply_chat_template( - prompt_messages + chosen_messages, add_generation_prompt=False, return_tensors="pt" - )[0][len(prompt_tokens) :] - reject_tokens = tokenizer.apply_chat_template( - prompt_messages + reject_messages, add_generation_prompt=False, return_tensors="pt" - )[0][len(prompt_tokens) :] - return Experience( - tokens=prompt_tokens, - prompt_length=len(prompt_tokens), - chosen=chosen_tokens, - rejected=reject_tokens, - ) - - -class DPOPlaintextFormatter(Formatter): + if self.prompt_type == PromptType.MESSAGES: + messages = sample[self.messages_key] + elif self.prompt_type == PromptType.PLAINTEXT: + messages = [] + if self.system_prompt_key is not None: + system_message = {"role": "system", "content": sample[self.system_prompt_key]} + messages.append(system_message) + elif self.system_prompt is not None: + system_message = {"role": "system", "content": self.system_prompt} + messages.append(system_message) + messages.append({"role": "user", "content": sample[self.prompt_key]}) + messages.append({"role": "assistant", "content": sample[self.response_key]}) + else: + raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") + tools = sample.get(self.tools_key, None) + return self._messages_to_experience(messages, tools) + + +@FORMATTER.register_module("dpo") +class DPOFormatter(ExperienceFormatter): """Formatter for DPO plaintext data. - Example Input: + Example Input for PLAINTEXT: .. code-block:: python @@ -130,38 +160,8 @@ class DPOPlaintextFormatter(Formatter): "chosen": "My name is Assistant.", "rejected": "I don't have a name." } - """ - def __init__(self, tokenizer, format_config: FormatConfig): - self.tokenizer = tokenizer - self.prompt_key = format_config.prompt_key - self.chosen_key = format_config.chosen_key - self.rejected_key = format_config.rejected_key - self.system_prompt_key = format_config.system_prompt_key - self.system_prompt = format_config.system_prompt - # currently DPO not support tools - - def format(self, sample: Dict) -> Experience: - messages = [] - if self.system_prompt_key is not None: - system_message = {"role": "system", "content": sample[self.system_prompt_key]} - messages.append(system_message) - elif self.system_prompt is not None: - system_message = {"role": "system", "content": self.system_prompt} - messages.append(system_message) - messages.append({"role": "user", "content": sample[self.prompt_key]}) - return format_dpo_messages( - tokenizer=self.tokenizer, - prompt_messages=messages, - chosen_messages=[{"role": "assistant", "content": sample[self.chosen_key]}], - reject_messages=[{"role": "assistant", "content": sample[self.rejected_key]}], - ) - - -class DPOMessagesFormatter(Formatter): - """Formatter for DPO message list data. - - Example Input: + Example Input for MESSAGES: .. code-block:: python @@ -180,12 +180,62 @@ class DPOMessagesFormatter(Formatter): def __init__(self, tokenizer, format_config: FormatConfig): self.tokenizer = tokenizer - self.messages_key = format_config.messages_key - self.chosen_key = format_config.chosen_key - self.rejected_key = format_config.rejected_key + self.prompt_type = format_config.prompt_type + if self.prompt_type == PromptType.PLAINTEXT: + self.prompt_key = format_config.prompt_key + self.chosen_key = format_config.chosen_key + self.rejected_key = format_config.rejected_key + self.system_prompt_key = format_config.system_prompt_key + self.system_prompt = format_config.system_prompt + elif self.prompt_type == PromptType.MESSAGES: + self.messages_key = format_config.messages_key + self.chosen_key = format_config.chosen_key + self.rejected_key = format_config.rejected_key + else: + raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") + # currently DPO not support tools + + def _messages_to_experience( + self, prompt_messages, chosen_messages, rejected_messages + ) -> Experience: + prompt_tokens = self.tokenizer.apply_chat_template( + prompt_messages, add_generation_prompt=True, return_tensors="pt" + )[0] + chosen_tokens = self.tokenizer.apply_chat_template( + prompt_messages + chosen_messages, add_generation_prompt=False, return_tensors="pt" + )[0][len(prompt_tokens) :] + rejected_tokens = self.tokenizer.apply_chat_template( + prompt_messages + rejected_messages, add_generation_prompt=False, return_tensors="pt" + )[0][len(prompt_tokens) :] + return Experience( + tokens=prompt_tokens, + prompt_length=len(prompt_tokens), + chosen=chosen_tokens, + rejected=rejected_tokens, + chosen_messages=prompt_messages + chosen_messages, + rejected_messages=prompt_messages + rejected_messages, + ) def format(self, sample: Dict) -> Experience: - messages = sample[self.messages_key] - chosen = sample[self.chosen_key] - rejected = sample[self.rejected_key] - return format_dpo_messages(self.tokenizer, messages, chosen, rejected) + if self.prompt_type == PromptType.PLAINTEXT: + messages = [] + if self.system_prompt_key is not None: + system_message = {"role": "system", "content": sample[self.system_prompt_key]} + messages.append(system_message) + elif self.system_prompt is not None: + system_message = {"role": "system", "content": self.system_prompt} + messages.append(system_message) + messages.append({"role": "user", "content": sample[self.prompt_key]}) + chosen = [{"role": "assistant", "content": sample[self.chosen_key]}] + rejected = [{"role": "assistant", "content": sample[self.rejected_key]}] + elif self.prompt_type == PromptType.MESSAGES: + messages = sample[self.messages_key] + chosen = sample[self.chosen_key] + rejected = sample[self.rejected_key] + else: + raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") + return self._messages_to_experience( + prompt_messages=messages, + chosen_messages=chosen, + rejected_messages=rejected, + ) diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index 61149b8b25..74a9135b7f 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -1,142 +1,135 @@ -"""Schema for SQLAlchemy models.""" +"""SQLAlchemy models for different data.""" -from typing import Any, Optional, Union +from typing import Dict, Optional, Tuple -from sqlalchemy import Column, Float, Integer, LargeBinary, String -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import JSON, Column, Float, Integer, LargeBinary, Text, create_engine +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import declarative_base +from sqlalchemy.pool import NullPool from trinity.common.experience import Experience +from trinity.utils.log import get_logger +from trinity.utils.registry import Registry + +SQL_SCHEMA = Registry("sql_schema") Base = declarative_base() +@SQL_SCHEMA.register_module("task") class TaskModel(Base): # type: ignore """Model for storing tasks in SQLAlchemy.""" __abstract__ = True - __table_args__ = { - "keep_existing": True, - } - id = Column(Integer, primary_key=True, autoincrement=True) - task_desc = Column(String, nullable=True) - workflow_type = Column(String, nullable=True) - reward_type = Column(String, nullable=True) + raw_task = Column(JSON, nullable=False) + + @classmethod + def from_dict(cls, dict: Dict): + return cls(raw_task=dict) +@SQL_SCHEMA.register_module("experience") class ExperienceModel(Base): # type: ignore """SQLAlchemy model for Experience.""" __abstract__ = True - __table_args__ = { - "keep_existing": True, - } - id = Column(Integer, primary_key=True, autoincrement=True) - serialized_exp = Column(LargeBinary, nullable=True) - prompt = Column(String, nullable=True) - response = Column(String, nullable=True) + # for single turn + prompt = Column(Text, nullable=True) + response = Column(Text, nullable=True) + # for multi turn + message_list = Column(JSON, nullable=True) reward = Column(Float, nullable=True) - consumed = Column(Integer, default=0) - priority = Column(Float, default=0.0) + # serialized experience object + experience_bytes = Column(LargeBinary, nullable=True) def to_experience(self) -> Experience: """Load the experience from the database.""" - return Experience.deserialize(self.serialized_exp) + return Experience.deserialize(self.experience_bytes) @classmethod def from_experience(cls, experience: Experience): """Save the experience to database.""" return cls( - serialized_exp=experience.serialize(), + experience_bytes=experience.serialize(), reward=experience.reward, prompt=experience.prompt_text, response=experience.response_text, + message_list=experience.messages, ) +@SQL_SCHEMA.register_module("sft") class SFTDataModel(Base): # type: ignore """SQLAlchemy model for SFT data.""" __abstract__ = True - __table_args__ = { - "keep_existing": True, - } - id = Column(Integer, primary_key=True, autoincrement=True) - serialized_exp = Column(LargeBinary, nullable=True) - messages = Column(String, nullable=True) - consumed = Column(Integer, default=0) + message_list = Column(JSON, nullable=True) + experience_bytes = Column(LargeBinary, nullable=True) def to_experience(self) -> Experience: """Load the experience from the database.""" - return Experience.deserialize(self.serialized_exp) + return Experience.deserialize(self.experience_bytes) @classmethod - def from_messages( - cls, - messages: list[dict], - tokenizer: Any, - chat_template: Optional[str] = None, - ) -> "SFTDataModel": - """Convert a list of messages into a single instance of SFT data.""" - from trinity.common.models.utils import tokenize_and_mask_messages_hf - - tokens, action_mask, prompt_length = tokenize_and_mask_messages_hf( - tokenizer=tokenizer, - messages=messages, - chat_template=chat_template, - ) - exp = Experience( - tokens=tokens, - action_mask=action_mask, - prompt_length=prompt_length, - info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, - ) + def from_experience(cls, experience: Experience): + """Save the experience to database.""" return cls( - serialized_exp=exp.serialize(), - messages=messages, + experience_bytes=experience.serialize(), + message_list=experience.messages, ) +@SQL_SCHEMA.register_module("dpo") class DPODataModel(Base): # type: ignore """SQLAlchemy model for DPO data.""" __abstract__ = True - __table_args__ = { - "keep_existing": True, - } - id = Column(Integer, primary_key=True, autoincrement=True) - serialized_exp = Column(LargeBinary, nullable=True) - chosen = Column(LargeBinary, nullable=True) - rejected = Column(LargeBinary, nullable=True) - consumed = Column(Integer, default=0) + chosen_message_list = Column(JSON, nullable=True) + rejected_message_list = Column(JSON, nullable=True) + experience_bytes = Column(LargeBinary, nullable=True) def to_experience(self) -> Experience: """Load the experience from the database.""" - exp = Experience.deserialize(self.serialized_exp) - exp.chosen = Experience.deserialize(self.chosen) - exp.rejected = Experience.deserialize(self.rejected) - return exp + return Experience.deserialize(self.experience_bytes) + + @classmethod + def from_experience(cls, experience: Experience): + """Save the experience to database.""" + return cls( + experience_bytes=experience.serialize(), + chosen_message_list=experience.chosen_messages, + rejected_message_list=experience.rejected_messages, + ) -def create_dynamic_table(algorithm_type: Union[str | None], table_name: str) -> Any: - """Create a dynamic table based on the provided algorithm type and table name.""" - if algorithm_type is None: - base_class = TaskModel - else: - from trinity.algorithm.algorithm import ALGORITHM_TYPE +def init_engine(db_url: str, table_name, schema_type: Optional[str]) -> Tuple: + """Get the sqlalchemy engine.""" + logger = get_logger(__name__) + engine = create_engine(db_url, poolclass=NullPool) - algorithm = ALGORITHM_TYPE.get(algorithm_type) - base_class = algorithm.schema + if schema_type is None: + schema_type = "task" + + base_class = SQL_SCHEMA.get(schema_type) table_attrs = { "__tablename__": table_name, + "__abstract__": False, + "__table_args__": {"keep_existing": True}, } + table_cls = type(table_name, (base_class,), table_attrs) + + try: + Base.metadata.create_all(engine, checkfirst=True) + except OperationalError: + logger.warning(f"Failed to create table {table_name}, assuming it already exists.") - return type(table_name, (base_class,), table_attrs) + return engine, table_cls diff --git a/trinity/buffer/storage/__init__.py b/trinity/buffer/storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/buffer/storage/file.py b/trinity/buffer/storage/file.py new file mode 100644 index 0000000000..6f393cc48f --- /dev/null +++ b/trinity/buffer/storage/file.py @@ -0,0 +1,84 @@ +"""File Storage""" +import json +import os +from typing import List + +import ray + +from trinity.buffer.utils import default_storage_path +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.experience import EID, Experience +from trinity.common.workflows import Task + + +class _Encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Experience): + return o.to_dict() + if isinstance(o, Task): + return o.to_dict() + if isinstance(o, EID): + return o.to_dict() + return super().default(o) + + +class FileStorage: + """ + A wrapper of a local jsonl file. + + If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as + a Ray Actor, and provide a remote interface to the local file. + + This wrapper is only for writing, if you want to read from the file, use + StorageType.QUEUE instead. + """ + + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + if storage_config.path is None: + storage_config.path = default_storage_path(storage_config, config) + ext = os.path.splitext(storage_config.path)[-1] + if ext != ".jsonl" and ext != ".json": + raise ValueError( + f"File path must end with '.json' or '.jsonl', got {storage_config.path}" + ) + path_dir = os.path.dirname(os.path.abspath(storage_config.path)) + os.makedirs(path_dir, exist_ok=True) + self.file = open(storage_config.path, "a", encoding="utf-8") + self.encoder = _Encoder(ensure_ascii=False) + self.ref_count = 0 + + @classmethod + def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + if storage_config.wrap_in_ray: + return ( + ray.remote(cls) + .options( + name=f"json-{storage_config.name}", + namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, + get_if_exists=True, + ) + .remote(storage_config, config) + ) + else: + return cls(storage_config, config) + + def write(self, data: List) -> None: + for item in data: + json_str = self.encoder.encode(item) + self.file.write(json_str + "\n") + self.file.flush() + + def read(self) -> List: + raise NotImplementedError( + "read() is not implemented for FILE Storage, please use QUEUE instead" + ) + + def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + def release(self) -> int: + self.ref_count -= 1 + if self.ref_count <= 0: + self.file.close() + return self.ref_count diff --git a/trinity/buffer/queue.py b/trinity/buffer/storage/queue.py similarity index 63% rename from trinity/buffer/queue.py rename to trinity/buffer/storage/queue.py index a43c4e82ca..71cebd9d27 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/storage/queue.py @@ -1,17 +1,30 @@ -"""Implementation of async queue buffers.""" +"""Ray Queue storage""" import asyncio +import time from abc import ABC, abstractmethod from collections import deque +from copy import deepcopy from functools import partial from typing import List, Optional +import ray from sortedcontainers import SortedDict from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.constants import StorageType from trinity.common.experience import Experience from trinity.utils.log import get_logger from trinity.utils.registry import Registry + +def is_database_url(path: str) -> bool: + return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"]) + + +def is_json_file(path: str) -> bool: + return path.endswith(".json") or path.endswith(".jsonl") + + PRIORITY_FUNC = Registry("priority_fn") @@ -194,3 +207,95 @@ async def close(self) -> None: def stopped(self) -> bool: return self._closed and len(self.priority_groups) == 0 + + +class QueueStorage: + """An wrapper of a async queue.""" + + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + self.logger = get_logger(f"queue_{storage_config.name}", in_ray_actor=True) + self.config = config + self.capacity = storage_config.capacity + self.queue = QueueBuffer.get_queue(storage_config, config) + st_config = deepcopy(storage_config) + st_config.wrap_in_ray = False + if st_config.path is not None: + if is_database_url(st_config.path): + from trinity.buffer.writer.sql_writer import SQLWriter + + st_config.storage_type = StorageType.SQL + self.writer = SQLWriter(st_config, self.config) + elif is_json_file(st_config.path): + from trinity.buffer.writer.file_writer import JSONWriter + + st_config.storage_type = StorageType.FILE + self.writer = JSONWriter(st_config, self.config) + else: + self.logger.warning("Unknown supported storage path: %s", st_config.path) + self.writer = None + else: + from trinity.buffer.writer.file_writer import JSONWriter + + st_config.storage_type = StorageType.FILE + self.writer = JSONWriter(st_config, self.config) + self.logger.warning(f"Save experiences in {st_config.path}.") + self.ref_count = 0 + self.exp_pool = deque() # A pool to store experiences + self.closed = False + + async def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + async def release(self) -> int: + """Release the queue.""" + self.ref_count -= 1 + if self.ref_count <= 0: + await self.queue.close() + await self.writer.release() + return self.ref_count + + def length(self) -> int: + """The length of the queue.""" + return self.queue.qsize() + + async def put_batch(self, exp_list: List) -> None: + """Put batch of experience.""" + await self.queue.put(exp_list) + if self.writer is not None: + self.writer.write(exp_list) + + async def get_batch(self, batch_size: int, timeout: float) -> List: + """Get batch of experience.""" + start_time = time.time() + while len(self.exp_pool) < batch_size: + if self.queue.stopped(): + # If the queue is stopped, ignore the rest of the experiences in the pool + raise StopAsyncIteration("Queue is closed and no more items to get.") + try: + exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0) + self.exp_pool.extend(exp_list) + except asyncio.TimeoutError: + if time.time() - start_time > timeout: + self.logger.error( + f"Timeout when waiting for experience, only get {len(self.exp_pool)} experiences.\n" + "This phenomenon is usually caused by the workflow not returning enough " + "experiences or running timeout. Please check your workflow implementation." + ) + batch = list(self.exp_pool) + self.exp_pool.clear() + return batch + return [self.exp_pool.popleft() for _ in range(batch_size)] + + @classmethod + def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + """Get the queue actor.""" + return ( + ray.remote(cls) + .options( + name=f"queue-{storage_config.name}", + namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, + get_if_exists=True, + ) + .remote(storage_config, config) + ) diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py new file mode 100644 index 0000000000..cbbdce4c53 --- /dev/null +++ b/trinity/buffer/storage/sql.py @@ -0,0 +1,223 @@ +"""SQL database storage""" + +import time +from abc import abstractmethod +from typing import Dict, List, Optional + +import ray +from datasets import Dataset +from sqlalchemy import asc +from sqlalchemy.orm import sessionmaker + +from trinity.buffer.schema import init_engine +from trinity.buffer.schema.formatter import FORMATTER, TaskFormatter +from trinity.buffer.utils import default_storage_path, retry_session +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.experience import Experience +from trinity.common.rewards import REWARD_FUNCTIONS +from trinity.common.workflows import WORKFLOWS, Task +from trinity.utils.log import get_logger + + +class SQLStorage: + """ + An Storage based on SQL Database. + + If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor, + and provide a remote interface to the local database. + + For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), please + set `wrap_in_ray` to `True`. + """ + + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + self.logger = get_logger(f"sql_{storage_config.name}", in_ray_actor=True) + if storage_config.path is None: + storage_config.path = default_storage_path(storage_config, config) + self.engine, self.table_model_cls = init_engine( + db_url=storage_config.path, + table_name=storage_config.name, + schema_type=storage_config.schema_type, + ) + self.logger.info(f"Init SQL storage at {storage_config.path}") + self.session = sessionmaker(bind=self.engine) + self.max_retry_times = storage_config.max_retry_times + self.max_retry_interval = storage_config.max_retry_interval + self.ref_count = 0 + self.stopped = False + # Assume that the auto-increment ID starts counting from 1, so the default offset should be 0. + self.offset = storage_config.index + + @classmethod + def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + if storage_config.schema_type is None: + storage_cls = SQLTaskStorage + else: + storage_cls = SQLExperienceStorage + if storage_config.wrap_in_ray: + return ( + ray.remote(storage_cls) + .options( + name=f"sql-{storage_config.name}", + namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, + get_if_exists=True, + max_concurrency=5, + ) + .remote(storage_config, config) + ) + else: + return storage_cls(storage_config, config) + + @abstractmethod + def write(self, data: List) -> None: + """Write a batch of data.""" + + @abstractmethod + def read(self, batch_size: Optional[int] = None) -> List: + """Read a batch of data.""" + + def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + def release(self) -> int: + self.ref_count -= 1 + if self.ref_count <= 0: + self.stopped = True + return self.ref_count + + +class SQLExperienceStorage(SQLStorage): + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + super().__init__(storage_config, config) + self.batch_size = config.train_batch_size + self.max_timeout = storage_config.max_read_timeout + + def write(self, data: List[Experience]) -> None: + with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: + experience_models = [self.table_model_cls.from_experience(exp) for exp in data] + session.add_all(experience_models) + + def read(self, batch_size: Optional[int] = None) -> List[Experience]: + if self.stopped: + raise StopIteration() + + exp_list = [] + batch_size = batch_size or self.batch_size # type: ignore + start_time = time.time() + while len(exp_list) < batch_size: + if self.stopped: + raise StopIteration() + if len(exp_list): + self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") + time.sleep(1) + if time.time() - start_time > self.max_timeout: + self.logger.warning( + f"Max read timeout reached ({self.max_timeout} s), only get {len(exp_list)} experiences, stopping..." + ) + raise StopIteration() + with retry_session( + self.session, self.max_retry_times, self.max_retry_interval + ) as session: + # get a batch of experiences from the database + experiences = ( + session.query(self.table_model_cls) + .filter(self.table_model_cls.id > self.offset) + .order_by(asc(self.table_model_cls.id)) + .limit(batch_size - len(exp_list)) + .all() + ) + if experiences: + self.offset = experiences[-1].id + start_time = time.time() + exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) + return exp_list + + @classmethod + def load_from_dataset( + cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig + ) -> "SQLExperienceStorage": + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) + storage = cls( + storage_config=storage_config, + config=config, + ) + formatter = FORMATTER.get(storage_config.schema_type)(tokenizer, storage_config.format) + batch_size = storage.batch_size + batch = [] + for item in dataset: + batch.append(formatter.format(item)) + if len(batch) >= batch_size: + storage.write(batch) + batch.clear() + if batch: + storage.write(batch) + return storage + + +class SQLTaskStorage(SQLStorage): + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + super().__init__(storage_config, config) + self.batch_size = config.batch_size + self.is_eval = storage_config.is_eval + self.default_workflow_cls = WORKFLOWS.get(storage_config.default_workflow_type) # type: ignore + if self.is_eval and storage_config.default_eval_workflow_type: + self.default_workflow_cls = WORKFLOWS.get(storage_config.default_eval_workflow_type) + self.default_reward_fn_cls = REWARD_FUNCTIONS.get(storage_config.default_reward_fn_type) # type: ignore + self.formatter = TaskFormatter(storage_config) + self.offset = storage_config.index + if storage_config.total_steps: + self.total_samples = self.batch_size * storage_config.total_steps + else: + if storage_config.total_epochs > 1: + self.logger.warning( + f"SQL Storage do not support total_epochs, the value {storage_config.total_epochs} will be ignored" + ) + self.total_samples = float("inf") + + def write(self, data: List[Dict]) -> None: + with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: + tasks = [self.table_model_cls.from_dict(item) for item in data] + session.add_all(tasks) + + def read(self, batch_size: Optional[int] = None) -> List[Task]: + if self.stopped: + raise StopIteration() + if self.offset > self.total_samples: + raise StopIteration() + batch_size = batch_size or self.batch_size + with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: + query = ( + session.query(self.table_model_cls) + .filter(self.table_model_cls.id > self.offset) + .order_by(asc(self.table_model_cls.id)) + .limit(batch_size) + ) + results = query.all() + if len(results) == 0: + raise StopIteration() + if not self.is_eval and len(results) < batch_size: + raise StopIteration() + self.offset = results[-1].id + return [self.formatter.format(item.raw_task) for item in results] + + @classmethod + def load_from_dataset( + cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig + ) -> "SQLTaskStorage": + storage = cls( + storage_config=storage_config, + config=config, + ) + batch_size = config.batch_size + batch = [] + for item in dataset: + batch.append(item) + if len(batch) >= batch_size: + storage.write(batch) + batch.clear() + if batch: + storage.write(batch) + return storage diff --git a/trinity/buffer/utils.py b/trinity/buffer/utils.py index aa2c0e4849..7db8dde867 100644 --- a/trinity/buffer/utils.py +++ b/trinity/buffer/utils.py @@ -6,18 +6,19 @@ from trinity.common.constants import StorageType from trinity.utils.log import get_logger -logger = get_logger(__name__) - @contextmanager def retry_session(session_maker, max_retry_times: int, max_retry_interval: float): """A Context manager for retrying session.""" + logger = get_logger(__name__) for attempt in range(max_retry_times): try: session = session_maker() yield session session.commit() break + except StopIteration as e: + raise e except Exception as e: import traceback diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index 93c10479ca..5a579bf59c 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -3,7 +3,7 @@ import ray from trinity.buffer.buffer_writer import BufferWriter -from trinity.buffer.ray_wrapper import FileWrapper +from trinity.buffer.storage.file import FileStorage from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType @@ -11,7 +11,7 @@ class JSONWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.FILE - self.writer = FileWrapper.get_wrapper(meta, config) + self.writer = FileStorage.get_wrapper(meta, config) self.wrap_in_ray = meta.wrap_in_ray def write(self, data: List) -> None: diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 27dc6df28a..7e4f4a9ca1 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -4,7 +4,7 @@ import ray from trinity.buffer.buffer_writer import BufferWriter -from trinity.buffer.ray_wrapper import QueueWrapper +from trinity.buffer.storage.queue import QueueStorage from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType @@ -15,7 +15,7 @@ class QueueWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.QUEUE self.config = config - self.queue = QueueWrapper.get_wrapper(meta, config) + self.queue = QueueStorage.get_wrapper(meta, config) def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index a951201b80..9ffaa13adb 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -3,7 +3,7 @@ import ray from trinity.buffer.buffer_writer import BufferWriter -from trinity.buffer.ray_wrapper import DBWrapper +from trinity.buffer.storage.sql import SQLStorage from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType @@ -15,7 +15,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL # we only support write RFT algorithm buffer for now self.wrap_in_ray = meta.wrap_in_ray - self.db_wrapper = DBWrapper.get_wrapper(meta, config) + self.db_wrapper = SQLStorage.get_wrapper(meta, config) def write(self, data: list) -> None: if self.wrap_in_ray: diff --git a/trinity/common/config.py b/trinity/common/config.py index 911885fe50..1382a561dc 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -16,7 +16,6 @@ StorageType, SyncMethod, SyncStyle, - TaskType, ) from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger @@ -28,6 +27,7 @@ class FormatConfig: """Configuration for data formatting""" + # for sft / dpo prompt_type: PromptType = PromptType.MESSAGES # for plaintext input @@ -73,18 +73,12 @@ class StorageConfig: path: Optional[str] = None repeat_times: Optional[int] = None - # only available for StorageType.FILE. When requiring data processing on raw data, set the raw to True. - raw: bool = False - # used for StorageType.FILE split: str = "train" subset_name: Optional[str] = None format: FormatConfig = field(default_factory=FormatConfig) index: int = 0 - # used for StorageType.SQL/FILE - wrap_in_ray: bool = True - # used for StorageType.QUEUE capacity: int = 10000 max_read_timeout: float = 1800 @@ -94,6 +88,10 @@ class StorageConfig: default_factory=lambda: {"priority_fn": "linear_decay", "decay": 0.1} ) + # used for StorageType.SQL + max_retry_times: int = 3 + max_retry_interval: int = 1 + # used for rollout tasks default_workflow_type: Optional[str] = None default_eval_workflow_type: Optional[str] = None @@ -108,8 +106,11 @@ class StorageConfig: # get storage from existing experiment ray_namespace: Optional[str] = None - # ! DO NOT SET, automatically set from algorithm.algorithm_type - algorithm_type: Optional[str] = None + # ! DO NOT SET except you know what you are doing + wrap_in_ray: bool = True + + # ! DO NOT SET, automatically set + schema_type: Optional[str] = None # ! DO NOT SET, automatically set from buffer.total_epochs total_epochs: int = 1 # automatically set @@ -118,7 +119,7 @@ class StorageConfig: total_steps: Optional[int] = None # automatically set # ! DO NOT SET, automatically set corresponding to train/eval - task_type: TaskType = TaskType.EXPLORE + is_eval: bool = False @dataclass @@ -349,10 +350,6 @@ class BufferConfig: # for trainer trainer_input: TrainerInput = field(default_factory=TrainerInput) - # for storage connection - max_retry_times: int = 3 - max_retry_interval: int = 1 - # ! DO NOT SET FOLLOWING FIELDS explorer_output: Optional[StorageConfig] = None # automatically set tokenizer_path: Optional[str] = None # automatically set @@ -551,7 +548,7 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.trainer_input.experience_buffer.total_epochs = self.buffer.total_epochs self.buffer.trainer_input.experience_buffer.total_steps = self.buffer.total_steps else: - self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE + self.buffer.explorer_input.taskset.is_eval = False self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps if self.buffer.explorer_input.taskset.default_workflow_type is None: @@ -582,7 +579,7 @@ def _check_buffer(self) -> None: # noqa: C901 if not dataset.path: logger.warning(f"Eval dataset [{dataset}]'s path is not configured. Skip.") continue - dataset.task_type = TaskType.EVAL + dataset.is_eval = True if not dataset.name: dataset.name = f"eval_taskset_{idx}" if dataset.repeat_times is None: @@ -623,9 +620,11 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.trainer_input.experience_buffer.storage_type = StorageType.QUEUE if self.buffer.trainer_input.experience_buffer is not None: - self.buffer.trainer_input.experience_buffer.algorithm_type = ( + from trinity.algorithm.algorithm import ALGORITHM_TYPE + + self.buffer.trainer_input.experience_buffer.schema_type = ALGORITHM_TYPE.get( self.algorithm.algorithm_type - ) + ).schema if self.buffer.trainer_input.experience_buffer.ray_namespace is None: self.buffer.trainer_input.experience_buffer.ray_namespace = self.ray_namespace @@ -648,7 +647,7 @@ def _check_buffer(self) -> None: # noqa: C901 "`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0" ) if self.buffer.trainer_input.sft_warmup_dataset is not None: - self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO + self.buffer.trainer_input.sft_warmup_dataset.schema_type = "sft" self.buffer.trainer_input.sft_warmup_dataset.total_steps = ( self.buffer.trainer_input.sft_warmup_steps ) diff --git a/trinity/common/constants.py b/trinity/common/constants.py index c5a0d76af0..531d113965 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -48,13 +48,6 @@ class PromptType(CaseInsensitiveEnum): PLAINTEXT = "plaintext" # user prompt text and assistant response text -class TaskType(Enum): - """Task Type.""" - - EXPLORE = 0 - EVAL = 1 - - class StorageType(CaseInsensitiveEnum): """Storage Type.""" @@ -68,6 +61,7 @@ class MonitorType(CaseInsensitiveEnum): WANDB = "wandb" TENSORBOARD = "tensorboard" + MLFLOW = "mlflow" class SyncMethodEnumMeta(CaseInsensitiveEnumMeta): diff --git a/trinity/common/experience.py b/trinity/common/experience.py index d4d9aa6cc2..3fbcd6df72 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -98,6 +98,7 @@ class CustomField: class Experience: eid: EID = field(default_factory=EID) # Unique identifier for the experience tokens: Optional[Tensor] = None # [seq_length] + prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks logprobs: Optional[Tensor] = None # [resp_length] reward: Optional[float] = None advantages: Optional[Tensor] = None # [resp_length] @@ -110,12 +111,11 @@ class Experience: ) # Metrics associated with the experience, directly used by the monitor # for single-turn experiences - prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks response_text: Optional[str] = None # Text of the response prompt_text: Optional[str] = None # Text of the prompt # for multi-turn experiences - # Action mask which indicates which tokens are generated by the model + # Action mask indicates which tokens are generated by the model action_mask: Optional[Tensor] = None # [resp_length] messages: Optional[List[dict]] = None # List of messages tools: Optional[List[dict]] = None @@ -123,8 +123,8 @@ class Experience: # for dpo experiences chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length] - chosen_text: Optional[str] = None # Text of the chosen response - rejected_text: Optional[str] = None # Text of the rejected response + chosen_messages: Optional[List[dict]] = None # Chosen message list (Include prompt message) + rejected_messages: Optional[List[dict]] = None # Rejected message list (Include prompt message) def __init__( # noqa: C901 self, @@ -145,8 +145,8 @@ def __init__( # noqa: C901 tools=None, chosen=None, rejected=None, - chosen_text=None, - rejected_text=None, + chosen_messages=None, + rejected_messages=None, ): if action_mask is not None: experience_type = "multi_turn" @@ -201,8 +201,8 @@ def __init__( # noqa: C901 if isinstance(rejected, list): rejected = torch.tensor(rejected, dtype=torch.int32) self.rejected = rejected - self.chosen_text = chosen_text - self.rejected_text = rejected_text + self.chosen_messages = chosen_messages + self.rejected_messages = rejected_messages if not isinstance(self.tokens, Tensor): self.tokens = torch.tensor(self.tokens) @@ -241,10 +241,10 @@ def to_dict(self) -> dict: res["messages"] = self.messages if self.tools is not None: res["tools"] = self.tools - if self.chosen_text is not None: - res["chosen_text"] = self.chosen_text - if self.rejected_text is not None: - res["rejected_text"] = self.rejected_text + if self.chosen_messages is not None: + res["chosen_messages"] = self.chosen_messages + if self.rejected_messages is not None: + res["rejected_messages"] = self.rejected_messages if self.reward is not None: res["reward"] = float(self.reward) return res @@ -343,7 +343,7 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, - response_text=exp.chosen_text, + messages=exp.chosen_messages, ) ) single_turn_experiences.append( @@ -360,7 +360,7 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, - response_text=exp.rejected_text, + messages=exp.rejected_messages, ) ) return single_turn_experiences diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 9a1da4ca49..c7808ae732 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -24,7 +24,7 @@ class Task(dict): """A Task class that defines a task and its associated reward function / workflow.""" - workflow: Type[Workflow] + workflow: Type[Workflow] = None repeat_times: Optional[int] = None format_args: FormatConfig = field(default_factory=FormatConfig) rollout_args: GenerationConfig = field(default_factory=GenerationConfig) diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index f241482dca..fa3c4da524 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -519,11 +519,11 @@ def _gen_buffer_config(self): "name": "experience_buffer", "storage_type": st.session_state["storage_type"], "path": experience_buffer_path, + "max_retry_interval": st.session_state["max_retry_interval"], + "max_retry_times": st.session_state["buffer_max_retry_times"], }, "sft_warmup_steps": st.session_state["sft_warmup_steps"], }, - "max_retry_times": st.session_state["buffer_max_retry_times"], - "max_retry_interval": st.session_state["max_retry_interval"], } if st.session_state["algorithm_type"] != "dpo": experience_buffer = buffer_config["trainer_input"]["experience_buffer"]