diff --git a/docs/sphinx_doc/assets/DYN-NCCL.png b/docs/sphinx_doc/assets/DYN-NCCL.png deleted file mode 100644 index 0818715d910..00000000000 Binary files a/docs/sphinx_doc/assets/DYN-NCCL.png and /dev/null differ diff --git a/docs/sphinx_doc/assets/DYN-STATEDICT.png b/docs/sphinx_doc/assets/DYN-STATEDICT.png deleted file mode 100644 index 5c7fe9a9999..00000000000 Binary files a/docs/sphinx_doc/assets/DYN-STATEDICT.png and /dev/null differ diff --git a/docs/sphinx_doc/assets/FIXED-NCCL.png b/docs/sphinx_doc/assets/FIXED-NCCL.png deleted file mode 100644 index cf4132f7f93..00000000000 Binary files a/docs/sphinx_doc/assets/FIXED-NCCL.png and /dev/null differ diff --git a/docs/sphinx_doc/assets/FIXED-STATEDICT.png b/docs/sphinx_doc/assets/FIXED-STATEDICT.png deleted file mode 100644 index f908c113037..00000000000 Binary files a/docs/sphinx_doc/assets/FIXED-STATEDICT.png and /dev/null differ diff --git a/docs/sphinx_doc/assets/NCCL-en.png b/docs/sphinx_doc/assets/NCCL-en.png new file mode 100644 index 00000000000..d8450e1bab9 Binary files /dev/null and b/docs/sphinx_doc/assets/NCCL-en.png differ diff --git a/docs/sphinx_doc/assets/NCCL-zh.png b/docs/sphinx_doc/assets/NCCL-zh.png new file mode 100644 index 00000000000..996eaa2471c Binary files /dev/null and b/docs/sphinx_doc/assets/NCCL-zh.png differ diff --git a/docs/sphinx_doc/assets/STATEDICT-en.png b/docs/sphinx_doc/assets/STATEDICT-en.png new file mode 100644 index 00000000000..b55047c55df Binary files /dev/null and b/docs/sphinx_doc/assets/STATEDICT-en.png differ diff --git a/docs/sphinx_doc/assets/STATEDICT-zh.png b/docs/sphinx_doc/assets/STATEDICT-zh.png new file mode 100644 index 00000000000..ab414b01db1 Binary files /dev/null and b/docs/sphinx_doc/assets/STATEDICT-zh.png differ diff --git a/docs/sphinx_doc/source/tutorial/develop_operator.md b/docs/sphinx_doc/source/tutorial/develop_operator.md index edcaa324a5c..3a022d1c9ec 100644 --- a/docs/sphinx_doc/source/tutorial/develop_operator.md +++ b/docs/sphinx_doc/source/tutorial/develop_operator.md @@ -71,7 +71,7 @@ data_processor: threshold: 0.1 synchronizer: sync_method: nccl - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_interval: 2 # some other configs ``` diff --git a/docs/sphinx_doc/source/tutorial/example_react.md b/docs/sphinx_doc/source/tutorial/example_react.md index eeaad84a04f..1d8862a242a 100644 --- a/docs/sphinx_doc/source/tutorial/example_react.md +++ b/docs/sphinx_doc/source/tutorial/example_react.md @@ -155,7 +155,7 @@ Since agent applications may have variable interaction rounds and sample counts, ```yaml synchronizer: - sync_style: dynamic_by_explorer # Trainer starts training immediately when enough data is generated, rather than padding to a fixed size, improving efficiency + sync_style: explorer_driven # Trainer starts training immediately when enough data is generated, rather than padding to a fixed size, improving efficiency sync_interval: 2 # Check for model parameter updates after every two batches ``` diff --git a/docs/sphinx_doc/source/tutorial/example_step_wise.md b/docs/sphinx_doc/source/tutorial/example_step_wise.md index 0fa8fe43735..1eff185c3a5 100644 --- a/docs/sphinx_doc/source/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source/tutorial/example_step_wise.md @@ -69,7 +69,7 @@ In general multi-step scenarios, each run may generate various number of experie - `buffer.trainer_input.experience_buffer.replay_buffer`: Using `PriorityQueue` allows the model to use the experiences with higher priority, which prefers newly-generated experiences by default. -- `synchronizer.sync_style = dynamic_by_explorer`: The explorer determines when to synchronize the model weights with the trainer. +- `synchronizer.sync_style = explorer_driven`: The explorer determines when to synchronize the model weights with the trainer. The example configuration is shown as: @@ -134,7 +134,7 @@ explorer: env_vars: TMPDIR: ${oc.env:TMPDIR,/tmp} synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 2 sync_timeout: 3600 diff --git a/docs/sphinx_doc/source/tutorial/synchronizer.md b/docs/sphinx_doc/source/tutorial/synchronizer.md index 0023e87da2d..bc56575e1b7 100644 --- a/docs/sphinx_doc/source/tutorial/synchronizer.md +++ b/docs/sphinx_doc/source/tutorial/synchronizer.md @@ -29,21 +29,27 @@ To achieve this, the Synchronizer: async def train(self) -> str: while self.train_step_num < self.total_steps: try: + metrics = {} # sample may be blocked due to explorer does not generate enough data self.logger.info(f"Sample data for step {self.train_step_num + 1} started.") sample_task = asyncio.create_task(self._sample_data()) while not sample_task.done(): # sync weight to make sure the explorer can continue to explore and generate enough data if await self.need_sync(): - # Currently, we do not record the metrics of sync_weight here - await self.sync_weight() + metrics.update(await self.sync_weight()) await asyncio.sleep(1) - exps, metrics, repr_samples = await sample_task + exps, sample_metrics, repr_samples = await sample_task + metrics.update(sample_metrics) self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.") metrics.update(await self.train_step(exps)) if await self.need_sync(): metrics.update(await self.sync_weight()) - # ... + if self.need_save(): + metrics.update( + await self.save_checkpoint(save_as_hf=self.save_hf_checkpoint == "always") + ) + if self.config.trainer.enable_preview: + self._log_experiences(repr_samples) self.monitor.log(metrics, self.train_step_num) except StopAsyncIteration: self.logger.info("No more samples to train. Stopping training.") @@ -145,27 +151,32 @@ There are **two synchronization styles** that define *when* the Explorer request | `interval=10, offset=0` | Sync every 10 steps (both start together) | | `interval=10, offset=5` | Explorer runs 5 steps first, then sync every 10 steps | -✅ **Best for**: Simple, predictable environments where exploration steps are short and rewards are frequent (e.g., mathematical reasoning tasks). - -> 🔁 Think of it as a metronome — steady and regular. +🎯 **Best for**: Simple, predictable environments with short exploration episodes and frequent rewards (e.g., mathematical reasoning tasks). --- -### 2. `SyncStyle.DYNAMIC_BY_EXPLORER` – Demand-Driven Sync +### 2. `SyncStyle.EXPLORER_DRIVEN` – Explorer-Driven Synchronization + +- The Explorer itself decides when it needs a new model. +- Workflow: + 1. After completing `sync_interval` steps, the Explorer sends a request to the Synchronizer to update its parameters. + 2. The Trainer detects this request in its next loop iteration and performs the synchronization. + 3. Once synchronization completes, both the Explorer and Trainer continue running. + 4. If a timeout occurs, the Explorer retries in the next cycle. -- Explorer decides to request a sync after generating a certain amount of data. -- It tells Synchronizer: _"I’m ready for a new model!"_ -- Trainer checks this request during its normal loop and responds accordingly. +🎯 **Best for**: Scenarios where the Explorer’s pace is irregular or when on-demand model updates are preferred. -📌 **Process Flow**: -1. Explorer finishes `N` steps → sets state to `REQUIRE_SYNC`. -2. Waits for Trainer to acknowledge and perform sync. -3. Once synced, returns to `RUNNING`. -4. If timeout occurs, retries on next step. +--- -✅ **Best for**: Complex, long-horizon tasks where data generation is expensive or variable (e.g., multi-turn dialogue, game playing). +### 3. `SyncStyle.TRAINER_DRIVEN` – Trainer-Driven Synchronization -> 🔄 More flexible — adapts to actual data throughput. +- The Trainer determines when to release a new model. +- Workflow: + 1. Every `sync_interval` steps, the Trainer decides to request synchronization. + 2. It notifies the Synchronizer to prepare pushing the new model. + 3. The Explorer detects this request during its normal loop and responds by performing synchronization. + +🎯 **Best for**: Cases where the Trainer has a clear, consistent training rhythm, and the Explorer passively receives updates. --- @@ -173,49 +184,33 @@ There are **two synchronization styles** that define *when* the Explorer request The Synchronizer tracks the **state** of both Trainer and Explorer to manage synchronization safely. -### Four Key States +### Three Key States | State | Meaning | |------|--------| | `STOPPED` | Component has stopped working | | `RUNNING` | Actively training or exploring | -| `REQUIRE_SYNC` | Explorer wants new weights | -| `WAITING_SYNC` | Explorer or Trainer is waiting synchronization (used in NCCL mode) | +| `REQUIRE_SYNC` | Explorer / Trainer requests new weights | These states help prevent race conditions and ensure smooth coordination. --- -### State Transitions by Style & Method - -#### 🔹 Fixed Style + NCCL Sync -- Synchronizer schedules sync every `N` steps. -- Both sides pause briefly for direct GPU sync. -- The state of the trainer toggles predictably between `RUNNING` ↔ `WAITING_SYNC`, and the state of the explorer toggles among `RUNNING` → `REQUIRE_SYNC` → `WAITING_SYNC`. +### State Transitions Across Different Modes and Methods -![FIXED_STYLE_NCCL_SYNC](../../assets/FIXED-NCCL.png) +#### 🔹 NCCL Synchronization +- Both Trainer and Explorer toggle states (`RUNNING` ↔ `REQUIRE_SYNC`). +- Synchronization uses a "two-way handshake": data transfer only begins once both sides are ready. +- After synchronization completes, both return to `RUNNING`. -#### 🔹 Fixed Style + CHECKPOINT/MEMORY -- Trainer saves or sends weights periodically. -- Explorer checks at each interval and pulls updates. -- The state of the trainer remains at `RUNNING`, and the state of the explorer toggles between `RUNNING` ↔ `REQUIRE_SYNC`. +![NCCL Synchronization](../../assets/NCCL-en.png) -![FIXED_STYLE_STATEDICT_SYNC](../../assets/FIXED-STATEDICT.png) +#### 🔹 CHECKPOINT/MEMORY Synchronization +- The Trainer typically remains in `RUNNING` state (it only saves weights). +- The Explorer initiates the sync request (switches to `REQUIRE_SYNC`), pulls the weights, then returns to `RUNNING`. +- The Synchronizer acts as an intermediary, delivering model weights to the Explorer. - -#### 🔹 Dynamic Style + NCCL -- Explorer signals `REQUIRE_SYNC` after enough data. -- Trainer sees the signal and initiates NCCL sync. -- The state of the trainer toggles predictably between `RUNNING` ↔ `WAITING_SYNC`, and the state of the explorer toggles between `RUNNING` → `REQUIRE_SYNC` → `WAITING_SYNC`. - -![DYN_STYLE_NCCL_SYNC](../../assets/DYN-NCCL.png) - -#### 🔹 Dynamic Style + CHECKPOINT/MEMORY -- Explorer signals `REQUIRE_SYNC` after enough data. -- Trainer sees the signal and pushes weights to synchronizer. -- The state of the trainer remains at `RUNNING`, and the state of the explorer toggles between `RUNNING` ↔ `REQUIRE_SYNC`. - -![DYN_STYLE_STATEDICT_SYNC](../../assets/DYN-STATEDICT.png) +![CHECKPOINT/MEMORY Synchronization](../../assets/STATEDICT-en.png) --- @@ -240,9 +235,7 @@ These states help prevent race conditions and ensure smooth coordination. | Use Case | Recommended Style | |--------|------------------| | Short episodes, quick feedback (e.g., math QA) | `FIXED` | -| Long interactions, delayed rewards (e.g., games, conversations) | `DYNAMIC_BY_EXPLORER` | - -> 💡 `DYNAMIC_BY_EXPLORER` gives more control to the data-generating side, making it better for unbalanced or variable workloads. +| Multi-turn interactive tasks, such as multi-round dialogues, tool usage, or multi-step games | `EXPLORER_DRIVEN` or `TRAINER_DRIVEN` | --- diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 12c5a337b1e..6d65dd0a5d6 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -465,7 +465,8 @@ synchronizer: - `sync_timeout`: Timeout duration for synchronization. - `sync_style`: Style of synchronization. Options: - `fixed`: The explorer and trainer synchronize weights every `sync_interval` steps. - - `dynamic_by_explorer`: The explorer notifies the trainer to synchronize weights after completing `sync_interval` steps, regardless of how many steps the trainer has completed at this point. + - `explorer_driven`: The explorer notifies the trainer to synchronize weights after completing `sync_interval` steps, regardless of how many steps the trainer has completed at this point. + - `trainer_driven`: The trainer notifies the explorer to synchronize weights after completing `sync_interval` steps, regardless of how many steps the explorer has completed at this point. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_operator.md b/docs/sphinx_doc/source_zh/tutorial/develop_operator.md index 0932546764f..8a662d84c19 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_operator.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_operator.md @@ -80,12 +80,12 @@ data_processor: threshold: 0.1 synchronizer: sync_method: nccl - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_interval: 2 # some other configs ``` ```{tip} -`RewardFilter` 会减少 experience 数量,可能导致 Trainer 无法获得足够的 experience 来启动训练流程。为避免此问题,你可以使用 Trinity-RFT 提供的 {ref}`动态同步 ` 功能 (`dynamic_by_explorer`)。 +`RewardFilter` 会减少 experience 数量,可能导致 Trainer 无法获得足够的 experience 来启动训练流程。为避免此问题,你可以使用 Trinity-RFT 提供的 {ref}`动态同步 ` 功能 (`explorer_driven`)。 上述设置意味着 `Explorer` 每运行 2 步就会与 `Trainer` 同步一次,且无论 `Trainer` 当前完成了多少步都会继续运行。这确保了只要 `Explorer` 在运行,`Trainer` 就总能获得足够的 experience 来启动训练步骤。 ``` diff --git a/docs/sphinx_doc/source_zh/tutorial/example_react.md b/docs/sphinx_doc/source_zh/tutorial/example_react.md index ec71d455e21..2b1ee7c8a9d 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_react.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_react.md @@ -162,7 +162,7 @@ algorithm: ```yaml synchronizer: - sync_style: dynamic_by_explorer # 当产生足够训练数据时,trainer 立即启动训练任务,而不是将生成的数据补齐到一个固定规模,能够有效提升训练效率 + sync_style: explorer_driven # 当产生足够训练数据时,trainer 立即启动训练任务,而不是将生成的数据补齐到一个固定规模,能够有效提升训练效率 sync_interval: 2 # 每执行两个批次的任务后检查是否需要同步更新模型参数 ``` diff --git a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md index c7e49fd07d3..035e3ab044c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md @@ -68,7 +68,7 @@ WORKFLOWS = Registry( - `buffer.trainer_input.experience_buffer.replay_buffer`:使用 `PriorityQueue` 可使模型优先使用高优先级的 experience (默认为使用更新产生的 experience)。 -- `synchronizer.sync_style = dynamic_by_explorer`:由 explorer 决定何时与 trainer 同步模型权重。 +- `synchronizer.sync_style = explorer_driven`:由 explorer 决定何时与 trainer 同步模型权重。 示例配置如下所示: @@ -129,7 +129,7 @@ explorer: env_vars: TMPDIR: ${oc.env:TMPDIR,/tmp} synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 2 sync_timeout: 3600 diff --git a/docs/sphinx_doc/source_zh/tutorial/synchronizer.md b/docs/sphinx_doc/source_zh/tutorial/synchronizer.md index 469a384cbcd..90e114f018f 100644 --- a/docs/sphinx_doc/source_zh/tutorial/synchronizer.md +++ b/docs/sphinx_doc/source_zh/tutorial/synchronizer.md @@ -28,21 +28,27 @@ async def train(self) -> str: while self.train_step_num < self.total_steps: try: + metrics = {} # sample may be blocked due to explorer does not generate enough data self.logger.info(f"Sample data for step {self.train_step_num + 1} started.") sample_task = asyncio.create_task(self._sample_data()) while not sample_task.done(): # sync weight to make sure the explorer can continue to explore and generate enough data if await self.need_sync(): - # Currently, we do not record the metrics of sync_weight here - await self.sync_weight() + metrics.update(await self.sync_weight()) await asyncio.sleep(1) - exps, metrics, repr_samples = await sample_task + exps, sample_metrics, repr_samples = await sample_task + metrics.update(sample_metrics) self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.") metrics.update(await self.train_step(exps)) if await self.need_sync(): metrics.update(await self.sync_weight()) - # ... + if self.need_save(): + metrics.update( + await self.save_checkpoint(save_as_hf=self.save_hf_checkpoint == "always") + ) + if self.config.trainer.enable_preview: + self._log_experiences(repr_samples) self.monitor.log(metrics, self.train_step_num) except StopAsyncIteration: self.logger.info("No more samples to train. Stopping training.") @@ -145,27 +151,30 @@ Explorer 会在以下时机检查是否需要同步: | `interval=10, offset=0` | 每 10 步同步一次(两者同时开始) | | `interval=10, offset=5` | Explorer 先运行 5 步,之后每 10 步同步一次 | -✅ **最适合**:简单、可预测的环境,探索步骤较短且奖励频繁(例如数学推理任务)。 - -> 🔁 可将其类比为节拍器 —— 稳定且规律。 +🎯 **适合**:简单、可预测的环境,探索步骤较短且奖励频繁(例如数学推理任务)。 --- -### 2. `SyncStyle.DYNAMIC_BY_EXPLORER` – 按需动态同步 +### 2. `SyncStyle.EXPLORER_DRIVEN` – Explorer 驱动同步 +- Explorer 自己决定何时需要新模型。 +- 流程: + 1. Explorer 完成 `sync_interval` 步后,向 Synchronizer 发出更新参数的请求。 + 2. Trainer 在下一次循环中发现这个请求,并完成同步。 + 3. 同步完成后,Explorer 和 Trainer 继续运行。 + 4. 若超时,Explorer 会在下一个周期重试。 -- Explorer 在生成一定量数据后决定请求同步。 -- 它会通知 Synchronizer:“我已经准备好获取新模型!” -- Trainer 在正常循环中检测该请求并响应。 +🎯 **适合**:Explorer 节奏不固定,或希望按需更新模型。 -📌 **流程说明**: -1. Explorer 完成 `N` 步后 → 将状态设为 `REQUIRE_SYNC`。 -2. 等待 Trainer 确认并完成同步。 -3. 同步完成后,状态恢复为 `RUNNING`。 -4. 若超时,则在下一步重试。 +--- -✅ **最适合**:复杂、长周期任务,其中数据生成成本高或不规律(例如多轮对话、游戏对战)。 +### 3. `SyncStyle.TRAINER_DRIVEN` – Trainer 驱动同步 +- Trainer 决定何时发布新模型。 +- 流程: + 1. Trainer 每隔 `sync_interval` 步数后决定请求同步。 + 2. 它会通知 Synchronizer 准备推送新模型。 + 3. Explorer 在正常循环中检测该请求并响应同步。 -> 🔄 更加灵活 —— 能根据实际数据产出动态调整。 +🎯 **适合**:Trainer 训练节奏明确,Explorer 被动接收更新。 --- @@ -173,14 +182,13 @@ Explorer 会在以下时机检查是否需要同步: Synchronizer 通过跟踪 Trainer 和 Explorer 的**状态**,确保同步过程安全可控。 -### 四个关键状态 +### 三个关键状态 | 状态 | 含义 | |------|--------| | `STOPPED` | 组件已停止运行 | | `RUNNING` | 正在训练或探索中 | -| `REQUIRE_SYNC` | Explorer 请求新权重 | -| `WAITING_SYNC` | Explorer 或 Trainer 正在等待同步(NCCL 模式下使用) | +| `REQUIRE_SYNC` | Explorer / Trainer 请求新权重 | 这些状态有助于避免竞态条件,保证协调过程平稳。 @@ -188,33 +196,19 @@ Synchronizer 通过跟踪 Trainer 和 Explorer 的**状态**,确保同步过 ### 不同模式与方法下的状态转换 -#### 🔹 固定模式 + NCCL 同步 -- Synchronizer 每 `N` 步安排一次同步。 -- 双方短暂暂停,进行 GPU 直连同步。 -- Trainer 状态在 `RUNNING` ↔ `WAITING_SYNC` 间规律切换,Explorer 状态在 `RUNNING` → `REQUIRE_SYNC` → `WAITING_SYNC` 间切换。 - -![FIXED_STYLE_NCCL_SYNC](../../assets/FIXED-NCCL.png) - -#### 🔹 固定模式 + CHECKPOINT/MEMORY -- Trainer 定期保存或发送权重。 -- Explorer 在每个间隔检查并拉取更新。 -- Trainer 状态保持 `RUNNING`,Explorer 状态在 `RUNNING` ↔ `REQUIRE_SYNC` 间切换。 +#### 🔹 NCCL 同步 +- Trainer 和 Explorer 都会切换状态(`RUNNING` ↔ `REQUIRE_SYNC`)。 +- 同步是“双向握手”:双方都准备好才开始传数据。 +- 同步完成后,双方都回到 `RUNNING`。 -![FIXED_STYLE_STATEDICT_SYNC](../../assets/FIXED-STATEDICT.png) +![NCCL 同步](../../assets/NCCL-zh.png) -#### 🔹 动态模式 + NCCL -- Explorer 在积累足够数据后发出 `REQUIRE_SYNC` 信号。 -- Trainer 检测到信号后启动 NCCL 同步。 -- Trainer 状态在 `RUNNING` ↔ `WAITING_SYNC` 间切换,Explorer 状态在 `RUNNING` → `REQUIRE_SYNC` → `WAITING_SYNC` 间切换。 +#### 🔹 CHECKPOINT/MEMORY 同步 +- Trainer 通常一直保持 `RUNNING`(它只负责存权重)。 +- Explorer 负责发起同步请求(切换到 `REQUIRE_SYNC`),拉取完权重后回到 `RUNNING`。 +- Synchronizer 作为“中介”,负责传递模型权重给 Explorer。 -![DYN_STYLE_NCCL_SYNC](../../assets/DYN-NCCL.png) - -#### 🔹 动态模式 + CHECKPOINT/MEMORY -- Explorer 在积累足够数据后发出 `REQUIRE_SYNC` 信号。 -- Trainer 检测到信号后将权重推送给 Synchronizer。 -- Trainer 状态保持 `RUNNING`,Explorer 状态在 `RUNNING` ↔ `REQUIRE_SYNC` 间切换。 - -![DYN_STYLE_STATEDICT_SYNC](../../assets/DYN-STATEDICT.png) +![CHECKPOINT/MEMORY 同步](../../assets/STATEDICT-zh.png) --- @@ -239,9 +233,7 @@ Synchronizer 通过跟踪 Trainer 和 Explorer 的**状态**,确保同步过 | 使用场景 | 推荐模式 | |--------|------------------| | 短周期任务,反馈迅速(如数学问答) | `FIXED` | -| 长交互任务,奖励延迟(如游戏、对话) | `DYNAMIC_BY_EXPLORER` | - -> 💡 `DYNAMIC_BY_EXPLORER` 将控制权交给数据生成方,更适合负载不均衡或变化较大的任务。 +| 多轮交互任务,例如多轮对话、工具调用、多步骤游戏 | `EXPLORER_DRIVEN` 或 `TRAINER_DRIVEN` | --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 517490b73b9..43feadab382 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -462,7 +462,8 @@ synchronizer: - `sync_timeout`: 同步超时时间。 - `sync_style`: 同步风格。选项: - `fixed`: explorer 和 trainer 每隔 `sync_interval` 步同步一次权重。 - - `dynamic_by_explorer`: explorer 在完成 `sync_interval` 步后通知 trainer 同步权重,而不管此时 trainer 已完成多少步。 + - `explorer_driven`: explorer 在完成 `sync_interval` 步后通知 trainer 同步权重,而不管此时 trainer 已完成多少步。 + - `trainer_driven`: trainer 在完成 `sync_interval` 步后通知 explorer 同步权重,而不管此时 explorer 已完成多少步。 --- diff --git a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml index 7fc2445eaf0..be2e65a9f75 100644 --- a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml +++ b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml @@ -62,7 +62,7 @@ explorer: gpu_memory_utilization: 0.86 seed: 42 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 4 sync_timeout: 1200 diff --git a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml index 30a115cadaa..1f0a3d0947f 100644 --- a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml +++ b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml @@ -62,7 +62,7 @@ explorer: gpu_memory_utilization: 0.86 seed: 42 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 4 sync_timeout: 1200 diff --git a/examples/agentscope_frozenlake/frozenlake_agent.yaml b/examples/agentscope_frozenlake/frozenlake_agent.yaml index 5144d9eccac..eabd6244e4e 100644 --- a/examples/agentscope_frozenlake/frozenlake_agent.yaml +++ b/examples/agentscope_frozenlake/frozenlake_agent.yaml @@ -73,6 +73,6 @@ trainer: synchronizer: sync_method: nccl - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_interval: 1 sync_timeout: 1200 diff --git a/examples/agentscope_react/gsm8k.yaml b/examples/agentscope_react/gsm8k.yaml index 102606c4ba8..d7d510091c9 100644 --- a/examples/agentscope_react/gsm8k.yaml +++ b/examples/agentscope_react/gsm8k.yaml @@ -53,7 +53,7 @@ explorer: dtype: bfloat16 seed: 42 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 2 sync_timeout: 1200 diff --git a/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml b/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml index e03664557ef..d9fa085a963 100644 --- a/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml +++ b/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml @@ -54,7 +54,7 @@ explorer: reasoning_parser: deepseek_r1 enable_thinking: true synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 2 sync_timeout: 1200 diff --git a/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml b/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml index 4715ee9f58f..3f0b9bbff70 100644 --- a/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml +++ b/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml @@ -54,7 +54,7 @@ explorer: reasoning_parser: deepseek_r1 enable_thinking: true synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 2 sync_timeout: 1200 diff --git a/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml b/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml index 4dd3ee76c86..c495ee47574 100644 --- a/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml +++ b/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml @@ -52,7 +52,7 @@ explorer: enable_auto_tool_choice: true tool_call_parser: hermes synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 2 sync_timeout: 1200 diff --git a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml index a6ad09cef82..6545f3356a9 100644 --- a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml +++ b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml @@ -82,7 +82,7 @@ explorer: max_response_tokens: 10240 max_model_len: 32000 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 5 sync_timeout: 3600 diff --git a/examples/grpo_alfworld_general_multi_step/alfworld.yaml b/examples/grpo_alfworld_general_multi_step/alfworld.yaml index 5427b6829ee..591ec78a371 100644 --- a/examples/grpo_alfworld_general_multi_step/alfworld.yaml +++ b/examples/grpo_alfworld_general_multi_step/alfworld.yaml @@ -54,7 +54,7 @@ explorer: env_vars: TMPDIR: ${oc.env:TMPDIR,/tmp} synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 2 sync_timeout: 3600 diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml index 24a58a8c96a..e389e5bb3b0 100644 --- a/examples/grpo_email_search/email_search.yaml +++ b/examples/grpo_email_search/email_search.yaml @@ -103,7 +103,7 @@ explorer: max_response_tokens: 128 max_model_len: 2500 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 5 sync_timeout: 3600 diff --git a/examples/grpo_gsm8k_ruler/README.md b/examples/grpo_gsm8k_ruler/README.md index 12eaf6f7746..6d25a278bb3 100644 --- a/examples/grpo_gsm8k_ruler/README.md +++ b/examples/grpo_gsm8k_ruler/README.md @@ -13,7 +13,7 @@ Some key configs in this example are: * `default_workflow_type`: set to `math_ruler_workflow` * `auxiliary_models`: LLM-as-a-judge for RULER; need to set `max_prompt_tokens`, `max_response_tokens`, `max_model_len` appropriately * `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero) -* `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences +* `sync_style`: use `explorer_driven`, due to filtering of experiences * `lr`: set to small value (2e-6) for stability, as rewards can be noisy diff --git a/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml b/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml index e0297e099b6..fba35885902 100644 --- a/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml +++ b/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml @@ -63,7 +63,7 @@ explorer: max_response_tokens: 12288 max_model_len: 16384 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 5 sync_timeout: 3600 diff --git a/examples/grpo_gsm8k_trainable_ruler/README.md b/examples/grpo_gsm8k_trainable_ruler/README.md index ca52feebd19..7d73f4c592b 100644 --- a/examples/grpo_gsm8k_trainable_ruler/README.md +++ b/examples/grpo_gsm8k_trainable_ruler/README.md @@ -16,7 +16,7 @@ Some key configs in this example are: * `default_workflow_type`: set to `math_trainable_ruler_workflow` * `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero) -* `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences +* `sync_style`: use `explorer_driven`, due to filtering of experiences * `train_batch_size`: set to 960; note that one explore step can generate more than 96 * 8 = 768 experiences * `lr`: set to small value (2e-6) for stability, as rewards can be noisy diff --git a/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml b/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml index 9fc071ba67b..a0fc9adaa97 100644 --- a/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml +++ b/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml @@ -57,7 +57,7 @@ explorer: dtype: bfloat16 seed: 42 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 5 sync_timeout: 3600 diff --git a/examples/grpo_rubric_as_reward/rubric.yaml b/examples/grpo_rubric_as_reward/rubric.yaml index 48e6909ba03..4e5cfd216f1 100644 --- a/examples/grpo_rubric_as_reward/rubric.yaml +++ b/examples/grpo_rubric_as_reward/rubric.yaml @@ -57,7 +57,7 @@ explorer: max_response_tokens: 1024 max_model_len: 20480 synchronizer: - sync_style: dynamic_by_explorer + sync_style: explorer_driven sync_method: 'nccl' sync_interval: 5 sync_timeout: 3600 diff --git a/examples/ppo_countdown_exp_replay/README.md b/examples/ppo_countdown_exp_replay/README.md index aeeefeb9999..bac9af4f4da 100644 --- a/examples/ppo_countdown_exp_replay/README.md +++ b/examples/ppo_countdown_exp_replay/README.md @@ -27,7 +27,7 @@ Important config parameters for experience replay include: * `reuse_cooldown_time`: delay time (in seconds) before putting sample back into the buffer; must be set explicitly * `priority_fn`: name of the priority function * `priority_fn_args`: additional args for the priority function -* `synchronizer.sync_style`: set to `dynamic_by_explorer`, which allows the trainer to run more training steps as long as the priority queue buffer is non-empty +* `synchronizer.sync_style`: set to `explorer_driven`, which allows the trainer to run more training steps as long as the priority queue buffer is non-empty The priority function used in this example is named `decay_limit_randomization`. The logic behind it: diff --git a/examples/ppo_countdown_exp_replay/countdown.yaml b/examples/ppo_countdown_exp_replay/countdown.yaml index c3871bd93df..4b3bd3c44ec 100644 --- a/examples/ppo_countdown_exp_replay/countdown.yaml +++ b/examples/ppo_countdown_exp_replay/countdown.yaml @@ -50,7 +50,7 @@ explorer: seed: 42 synchronizer: sync_method: 'nccl' - sync_style: dynamic_by_explorer # set to "fixed" for baseline + sync_style: explorer_driven # set to "fixed" for baseline sync_interval: 10 sync_timeout: 1200 trainer: diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 66980634663..9eb996166a7 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -811,7 +811,7 @@ async def test_over_rollout_min_wait(self): self.config.explorer.over_rollout.wait_after_min = 3 self.config.explorer.max_repeat_times_per_runner = None self.config.buffer.batch_size = 4 - self.config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + self.config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN self.config.check_and_update() scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index f0d5415863e..642aa0ef603 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -9,9 +9,11 @@ import unittest from copy import deepcopy from datetime import datetime +from multiprocessing import Process from typing import Dict, List import ray +import torch from parameterized import parameterized_class from tests.tools import ( @@ -21,83 +23,82 @@ get_template_config, get_unittest_dataset_config, ) -from trinity.algorithm import ALGORITHM_TYPE from trinity.cli.launcher import both, explore, train from trinity.common.config import Config, ExperienceBufferConfig -from trinity.common.constants import StorageType, SyncMethod, SyncStyle +from trinity.common.constants import RunningStatus, StorageType, SyncMethod, SyncStyle +from trinity.common.experience import Experience from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger logger = get_logger(__name__) -CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir") -def trainer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): - async def new_sample_data(self): - self.logger.info(f"Sample data for step {self.train_step_num + 1} started.") - await asyncio.sleep(0.1) - time.sleep(intervals[self.engine.global_steps - 1]) - self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.") - return [], {}, [] - - async def new_train_step(self, exps) -> Dict: - self.engine.algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type) +def trainer_monkey_patch(train_step_time_list: List[int]): + async def new_train_step(self: Trainer, exps) -> Dict: self.engine.global_steps += 1 self.logger.info(f"Training at step {self.engine.global_steps} started.") - await asyncio.sleep(0.1) - time.sleep(intervals[self.engine.global_steps - 1]) + time.sleep(train_step_time_list[self.engine.global_steps - 1]) metrics = {"actor/step": self.engine.global_steps} self.logger.info(f"Training at step {self.engine.global_steps} finished.") return metrics Trainer.train_step = new_train_step - Trainer._sample_data = new_sample_data -def explorer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): - async def new_explore_step(self): - if self.explore_step_num == max_steps: - await self.save_checkpoint(sync_weight=False) +def explorer_monkey_patch(explore_step_time_list: List[int]): + async def new_explore_step(self: Explorer): + if self.explore_step_num >= len(explore_step_time_list): + await self.finish_current_steps() + await self.save_checkpoint() + await self.synchronizer.set_explorer_status.remote( + RunningStatus.STOPPED, + old_status=RunningStatus.RUNNING, + ) + await self.shutdown() + return False self.explore_step_num += 1 - return self.explore_step_num <= max_steps - - def wrapper(old_save_checkpoint): - async def new_save_checkpoint(self, sync_weight: bool = False): - await asyncio.sleep(intervals.pop(0)) - await old_save_checkpoint(self, sync_weight) - - return new_save_checkpoint + return True - async def new_finish_explore_step(self, step: int, model_version: int) -> None: + async def new_finish_explore_step(self: Explorer, step: int, model_version: int) -> None: metric = {"rollout/model_version": model_version} + await asyncio.sleep(explore_step_time_list[step - 1]) + dummy_exps = [ + Experience( + tokens=torch.tensor([0, 0, 0]), + info={"model_version": model_version}, + ) + for _ in range(self.config.buffer.train_batch_size) + ] + await self.experience_pipeline.process.remote(dummy_exps) self.monitor.log(metric, step=step) Explorer.explore_step = new_explore_step - Explorer.save_checkpoint = wrapper(Explorer.save_checkpoint) Explorer._finish_explore_step = new_finish_explore_step -def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None: +def run_trainer(config: Config, train_step_time_list: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) - trainer_monkey_patch(config, max_steps, intervals) + trainer_monkey_patch(train_step_time_list) train(config) ray.shutdown() -def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None: +def run_explorer(config: Config, explore_step_time_list: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) - explorer_monkey_patch(config, max_steps, intervals) + explorer_monkey_patch(explore_step_time_list) explore(config) ray.shutdown() def run_both( - config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int] + config: Config, + train_step_time_list: List[int], + explore_step_time_list: List[int], ) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) - trainer_monkey_patch(config, max_steps, trainer_intervals) - explorer_monkey_patch(config, max_steps, explorer_intervals) + trainer_monkey_patch(train_step_time_list) + explorer_monkey_patch(explore_step_time_list) both(config) ray.shutdown() @@ -108,10 +109,88 @@ def setUp(self): multiprocessing.set_start_method("spawn", force=True) self.process_list = [] + self.config = get_template_config() + self.config.project = "unittest" + self.config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.checkpoint_root_dir = get_checkpoint_path() + self.config.buffer.total_epochs = 1 + self.config.buffer.batch_size = 4 + self.config.algorithm.repeat_times = 8 + self.config.cluster.gpu_per_node = 2 + self.config.cluster.node_num = 1 + self.config.model.model_path = get_model_path() + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + experience_buffer = ExperienceBufferConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE.value, + ) + self.config.buffer.trainer_input.experience_buffer = deepcopy(experience_buffer) + self.config.synchronizer.sync_method = getattr(self, "sync_method", SyncMethod.NCCL) + self.config.synchronizer.sync_style = self.sync_style + self.config.synchronizer.sync_interval = 2 + self.config.monitor.monitor_type = "tensorboard" + + self.config.trainer.total_steps = len(self.train_step_time_list) + self.config.trainer.save_interval = 100 + self.config.buffer.train_batch_size = ( + self.config.buffer.batch_size * self.config.algorithm.repeat_times + ) + + self.config.explorer.rollout_model.engine_num = 1 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.buffer.explorer_output = deepcopy(experience_buffer) + + def _start_process(self, target_func, *args) -> Process: + process = Process(target=target_func, args=args) + process.start() + self.process_list.append(process) + return process + + def start_train_process(self, config: Config) -> Process: + return self._start_process(run_trainer, config, self.train_step_time_list) + + def start_explore_process(self, config: Config, explore_step_time_list: List[int]) -> Process: + return self._start_process(run_explorer, config, explore_step_time_list) + + def start_both_process(self, config: Config, explore_step_time_list: List[int]) -> Process: + return self._start_process( + run_both, + config, + self.train_step_time_list, + explore_step_time_list, + ) + + def join_process(self, process: Process, process_name: str, timeout: int = 300): + process.join(timeout=timeout) + if process.is_alive(): + self.fail(f"Process [{process_name}] is still alive after timeout") + + def wait_trainer_started(self, ray_namespace: str): + ray.init(ignore_reinit_error=True) + while True: + try: + ray.get_actor("queue-exp_buffer", namespace=ray_namespace) + break + except ValueError: + print("waiting for trainer to start.") + time.sleep(5) + return ray.get_actor("synchronizer", namespace=ray_namespace) + + def _check_metrics( + self, + config: Config, + module: str, + metric_check_dict: Dict[str, float], + ): + parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard", module)) + for metric_name, metric_value in metric_check_dict.items(): + metric_list = parser.metric_list(metric_name) + self.assertEqual(parser.metric_max_step(metric_list[0]), metric_value) + def tearDown(self): ray.shutdown(_exiting_interpreter=True) - if os.path.exists(CHECKPOINT_ROOT_DIR): - shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True) + if os.path.exists(self.config.checkpoint_root_dir): + shutil.rmtree(self.config.checkpoint_root_dir, ignore_errors=True) for process in self.process_list: if process.is_alive(): process.terminate() @@ -121,275 +200,142 @@ def tearDown(self): process.join() +@parameterized_class( + [ + { + "sync_method": SyncMethod.NCCL, # will be converted to CHECKPOINT + }, + { + "sync_method": SyncMethod.MEMORY, + }, + ] +) class TestSynchronizerExit(BaseTestSynchronizer): + def setUp(self): + self.sync_style = SyncStyle.FIXED + self.train_step_time_list = [2, 1, 2, 1, 2, 1, 2, 1] + self.explore_step_time_list = [2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5] + super().setUp() + def test_synchronizer(self): - config = get_template_config() - config.project = "unittest" - config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" - config.checkpoint_root_dir = get_checkpoint_path() - config.buffer.total_epochs = 1 - config.buffer.batch_size = 4 - config.cluster.gpu_per_node = 2 - config.cluster.node_num = 1 - config.model.model_path = get_model_path() - config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( - name="exp_buffer", - storage_type=StorageType.QUEUE.value, - ) - config.synchronizer.sync_method = SyncMethod.CHECKPOINT - config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER - config.synchronizer.sync_interval = 2 - config.trainer.save_interval = 100 - config.monitor.monitor_type = "tensorboard" - trainer_config = deepcopy(config) + trainer_config = deepcopy(self.config) + trainer_config.cluster.gpu_per_node = 1 trainer_config.mode = "train" - trainer_config.buffer.train_batch_size = 4 trainer_config.check_and_update() - explorer1_config = deepcopy(config) - explorer1_config.mode = "explore" - explorer1_config.explorer.name = "explorer1" - explorer1_config.explorer.rollout_model.engine_num = 1 - explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 - explorer1_config.buffer.explorer_output = ExperienceBufferConfig( - name="exp_buffer", - storage_type=StorageType.QUEUE.value, - ) - explorer1_config.check_and_update() + explorer_config = deepcopy(self.config) + explorer_config.mode = "explore" + explorer_config.check_and_update() - trainer_process = multiprocessing.Process( - target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1]) - ) - trainer_process.start() - self.process_list.append(trainer_process) - - ray.init(ignore_reinit_error=True) - while True: - try: - synchronizer = ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) - break - except ValueError: - print("waiting for trainer to start.") - time.sleep(5) - explorer_process_1 = multiprocessing.Process( - target=run_explorer, - args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]), - ) - explorer_process_1.start() - self.process_list.append(explorer_process_1) + trainer_process = self.start_train_process(trainer_config) + synchronizer = self.wait_trainer_started(trainer_config.ray_namespace) + explorer_process = self.start_explore_process(explorer_config, self.explore_step_time_list) self.assertEqual( synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) ) for _ in range(12): # Wait for up to 60 seconds try: - explorer1 = ray.get_actor("explorer1", namespace=trainer_config.ray_namespace) - ray.get(explorer1.is_alive.remote()) + explorer = ray.get_actor("explorer", namespace=trainer_config.ray_namespace) + ray.get(explorer.is_alive.remote()) break except ValueError: - print("waiting for explorer1 to start.") + print("waiting for explorer to start.") time.sleep(5) - trainer_process.join(timeout=200) + self.join_process(trainer_process, "trainer") self.assertEqual( synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) ) - explorer_process_1.join(timeout=200) + self.join_process(explorer_process, "explorer") time.sleep(6) with self.assertRaises(ValueError): ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) @parameterized_class( - ( - "sync_method", - "sync_style", - "max_steps", - "trainer_intervals", - "explorer1_intervals", - "explorer2_intervals", - ), [ - ( - SyncMethod.CHECKPOINT, - SyncStyle.FIXED, - 8, - [2, 1, 2, 1, 2, 1, 2, 1], - [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - ), - ( - SyncMethod.CHECKPOINT, - SyncStyle.DYNAMIC_BY_EXPLORER, - 8, - [2, 1, 2, 1, 2, 1, 2, 1], - [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - ), - ( - SyncMethod.MEMORY, - SyncStyle.FIXED, - 8, - [2, 1, 2, 1, 2, 1, 2, 1], - [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - ), - ( - SyncMethod.MEMORY, - SyncStyle.DYNAMIC_BY_EXPLORER, - 8, - [2, 1, 2, 1, 2, 1, 2, 1], - [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - ), - ], + { + "sync_method": sync_method, + "sync_style": sync_style, + "train_step_time_list": [2, 1, 2, 1, 2, 1, 2, 1], + "explore_step_time_lists": [ + [2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ], + "batch_size_list": [20, 12], + } + for sync_method in [SyncMethod.CHECKPOINT, SyncMethod.MEMORY] + for sync_style in [SyncStyle.FIXED, SyncStyle.EXPLORER_DRIVEN, SyncStyle.TRAINER_DRIVEN] + ] ) class TestStateDictBasedSynchronizer(BaseTestSynchronizer): def test_synchronizer(self): - config = get_template_config() - config.project = "unittest" - config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" - config.checkpoint_root_dir = get_checkpoint_path() - config.buffer.total_epochs = 1 - config.buffer.batch_size = 4 - config.cluster.gpu_per_node = 2 - config.cluster.node_num = 1 - config.model.model_path = get_model_path() - config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( - name="exp_buffer", - storage_type=StorageType.QUEUE.value, - ) - config.synchronizer.sync_method = self.sync_method - config.synchronizer.sync_style = self.sync_style - config.synchronizer.sync_interval = 2 - config.trainer.save_interval = 100 - config.trainer.total_steps = self.max_steps - config.monitor.monitor_type = "tensorboard" - trainer_config = deepcopy(config) + trainer_config = deepcopy(self.config) trainer_config.mode = "train" - trainer_config.buffer.train_batch_size = 4 trainer_config.check_and_update() - explorer1_config = deepcopy(config) - explorer1_config.mode = "explore" - explorer1_config.explorer.name = "explorer1" - explorer1_config.explorer.rollout_model.engine_num = 1 - explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 - explorer1_config.buffer.explorer_output = ExperienceBufferConfig( - name="exp_buffer", - storage_type=StorageType.QUEUE.value, - ) - explorer2_config = deepcopy(explorer1_config) - explorer2_config.explorer.name = "explorer2" - explorer1_config.check_and_update() - explorer2_config.check_and_update() - - trainer_process = multiprocessing.Process( - target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals) - ) - trainer_process.start() - self.process_list.append(trainer_process) + trainer_process = self.start_train_process(trainer_config) + _ = self.wait_trainer_started(trainer_config.ray_namespace) - ray.init(ignore_reinit_error=True) - while True: - try: - ray.get_actor("queue-exp_buffer", namespace=trainer_config.ray_namespace) - break - except ValueError: - print("waiting for trainer to start.") - time.sleep(5) - explorer_process_1 = multiprocessing.Process( - target=run_explorer, - args=(explorer1_config, self.max_steps, self.explorer1_intervals), + assert len(self.batch_size_list) == len(self.explore_step_time_lists), ( + f"{len(self.batch_size_list)=} not equal to {len(self.explore_step_time_lists)=}, " + "please check the test case" ) - explorer_process_1.start() - self.process_list.append(explorer_process_1) - explorer_process_2 = multiprocessing.Process( - target=run_explorer, args=(explorer2_config, self.max_steps, self.explorer2_intervals) + assert sum(self.batch_size_list) == trainer_config.buffer.train_batch_size, ( + f"{sum(self.batch_size_list)=} not equal to {trainer_config.buffer.train_batch_size}, " + "please check the test case" ) - explorer_process_2.start() - self.process_list.append(explorer_process_2) - - explorer_process_1.join(timeout=200) - explorer_process_2.join(timeout=200) - trainer_process.join(timeout=200) + explorer_configs, explorer_processes = [], [] + for i, (explore_step_time_list, batch_size) in enumerate( + zip(self.explore_step_time_lists, self.batch_size_list) + ): + explorer_config = deepcopy(self.config) + explorer_config.mode = "explore" + explorer_config.explorer.name = f"explorer_{i}" + explorer_config.explorer.rollout_model.engine_num = 1 + explorer_config.explorer.rollout_model.tensor_parallel_size = 1 + explorer_config.buffer.train_batch_size = batch_size + explorer_config.check_and_update() + explorer_configs.append(explorer_config) + explorer_processes.append( + self.start_explore_process(explorer_config, explore_step_time_list) + ) + + self.join_process(trainer_process, "trainer") + for i, explore_process in enumerate(explorer_processes): + self.join_process(explore_process, f"explorer_{i}") # check the tensorboard - parser = TensorBoardParser( - os.path.join(trainer_config.monitor.cache_dir, "tensorboard", "trainer") - ) - actor_metrics = parser.metric_list("actor") - self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) - parser = TensorBoardParser( - os.path.join(explorer1_config.monitor.cache_dir, "tensorboard", "explorer1") - ) - rollout_metrics = parser.metric_list("rollout") - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) - parser = TensorBoardParser( - os.path.join(explorer2_config.monitor.cache_dir, "tensorboard", "explorer2") - ) - rollout_metrics = parser.metric_list("rollout") - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + self._check_metrics(trainer_config, "trainer", {"actor": len(self.train_step_time_list)}) + for i, (explorer_config, explore_step_time_list) in enumerate( + zip(explorer_configs, self.explore_step_time_lists) + ): + self._check_metrics( + explorer_config, f"explorer_{i}", {"rollout": len(explore_step_time_list)} + ) @parameterized_class( - ("sync_style", "max_steps", "trainer_intervals", "explorer_intervals"), [ - ( - SyncStyle.FIXED, - 8, - [2, 1, 2, 1, 2, 1, 2, 1], - [0, 2.5, 2.5, 2.5, 2.5, 0], - ), - ( - SyncStyle.DYNAMIC_BY_EXPLORER, - 8, - [2, 1, 2, 1, 2, 1, 2, 1], - [0, 0.5, 0.5, 0.5, 0.5, 0], - ), + { + "sync_style": sync_style, + "train_step_time_list": [2, 2, 1, 1, 2, 2, 1, 1], + "explore_step_time_list": [1, 1, 2, 2, 1, 1, 2, 2], + } + for sync_style in [SyncStyle.FIXED, SyncStyle.EXPLORER_DRIVEN, SyncStyle.TRAINER_DRIVEN] ], ) class TestNCCLBasedSynchronizer(BaseTestSynchronizer): def test_synchronizer(self): - config = get_template_config() - config.project = "unittest" - config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" - config.checkpoint_root_dir = get_checkpoint_path() - config.buffer.total_epochs = 1 - config.buffer.batch_size = 4 - config.trainer.total_steps = self.max_steps - config.model.model_path = get_model_path() - config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( - name="exp_buffer", - storage_type=StorageType.QUEUE.value, - ) - config.synchronizer.sync_method = SyncMethod.NCCL - config.synchronizer.sync_style = self.sync_style - config.synchronizer.sync_interval = 2 - config.trainer.save_interval = 100 - config.monitor.monitor_type = "tensorboard" - config.mode = "both" - config.check_and_update() + self.config.mode = "both" + self.config.check_and_update() # TODO: test more interval cases - both_process = multiprocessing.Process( - target=run_both, - args=(config, self.max_steps, self.trainer_intervals, self.explorer_intervals), - ) - both_process.start() - self.process_list.append(both_process) - both_process.join(timeout=200) + both_process = self.start_both_process(self.config, self.explore_step_time_list) + self.join_process(both_process, "both") # check the tensorboard - parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard", "trainer")) - actor_metrics = parser.metric_list("actor") - self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) - parser = TensorBoardParser( - os.path.join(config.monitor.cache_dir, "tensorboard", "explorer") - ) - rollout_metrics = parser.metric_list("rollout") - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + self._check_metrics(self.config, "trainer", {"actor": len(self.train_step_time_list)}) + self._check_metrics(self.config, "explorer", {"rollout": len(self.explore_step_time_list)}) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 2a795e74973..95a689f5bed 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -597,7 +597,7 @@ def test_fully_async_mode(self): ) config.buffer.trainer_input.experience_buffer.replay_buffer.enable = self.use_priority_queue config.synchronizer.sync_method = SyncMethod.CHECKPOINT - config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN config.synchronizer.sync_interval = 8 config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) @@ -1378,7 +1378,7 @@ def test_trainer(self): self.config.algorithm.advantage_fn_args = { "epsilon": 1e-6, } - self.config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + self.config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN self.config.synchronizer.sync_interval = 1 self.config.check_and_update() both(self.config) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 53add1872a1..3f82520988a 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -254,9 +254,9 @@ def debug( load_plugins() config = load_config(config_path) config.mode = "explore" + config.ray_namespace = DEBUG_NAMESPACE config.check_and_update() sys.path.insert(0, os.getcwd()) - config.ray_namespace = DEBUG_NAMESPACE ray.init( namespace=config.ray_namespace, runtime_env={"env_vars": config.get_envs()}, diff --git a/trinity/common/config.py b/trinity/common/config.py index b68f8676636..c392bb60dc8 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -530,7 +530,7 @@ class InferenceModelConfig: chat_template: Optional[str] = None # For Qwen3 - enable_thinking: bool = False + enable_thinking: Optional[bool] = None # For history recording enable_history: bool = False diff --git a/trinity/common/config_validator.py b/trinity/common/config_validator.py index 0bc4e813daf..fdad4d5856a 100644 --- a/trinity/common/config_validator.py +++ b/trinity/common/config_validator.py @@ -630,7 +630,7 @@ def validate(self, config: Config) -> None: if config.synchronizer.sync_style == SyncStyle.FIXED: raise ValueError( "over_rollout_ratio is not compatible with fixed sync_style, please set " - "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`." + "`synchronizer.sync_style` to `explorer_driven` or `trainer_driven`." ) self._validate_lora(config) diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 183702927bd..644cef235a7 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -33,7 +33,10 @@ class CaseInsensitiveEnumMeta(EnumMeta): + name_aliases = {} + def __getitem__(cls, name): + name = cls.name_aliases.get(name.lower(), name) return super().__getitem__(name.upper()) def __getattr__(cls, name): @@ -42,6 +45,7 @@ def __getattr__(cls, name): return super().__getattr__(name) def __call__(cls, value, *args, **kwargs): + value = cls.name_aliases.get(value.lower(), value) return super().__call__(value.lower(), *args, **kwargs) @@ -65,15 +69,10 @@ class StorageType(CaseInsensitiveEnum): class SyncMethodEnumMeta(CaseInsensitiveEnumMeta): - def __call__(cls, value, *args, **kwargs): - if value == "online": - value = "nccl" - elif value == "offline": - value = "checkpoint" - try: - return super().__call__(value, *args, **kwargs) - except Exception: - raise ValueError(f"Invalid SyncMethod: {value}") + name_aliases = { + "online": "nccl", + "offline": "checkpoint", + } class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta): @@ -89,7 +88,6 @@ class RunningStatus(Enum): RUNNING = "running" REQUIRE_SYNC = "require_sync" - WAITING_SYNC = "waiting_sync" STOPPED = "stopped" @@ -102,10 +100,17 @@ class OpType(Enum): DIV = "div" -class SyncStyle(CaseInsensitiveEnum): +class SyncStyleEnumMeta(CaseInsensitiveEnumMeta): + name_aliases = { + "dynamic_by_explorer": "explorer_driven", + "dynamic_by_trainer": "trainer_driven", + } + + +class SyncStyle(CaseInsensitiveEnum, metaclass=SyncStyleEnumMeta): FIXED = "fixed" - DYNAMIC_BY_TRAINER = "dynamic_by_trainer" - DYNAMIC_BY_EXPLORER = "dynamic_by_explorer" + TRAINER_DRIVEN = "trainer_driven" + EXPLORER_DRIVEN = "explorer_driven" class SaveStrategy(CaseInsensitiveEnum): diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 86305b6658d..676df1bf0ce 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -104,9 +104,6 @@ def apply_chat_template( messages: List[dict], ) -> str: assert tokenizer_or_processor is not None, "tokenizer_or_processor must be provided." - if self.chat_template is None: - assert self.tokenizer is not None, "self.tokenizer must be initialized." - self.chat_template = self.tokenizer.get_chat_template() if messages[-1]["role"] == "assistant": prompt = tokenizer_or_processor.apply_chat_template( @@ -173,8 +170,6 @@ async def convert_messages_to_experience( """ if self.tokenizer is None: await self._initialize_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() token_ids, action_mask, prompt_length = self.action_mask_method( tokenizer=self.tokenizer, messages=messages, diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index c2c76519c98..00fa8e089bd 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- import os import re -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch -from torch.distributed._tensor import DTensor, Placement, Shard from trinity.common.config import TrainerConfig from trinity.utils.log import get_logger @@ -223,18 +221,6 @@ def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, T raise NotImplementedError(f"Unsupported trainer type {config.trainer_type}") -# copy from verl/scripts/model_merger.py -def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): - if placement.is_replicate(): - return tensors[0] - elif placement.is_partial(): - raise NotImplementedError("Partial placement is not supported yet") - elif placement.is_shard(): - return torch.cat(tensors, dim=placement.dim).contiguous() - else: - raise ValueError(f"Unsupported placement: {placement}") - - def get_verl_checkpoint_info( checkpoint_path: str, step_num: Optional[int] = None, raise_error: bool = True ) -> Tuple[str, int]: @@ -271,107 +257,36 @@ def get_verl_checkpoint_info( return path, step_num -# copy from verl/scripts/model_merger.py +# modified from verl/model_merger/fsdp_model_merger.py def load_fsdp_state_dict_from_verl_checkpoint(checkpoint_path: str) -> dict: # noqa: C901 """Load state dict from a Verl checkpoint.""" + from verl.model_merger.base_model_merger import ModelMergerConfig + from verl.model_merger.fsdp_model_merger import FSDPModelMerger + logger = get_logger(__name__) - logger.info(f"Loading state dict from {checkpoint_path}") - assert not checkpoint_path.endswith( - "huggingface" - ), "The local_dir should not end with huggingface" - - # copy rank zero to find the shape of (dp, fsdp) - rank = 0 - world_size = 0 - for filename in os.listdir(checkpoint_path): - match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) - if match: - world_size = match.group(1) - break - assert world_size, "No model file with the proper format" - - state_dict = torch.load( - os.path.join(checkpoint_path, f"model_world_size_{world_size}_rank_{rank}.pt"), - map_location="cpu", + config = ModelMergerConfig( + operation="merge", + backend="fsdp", + trust_remote_code=False, + is_value_model=False, + local_dir=checkpoint_path, + hf_model_config_path=os.path.join(checkpoint_path, "huggingface"), ) - pivot_key = sorted(list(state_dict.keys()))[0] - weight = state_dict[pivot_key] - assert isinstance(weight, torch.distributed._tensor.DTensor) - # get sharding info - device_mesh = weight.device_mesh - mesh = device_mesh.mesh - mesh_dim_names = device_mesh.mesh_dim_names + merger = FSDPModelMerger(config) - logger.info(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - - assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}" + world_size = merger._get_world_size() + rank_zero_state_dict = merger._load_rank_zero_state_dict(world_size) - if "tp" in mesh_dim_names: - # fsdp * tp - total_shards = mesh.shape[-1] * mesh.shape[-2] - mesh_shape = (mesh.shape[-2], mesh.shape[-1]) - else: - # fsdp - total_shards = mesh.shape[-1] - mesh_shape = (mesh.shape[-1],) + mesh, mesh_dim_names = merger._extract_device_mesh_info(rank_zero_state_dict, world_size) + logger.info(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") + total_shards, mesh_shape = merger._calculate_shard_configuration(mesh, mesh_dim_names) logger.info(f"Processing model shards with {total_shards} {mesh_shape} in total") - model_state_dict_lst = [] - model_state_dict_lst.append(state_dict) - model_state_dict_lst.extend([""] * (total_shards - 1)) - - def process_one_shard(rank): - model_path = os.path.join(checkpoint_path, f"model_world_size_{world_size}_rank_{rank}.pt") - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - model_state_dict_lst[rank] = state_dict # noqa: F821 - return state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: # type: ignore - for rank in range(1, total_shards): - executor.submit(process_one_shard, rank) - state_dict = {} - param_placements: Dict[str, List[Placement]] = {} - keys = set(model_state_dict_lst[0].keys()) - for key in keys: - state_dict[key] = [] - for model_state_dict in model_state_dict_lst: - try: - tensor = model_state_dict.pop(key) - except: # noqa: E722 - logger.info("-" * 30) - logger.info(model_state_dict.keys()) - if isinstance(tensor, DTensor): - state_dict[key].append(tensor._local_tensor.bfloat16()) - placements = tuple(tensor.placements) - # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] == "dp": - placements = placements[1:] - if key not in param_placements: - param_placements[key] = placements - else: - assert param_placements[key] == placements - else: - state_dict[key] = tensor.bfloat16() - - del model_state_dict_lst - - for key in sorted(state_dict): - if not isinstance(state_dict[key], list): - logger.info(f"No need to merge key {key}") - continue - # merge shards - placements: Tuple[Shard] = param_placements[key] # type: ignore - if len(mesh_shape) == 1: - # 1-D list, FSDP without TP - assert len(placements) == 1 - shards = state_dict[key] - state_dict[key] = merge_by_placement(shards, placements[0]) - else: - # 2-D list, FSDP + TP - raise NotImplementedError("FSDP + TP is not supported yet") - - return state_dict + merged_state_dict = merger._load_and_merge_state_dicts( + world_size, total_shards, mesh_shape, mesh_dim_names + ) + return merged_state_dict def load_huggingface_state_dict(checkpoint_path: str): @@ -382,98 +297,51 @@ def load_huggingface_state_dict(checkpoint_path: str): def get_megatron_converter(checkpoint_path: str): - from megatron.core import mpu - from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed - from transformers import AutoConfig + import builtins + from contextlib import contextmanager + from verl.model_merger.base_model_merger import ModelMergerConfig from verl.model_merger.megatron_model_merger import MegatronModelMerger - from verl.utils.device import get_device_name, get_torch_device + # modified from verl/model_merger/megatron_model_merger.py class MegatronStateDictConverter(MegatronModelMerger): def __init__(self, config: ModelMergerConfig): - self.config = config - self.hf_model_config_path = config.hf_model_config_path - self.model_config = AutoConfig.from_pretrained( - self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code - ) - - self.rank = 0 - self.world_size = 1 - local_rank = 0 - get_torch_device().set_device(f"{get_device_name()}:{local_rank}") - - mpu.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=self.world_size, - virtual_pipeline_model_parallel_size=None, - context_parallel_size=1, - expert_model_parallel_size=1, - ) - model_parallel_cuda_manual_seed(0) - self.hf_config = AutoConfig.from_pretrained( - self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code - ) + original_init_process_group = torch.distributed.init_process_group + original_get_rank = torch.distributed.get_rank + original_get_world_size = torch.distributed.get_world_size + torch.distributed.init_process_group = lambda *args, **kwargs: None + torch.distributed.get_rank = lambda: 0 + torch.distributed.get_world_size = lambda: 1 self.logger = get_logger(__name__) - self.logger.debug(self.hf_config) - - self.params_mapping = { - # megatron core gpt model name, huggingface model name - # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the - # longer key within the containing relationship is processed first. - "embedding.word_embeddings": "model.embed_tokens", - # input layer norm for dpskv3 - "input_layernorm.weight": "input_layernorm.weight", - "input_layernorm.bias": "input_layernorm.bias", - # attn - "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", - "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", - "self_attention.linear_qkv": "self_attn.qkv_proj", - "self_attention.q_layernorm": "self_attn.q_norm", - "self_attention.k_layernorm": "self_attn.k_norm", - "self_attention.linear_proj": "self_attn.o_proj", - # mla - "self_attention.linear_q_proj": "self_attn.q_proj", - "self_attention.linear_q_down_proj": "self_attn.q_a_proj", - "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", - "self_attention.linear_q_up_proj": "self_attn.q_b_proj", - "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", - "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", - "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", - # mlp - "pre_mlp_layernorm": "post_attention_layernorm", - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", - "mlp.linear_fc1": "mlp.gate_up_proj", - "mlp.linear_fc2": "mlp.down_proj", - # moe - "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", - "mlp.router": "mlp.gate", - "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", - "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", - "linear_fc1": "gate_up_proj", - "linear_fc2": "down_proj", - # output - "final_layernorm": "norm", - "output_layer": "lm_head", - } - - if "Qwen2MoeForCausalLM" in self.hf_config.architectures: - self.params_mapping[ - "mlp.shared_experts.linear_fc1" - ] = "mlp.shared_expert.gate_up_proj" - self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj" - self.params_mapping[ - "mlp.shared_experts.gate_weight" - ] = "mlp.shared_expert_gate.weight" + with self._redirect_print_to_logger(): + super().__init__(config) + torch.distributed.init_process_group = original_init_process_group + torch.distributed.get_rank = original_get_rank + torch.distributed.get_world_size = original_get_world_size + + @contextmanager + def _redirect_print_to_logger(self): + original_print = builtins.print + + def logger_print(*args, **kwargs): + message = " ".join(str(arg) for arg in args) + self.logger.debug(message) + + builtins.print = logger_print + try: + yield + finally: + builtins.print = original_print def get_state_dict(self, checkpoint_path): self.config.local_dir = checkpoint_path from verl.utils.megatron_utils import get_dist_checkpoint_path - model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) + with self._redirect_print_to_logger(): + model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) - model_state_dict = self._load_state_dicts(model_ckpt_path) - merged_state_dict = self._merge_state_dicts(model_state_dict) + model_state_dict = self._load_state_dicts(model_ckpt_path) + merged_state_dict = self._merge_state_dicts(model_state_dict) del model_state_dict return merged_state_dict diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index ef681b6578b..fffc4e17e01 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -638,8 +638,6 @@ def get_lora_request(self, lora_path: Optional[str] = None) -> Any: async def get_message_token_len(self, messages) -> int: if self.tokenizer is None: await self._initialize_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py index e97b1da07c7..b9a8ffe0c66 100644 --- a/trinity/common/models/vllm_patch/__init__.py +++ b/trinity/common/models/vllm_patch/__init__.py @@ -31,56 +31,28 @@ def get_api_server( run_api_server_in_ray_actor, ) - return asyncio.create_task( - run_api_server_in_ray_actor( - async_llm, - host=host, - port=port, - logger=logger, - model_path=config.model_path, # type: ignore [arg-type] - enable_auto_tool_choice=config.enable_auto_tool_choice, - tool_call_parser=config.tool_call_parser, - reasoning_parser=config.reasoning_parser, - enable_log_requests=config.enable_log_requests, - chat_template=config.chat_template, - ) - ) elif vllm_version == parse_version("0.12.0"): from trinity.common.models.vllm_patch.api_patch_v12 import ( - run_api_server_in_ray_actor_v12, + run_api_server_in_ray_actor_v12 as run_api_server_in_ray_actor, ) - return asyncio.create_task( - run_api_server_in_ray_actor_v12( - async_llm, - host=host, - port=port, - model_path=config.model_path, # type: ignore [arg-type] - logger=logger, - enable_auto_tool_choice=config.enable_auto_tool_choice, - tool_call_parser=config.tool_call_parser, - reasoning_parser=config.reasoning_parser, - enable_log_requests=config.enable_log_requests, - chat_template=config.chat_template, - ) - ) else: from trinity.common.models.vllm_patch.api_patch_v13 import ( - run_api_server_in_ray_actor_v13, + run_api_server_in_ray_actor_v13 as run_api_server_in_ray_actor, ) - logger.info(f"Using vLLM API patch for version {vllm.__version__}") - return asyncio.create_task( - run_api_server_in_ray_actor_v13( - async_llm, - host=host, - port=port, - model_path=config.model_path, # type: ignore [arg-type] - logger=logger, - enable_auto_tool_choice=config.enable_auto_tool_choice, - tool_call_parser=config.tool_call_parser, - reasoning_parser=config.reasoning_parser, - enable_log_requests=config.enable_log_requests, - chat_template=config.chat_template, - ) + logger.info(f"Using vLLM API patch for version {vllm.__version__}") + return asyncio.create_task( + run_api_server_in_ray_actor( + async_llm, + host=host, + port=port, + model_path=config.model_path, # type: ignore [arg-type] + logger=logger, + enable_auto_tool_choice=config.enable_auto_tool_choice, + tool_call_parser=config.tool_call_parser, + reasoning_parser=config.reasoning_parser, + enable_log_requests=config.enable_log_requests, + chat_template=config.chat_template, ) + ) diff --git a/trinity/common/models/vllm_patch/api_patch.py b/trinity/common/models/vllm_patch/api_patch.py index 0a3b02e6542..3b8dec85e7b 100644 --- a/trinity/common/models/vllm_patch/api_patch.py +++ b/trinity/common/models/vllm_patch/api_patch.py @@ -38,6 +38,7 @@ from vllm.outputs import RequestOutput from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.version import __version__ as VLLM_VERSION from trinity.common.models.vllm_patch import get_vllm_version from trinity.utils.log import get_logger @@ -270,7 +271,9 @@ async def chat_completion_full_generator( # noqa C901 return PatchedChatCompletionResponse(**response_args) -async def run_server_in_ray(args, engine_client): +async def run_server_in_ray(args, engine_client, logger): + logger.info("vLLM API server version %s", VLLM_VERSION) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 @@ -369,10 +372,10 @@ async def run_api_server_in_ray_actor( cli_args.extend(["--tool-call-parser", tool_call_parser]) if reasoning_parser: cli_args.extend(["--reasoning-parser", reasoning_parser]) + if chat_template: + cli_args.extend(["--chat-template", chat_template]) args = parser.parse_args(cli_args) if vllm_version >= parse_version("0.11.0"): args.structured_outputs_config.reasoning_parser = reasoning_parser - if chat_template: - args.chat_template = chat_template logger.info(f"Starting vLLM OpenAI API server with args: {args}") - await run_server_in_ray(args, async_llm) + await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/common/models/vllm_patch/api_patch_v12.py b/trinity/common/models/vllm_patch/api_patch_v12.py index 14b87108b3e..5ffd4e6cd57 100644 --- a/trinity/common/models/vllm_patch/api_patch_v12.py +++ b/trinity/common/models/vllm_patch/api_patch_v12.py @@ -159,10 +159,10 @@ async def run_api_server_in_ray_actor_v12( cli_args.extend(["--tool-call-parser", tool_call_parser]) if reasoning_parser: cli_args.extend(["--reasoning-parser", reasoning_parser]) + if chat_template: + cli_args.extend(["--chat-template", chat_template]) args = parser.parse_args(cli_args) if vllm_version >= parse_version("0.11.0"): args.structured_outputs_config.reasoning_parser = reasoning_parser - if chat_template: - args.chat_template = chat_template logger.info(f"Starting vLLM OpenAI API server with args: {args}") await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/common/models/vllm_patch/api_patch_v13.py b/trinity/common/models/vllm_patch/api_patch_v13.py index 8f6dea10f4a..d8caa3d47a2 100644 --- a/trinity/common/models/vllm_patch/api_patch_v13.py +++ b/trinity/common/models/vllm_patch/api_patch_v13.py @@ -169,9 +169,9 @@ async def run_api_server_in_ray_actor_v13( cli_args.extend(["--tool-call-parser", tool_call_parser]) if reasoning_parser: cli_args.extend(["--reasoning-parser", reasoning_parser]) + if chat_template: + cli_args.extend(["--chat-template", chat_template]) args = parser.parse_args(cli_args) args.structured_outputs_config.reasoning_parser = reasoning_parser - if chat_template: - args.chat_template = chat_template logger.info(f"Starting vLLM OpenAI API server with args: {args}") await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index a2c97e362a9..54b953f1e1a 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -46,7 +46,6 @@ def __init__(self, config: Config): ) explorer_state = self.state.load_explorer() self.explore_step_num = explorer_state.get("latest_iteration", 0) - self.last_sync_step = self.explore_step_num self.last_monitored_step = self.explore_step_num self.synchronizer = Synchronizer.get_actor(config) self.config = config @@ -84,7 +83,10 @@ def __init__(self, config: Config): # boradcast to all rollout models self.enable_lora = self.config.explorer.rollout_model.enable_lora self.model_version = -1 - self.last_sync_successful = True + self.sync_offset = config.synchronizer.sync_offset + self.sync_interval = config.synchronizer.sync_interval + self.sync_method = config.synchronizer.sync_method + self.sync_style = config.synchronizer.sync_style self.eval_start_time = None self.explore_start_time = None self.logger.info("Finished initializing Explorer.") @@ -153,7 +155,6 @@ async def _pull_latest_weights(self): self.logger.info("Start to pull latest model weights.") new_version = await self.synchronizer.wait_new_model_state_dict.remote( current_version=self.model_version, - no_wait=(self.config.synchronizer.sync_style != SyncStyle.FIXED), ) if new_version > self.model_version: if self.model_version != -1: @@ -162,13 +163,10 @@ async def _pull_latest_weights(self): *[model.sync_model.remote(new_version) for model in self.models] ) self.model_version = new_version - self.last_sync_step = self.explore_step_num - self.last_sync_successful = True else: self.logger.warning( f"No new model weights found, current version: {self.model_version}" ) - self.last_sync_successful = False async def _nccl_weights_update(self): new_version = await self.synchronizer.ready_to_nccl_sync.remote( @@ -176,14 +174,11 @@ async def _nccl_weights_update(self): ) if new_version is None: self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.") - self.last_sync_successful = False return self.model_version = new_version await asyncio.gather( *[model.sync_model.remote(self.model_version) for model in self.models] ) - self.last_sync_step = self.explore_step_num - self.last_sync_successful = True async def prepare(self) -> None: """Preparation before running.""" @@ -215,7 +210,7 @@ async def prepare(self) -> None: if self.config.explorer.eval_on_startup and self.explore_step_num == 0: await self.eval() - await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC) + await self.synchronizer.set_explorer_status.remote(RunningStatus.RUNNING) except Exception as e: self.logger.error(f"Error during explorer preparation: {traceback.format_exc()}") await self.shutdown() @@ -263,42 +258,35 @@ async def explore_step(self) -> bool: tasks = await self.taskset.read_async() except StopAsyncIteration: self.logger.warning("No more tasks to explore. Stop exploring.") - await self.save_checkpoint(sync_weight=False) + await self.finish_current_steps() + await self.save_checkpoint() await self.synchronizer.set_explorer_status.remote( RunningStatus.STOPPED, - old_status=( - RunningStatus.RUNNING - if self.last_sync_successful - else RunningStatus.REQUIRE_SYNC - ), + old_status=RunningStatus.RUNNING, ) await self.shutdown() return False - self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 + self.scheduler.schedule(tasks, batch_id=self.explore_step_num) return True + async def finish_current_steps(self) -> None: + if self.scheduler: + await self._finish_steps( + self.last_monitored_step + 1, self.explore_step_num, self.model_version + ) + self.last_monitored_step = self.explore_step_num + async def need_sync(self) -> bool: - if self.config.synchronizer.sync_style == SyncStyle.FIXED: - if self.explore_step_num <= self.config.synchronizer.sync_offset: - return False - require_sync = ( - self.explore_step_num - self.config.synchronizer.sync_offset - ) % self.config.synchronizer.sync_interval == 0 - else: - require_sync = False - if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_EXPLORER: - delta = self.explore_step_num - self.last_sync_step - if delta >= self.config.synchronizer.sync_interval: - require_sync = True + if self.explore_step_num <= self.sync_offset: + return False + require_sync = False + if (self.explore_step_num - self.sync_offset) % self.sync_interval == 0: + await self.finish_current_steps() + if self.sync_style == SyncStyle.TRAINER_DRIVEN and self.sync_method == SyncMethod.NCCL: + require_sync = await self.synchronizer.trainer_requires_sync.remote() else: - require_sync = await ( - self.synchronizer.get_trainer_status.remote() == RunningStatus.REQUIRE_SYNC - ) - if require_sync and self.last_sync_successful: - await self.synchronizer.set_explorer_status.remote( - RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING - ) + require_sync = True return require_sync def need_eval(self) -> bool: @@ -361,27 +349,7 @@ async def benchmark(self) -> bool: await self._finish_eval_step(prefix="bench") return True - async def save_checkpoint(self, sync_weight: bool = False) -> None: - if self.scheduler: - if self.explore_step_num == 0: - await self._finish_eval_step(step=0) - else: - await self._finish_steps( - self.last_monitored_step + 1, self.explore_step_num, self.model_version - ) - self.last_monitored_step = self.explore_step_num - - if sync_weight: - # sync weights - self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") - if self.use_nccl_sync: - await self._nccl_weights_update() - else: # pull weights from Synchronizer - await self._pull_latest_weights() - self.logger.info( - f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}." - ) - + async def save_checkpoint(self) -> None: # save explore checkpoint self.state.save_explorer( current_step=self.explore_step_num, @@ -391,7 +359,19 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights - await self.save_checkpoint(sync_weight=True) + if self.scheduler and self.explore_step_num == 0: + await self._finish_eval_step(step=0) + + self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") + if self.use_nccl_sync: + await self._nccl_weights_update() + else: # pull weights from Synchronizer + await self._pull_latest_weights() + self.logger.info( + f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}." + ) + + await self.save_checkpoint() async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None: for step in range(start_step, end_step + 1): diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py index 868c733c452..214f1a02251 100644 --- a/trinity/explorer/proxy/service.py +++ b/trinity/explorer/proxy/service.py @@ -92,7 +92,6 @@ async def _sync_model_weights(self, index: int) -> None: while time.time() - start_time < self.max_timeout: current_load = await self.models[index].get_current_load() if current_load == 0: - self.models[index].status = RunningStatus.WAITING_SYNC self.logger.info(f"Model {index} begins synchronization.") timeout_flag = False break diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py index dd23927099f..2808831f8c3 100644 --- a/trinity/manager/config_registry/explorer_config_manager.py +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -295,7 +295,9 @@ def on_change(): [sync_style.value for sync_style in SyncStyle], help="""`fixed`: The explorer and trainer sync model weights once every `sync_interval` steps. -`dynamic_by_explorer`: The explorer decides to request a sync after `sync_interval` steps.""", +`explorer_driven`: The explorer decides to request a sync after `sync_interval` steps. + +`trainer_driven`: The trainer decides to request a sync after `sync_interval` steps.""", disabled=disabled, on_change=on_change, **kwargs, diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 0fcf312b576..71cebdf4b72 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -175,9 +175,9 @@ async def set_trainer_status(self, status: RunningStatus): if status == RunningStatus.STOPPED: self._ready_condition.notify_all() - def get_trainer_status(self) -> RunningStatus: - """Get the current status of the trainer.""" - return self.trainer_status + def trainer_requires_sync(self) -> bool: + """Check if the trainer is require sync.""" + return self.trainer_status == RunningStatus.REQUIRE_SYNC async def set_explorer_status( self, status: RunningStatus, old_status: Optional[RunningStatus] = None @@ -202,9 +202,9 @@ async def set_explorer_status( self.explorer_status_counts[status] = 0 self.explorer_status_counts[status] += 1 - def get_explorer_status_counts(self) -> Dict[RunningStatus, int]: - """Return the current status counts for all explorers.""" - return self.explorer_status_counts + def explorer_requires_sync(self) -> bool: + """Check if any explorer is require sync.""" + return self.explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 async def set_model_state_dict_with_step_num( self, step_num: Optional[int] = None, world_size: Optional[int] = None @@ -302,7 +302,7 @@ async def setup_weight_sync_group( explorer = ray.get_actor(self.config.explorer.name, namespace=self.config.ray_namespace) await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta) - async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = False) -> int: + async def wait_new_model_state_dict(self, current_version: int) -> int: """ Wait until a new model state is available. @@ -316,20 +316,28 @@ async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = assert ( self.model_version >= current_version ), f"The model version in Synchronizer ({self.model_version}) should be no smaller than that in Explorer ({current_version})!" + await self.set_explorer_status( + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING + ) if self.model_version == current_version: - if not no_wait and self.trainer_status != RunningStatus.STOPPED: - # TODO: explorer need support no wait - # TODO: handle timeout + if self.trainer_status != RunningStatus.STOPPED: await asyncio.wait_for( self._ready_condition.wait(), timeout=self.config.synchronizer.sync_timeout, ) - if self.model_version > current_version: - await self.set_explorer_status( - RunningStatus.RUNNING, old_status=RunningStatus.REQUIRE_SYNC - ) + await self.set_explorer_status( + RunningStatus.RUNNING, old_status=RunningStatus.REQUIRE_SYNC + ) return self.model_version + async def notify_no_new_model_state_dict(self) -> None: + """ + Notify the explorer that there is no new model state. + Used for `wait_new_model_state_dict`. + """ + async with self._ready_condition: + self._ready_condition.notify_all() + async def get_latest_model_version(self) -> int: """ Get the latest model version available in the synchronizer. @@ -361,11 +369,11 @@ async def sync_failed(): if module == "explorer": another_module = "Trainer" await self.set_explorer_status( - RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.WAITING_SYNC + RunningStatus.RUNNING, old_status=RunningStatus.REQUIRE_SYNC ) else: another_module = "Explorer" - self.trainer_status = RunningStatus.REQUIRE_SYNC + self.trainer_status = RunningStatus.RUNNING self.logger.error(f"{another_module} is not ready for model weight sync.") return None @@ -381,12 +389,12 @@ async def sync_failed(): try: if module == "trainer": self.model_version = trainer_step - self.trainer_status = RunningStatus.WAITING_SYNC + self.trainer_status = RunningStatus.REQUIRE_SYNC self._ready_condition.notify_all() - if self.explorer_status_counts[RunningStatus.WAITING_SYNC] != 1: + if self.explorer_status_counts[RunningStatus.REQUIRE_SYNC] != 1: await asyncio.wait_for( self._ready_condition.wait_for( - lambda: self.explorer_status_counts[RunningStatus.WAITING_SYNC] + lambda: self.explorer_status_counts[RunningStatus.REQUIRE_SYNC] + self.explorer_status_counts[RunningStatus.STOPPED] == 1, ), @@ -396,18 +404,18 @@ async def sync_failed(): return await sync_failed() await self.set_explorer_status( RunningStatus.RUNNING, - old_status=RunningStatus.WAITING_SYNC, + old_status=RunningStatus.REQUIRE_SYNC, ) elif module == "explorer": await self.set_explorer_status( - RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING ) self._ready_condition.notify_all() - if self.trainer_status != RunningStatus.WAITING_SYNC: + if self.trainer_status != RunningStatus.REQUIRE_SYNC: await asyncio.wait_for( self._ready_condition.wait_for( lambda: self.trainer_status - in {RunningStatus.WAITING_SYNC, RunningStatus.STOPPED}, + in {RunningStatus.REQUIRE_SYNC, RunningStatus.STOPPED}, ), timeout=self.config.synchronizer.sync_timeout, ) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index cb16bf7b643..80025f411b7 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -61,6 +61,9 @@ def __init__(self, config: Config) -> None: self.save_interval = config.trainer.save_interval self.last_sync_step = 0 self.last_sync_time = None + self.sync_interval = config.synchronizer.sync_interval + self.sync_method = config.synchronizer.sync_method + self.sync_style = config.synchronizer.sync_style self.total_steps = config.trainer.total_steps or float("inf") self.save_hf_checkpoint = config.trainer.save_hf_checkpoint @@ -139,24 +142,17 @@ async def _sample_data(self) -> Tuple[List[Experience], Dict, List[Dict]]: async def need_sync(self) -> bool: """Whether to sync the model weight.""" - if self.config.synchronizer.sync_style == SyncStyle.FIXED: + if self.sync_style in {SyncStyle.FIXED, SyncStyle.TRAINER_DRIVEN}: return ( self.last_sync_step != self.train_step_num - and self.train_step_num % self.config.synchronizer.sync_interval == 0 + and self.train_step_num % self.sync_interval == 0 ) - else: - if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER: - delta = self.train_step_num - self.last_sync_step - if delta >= self.config.synchronizer.sync_interval: - await self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC) - explorer_status_counts = await self.synchronizer.get_explorer_status_counts.remote() - if self.config.synchronizer.sync_method == SyncMethod.NCCL: - return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0 - else: # memory & checkpoint - return ( - self.last_sync_step != self.train_step_num - and explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 - ) + else: # explorer driven + # for memory & checkpoint; TODO: apply to nccl sync + if self.last_sync_step == self.train_step_num and self.sync_method != SyncMethod.NCCL: + await self.synchronizer.notify_no_new_model_state_dict.remote() + return False + return await self.synchronizer.explorer_requires_sync.remote() def need_save(self) -> bool: """Whether to save the checkpoint.""" @@ -169,7 +165,7 @@ async def sync_weight(self) -> Dict: if self.last_sync_time is not None: metrics["time/trainer_sync_interval"] = time.time() - self.last_sync_time with Timer(metrics, "time/sync_weight"): - if self.config.synchronizer.sync_method == SyncMethod.NCCL: + if self.sync_method == SyncMethod.NCCL: result = await self.synchronizer.ready_to_nccl_sync.remote( "trainer", self.train_step_num ) @@ -177,13 +173,12 @@ async def sync_weight(self) -> Dict: self.logger.error("Trainer sync_weights failed.") else: self.engine.sync_weight() - elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: + elif self.sync_method == SyncMethod.CHECKPOINT: await self.engine.save_state_dict() - elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: + elif self.sync_method == SyncMethod.MEMORY: await self.engine.upload_state_dict() self.last_sync_step = self.train_step_num self.last_sync_time = time.time() - await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING) self.logger.info(f"Trainer sync_weights at step {self.train_step_num} finished.") return metrics