diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst index 5b89ea2147..21794b4138 100644 --- a/docs/sphinx_doc/source/index.rst +++ b/docs/sphinx_doc/source/index.rst @@ -21,6 +21,7 @@ Welcome to Trinity-RFT's documentation! tutorial/develop_algorithm.md tutorial/example_mix_algo.md tutorial/develop_operator.md + tutorial/develop_selector.md tutorial/trinity_configs.md tutorial/synchronizer.md diff --git a/docs/sphinx_doc/source/tutorial/develop_operator.md b/docs/sphinx_doc/source/tutorial/develop_operator.md index f658e5bff6..f05b1dc62d 100644 --- a/docs/sphinx_doc/source/tutorial/develop_operator.md +++ b/docs/sphinx_doc/source/tutorial/develop_operator.md @@ -1,3 +1,4 @@ +(Operators)= ## Operator Development Guide ### Step 0: Basic Concepts of Operator Module diff --git a/docs/sphinx_doc/source/tutorial/develop_overview.md b/docs/sphinx_doc/source/tutorial/develop_overview.md index f47cb967ff..91eecc2d60 100644 --- a/docs/sphinx_doc/source/tutorial/develop_overview.md +++ b/docs/sphinx_doc/source/tutorial/develop_overview.md @@ -11,6 +11,7 @@ The table below lists the main functions of each extension interface, its target | `Workflow` | Agent Application Developers | Enhance agent's ability to complete tasks in a specified environment | [🔗](./develop_workflow.md) | | `Algorithm` | RL Algorithm Researchers | Design new RL algorithms | [🔗](./develop_algorithm.md) | | `Operator` | Data Engineers | Design new data cleaning and augmentation strategies | [🔗](./develop_operator.md) | +| `Selector` | Data Engineers | Design new task selection strategies | [🔗](./develop_selector.md) | ```{tip} Trinity-RFT provides a modular development approach, allowing you to flexibly add custom modules without modifying the framework code. diff --git a/docs/sphinx_doc/source/tutorial/develop_selector.md b/docs/sphinx_doc/source/tutorial/develop_selector.md new file mode 100644 index 0000000000..d7593ce036 --- /dev/null +++ b/docs/sphinx_doc/source/tutorial/develop_selector.md @@ -0,0 +1,264 @@ + +# 🧪 Experimental: Task Selection & Scheduling System + +```{note} +This module is currently in **experimental status**. Interfaces may change in future versions. +This document describes the functionality and intended usage of the system. +``` + + + +## Overview + +This system enables **intelligent, adaptive task sampling** from multiple datasets (called *tasksets*) during exploration. It consists of two core components: + +1. **`Selector`** – Controls how individual samples are selected *within* each taskset. +2. **`TasksetScheduler`** – Manages *which* tasksets contribute to each batch and coordinates their sampling. + +Together, they support advanced training strategies such as: +- Curriculum learning (easy → hard) +- Multi-task interleaving or mixing +- Difficulty-aware sampling +- Adaptive data selection based on model performance + +These capabilities allow you to train models more efficiently by focusing on informative or challenging examples. + + + +## Module 1: Selector – Customizable Data Selection + +A `Selector` determines **which tasks (samples) to select** from its associated dataset (`Taskset`). Beyond basic strategies like sequential or random access, it supports **adaptive algorithms** that adjust sampling based on feedback—such as sample difficulty, model confidence, or reward signals. + +### Built-in Selectors + +| Selector Type | Description | +|---------------|-------------| +| `sequential` | Returns samples in fixed order (0, 1, ..., N). | +| `shuffle` | Shuffles the dataset once per epoch; then iterates sequentially. | +| `random` | Randomly samples without replacement within each batch. Independent across batches. | +| `offline_easy2hard` | Sorts samples by pre-defined features (e.g., loss, length), serving easier ones first, progressing to harder ones. | +| `difficulty_based` *(custom example)* | Dynamically selects samples near a target difficulty level using probabilistic modeling. | + +You can also **implement your own custom selector** to enable adaptive or curriculum-based learning. + + + +### ✅ Step 1: Implement a Custom Selector + +To create a new selector, inherit from `BaseSelector` and implement the following methods: + +#### Required Methods + +| Method | Purpose | +|-------|--------| +| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | Return a list of sample indices to read next. | +| `update(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). | +| `state_dict() -> Dict` | Serialize current state for checkpointing. | +| `load_state_dict(state_dict: Dict)` | Restore state from a saved dictionary. | + +#### Example: `DifficultyBasedSelector` + +This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks. + +```python +@SELECTORS.register_module("difficulty_based") +class DifficultyBasedSelector(BaseSelector): + def __init__(self, data_source, config: TaskSelectorConfig) -> None: + super().__init__(data_source, config) + self.logger = get_logger("difficulty_based_selector") + + # Build difficulty estimator using two input features (e.g., correctness, uncertainty) + self.diff_estimator = self.build_diff_estimator( + data_source.dataset, config.feature_keys, config.kwargs + ) + self.current_index = 0 + self.seed = config.seed + + # Configuration parameters + self.do_sample = config.kwargs.get("do_sample", False) + self.target_reward = config.kwargs.get("target_reward", 1.0) + self.tau = config.kwargs.get("tau", 1.0) + + # ... detailed implementation + + def get_indices(self, batch_size, return_extra_info=False): + # Compute scores based on proximity to target reward + sampling_scores = self.get_scores() + sampling_scores = torch.from_numpy(sampling_scores) + + if self.tau == 0: + # Greedy: take top-k highest scoring samples + selected_indices = torch.topk(sampling_scores, batch_size).indices + else: + # Stochastic: sample via softmax with temperature scaling + sampling_logits = sampling_scores / self.tau + sampling_logits -= sampling_logits.max() # Stability + sampling_probabilities = torch.softmax(sampling_logits, dim=0) + rng = torch.Generator().manual_seed(self.seed + self.current_index) + selected_indices = torch.multinomial( + sampling_probabilities, + batch_size, + replacement=False, + generator=rng, + ) + + self.current_index += batch_size + + if return_extra_info: + # Optional debugging info + extra_info = { + "indices": selected_indices.tolist(), + "scores": sampling_scores[selected_indices].tolist(), + # ... other metadata + } + return selected_indices, extra_info + else: + return selected_indices + + def update(self, indices: List[int], values: List[float]) -> None: + # Update difficulty model with observed rewards + self.diff_estimator.update(indices, values) + + def state_dict(self) -> Dict: + return {"current_index": self.current_index} + + def load_state_dict(self, state_dict: Dict) -> None: + self.current_index = state_dict.get("current_index", 0) +``` + +> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs. + + + +### ✅ Step 2: Implement a Feedback Operator + +For adaptive selectors like `DifficultyBasedSelector`, you need to provide runtime feedback (e.g., task rewards). This is done via an **Experience Operator** that processes rollouts and computes metrics. + +> 📚 See the {ref}`Operator Development Guide` for more on building custom experience processors. + +The operator must output a metric under the key `trinity.common.constants.SELECTOR_METRIC`, structured as: + +```python +{ + SELECTOR_METRIC: { + 0: { # taskset_id + "indices": [10, 25, 43], + "values": [0.8, 0.6, 0.9] # e.g., average reward + }, + 1: { ... } + } +} +``` + +#### Example: Pass Rate Calculator + +```python +@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator") +class PassRateCalculator(ExperienceOperator): + def __init__(self, **kwargs): + pass + + def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: + raw_metric = defaultdict(lambda: defaultdict(list)) + + for exp in exps: + task_index = exp.info["task_index"] + assert "taskset_id" in task_index and "index" in task_index + raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward) + + metric = {} + for taskset_id, task_metrics in raw_metric.items(): + indices = [] + reward_means = [] + for idx, rewards in task_metrics.items(): + indices.append(idx) + reward_means.append(float(np.mean(rewards))) + metric[taskset_id] = { + "indices": indices, + "values": reward_means, + } + + return exps, {SELECTOR_METRIC: metric} +``` + +This operator calculates the average reward per task and passes it back to the corresponding selector for updating difficulty estimates. + + + +### ✅ Step 3: Update Configuration + +After implementing your selector and operator, register them in the config file. + +#### Add the Operator to the Pipeline + +```yaml +data_processor: + experience_pipeline: + operators: + - name: pass_rate_calculator # Must match @register_module name +``` + +#### Configure the Taskset with Your Selector + +```yaml +buffer: + explorer_input: + tasksets: + - name: my_taskset + storage_type: file + path: ./path/to/tasks + task_selector: + selector_type: difficulty_based # Matches @register_module name + feature_keys: ["correct", "uncertainty"] + kwargs: + m: 16 + lamb: 0.2 + rho: 0.2 + target_reward: 0.9 + tau: 0.5 + do_sample: true +``` + +> 💡 You can define multiple tasksets, each with its own selector type and configuration. + + + +## Module 2: TasksetScheduler – Multi-Taskset Orchestration + +The `TasksetScheduler` manages **how different tasksets are interleaved or mixed** during training. + +### Key Features + +- Supports **multiple tasksets** simultaneously. +- Balances sampling proportionally to dataset sizes. +- **Shuffles taskset access order** at the start of each epoch. +- Enables **curriculum-style** or **interleaved multi-task training**. +- Fully **checkpointable**: resumes exactly where it left off. +- Integrates with any registered `Selector`. + +### How It Works + +At each training step: +1. Determines which tasksets should contribute to the current batch. +2. Queries each taskset’s selector to get specific sample indices. +3. Reads the actual data asynchronously. +4. Tags each task with `"taskset_id"` for downstream routing or analysis. + +Epochs are defined based on total data volume and batch size: +```python +steps_per_epoch = total_samples // batch_size +``` + +At the beginning of each epoch, the scheduler reshuffles the sequence of taskset accesses to introduce variability. + + + +## Summary + +With these components, you can: +- Use simple strategies like random or sequential sampling. +- Design **adaptive curricula** using custom selectors. +- Combine multiple datasets intelligently. +- Optimize training efficiency by focusing on high-value samples. + +By combining smart `Selectors` with the flexible `TasksetScheduler`, you gain fine-grained control over what your model sees—and when. diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_overview.md b/docs/sphinx_doc/source_zh/tutorial/develop_overview.md index 9e79780758..2ce151198f 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_overview.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_overview.md @@ -11,6 +11,7 @@ Trinity-RFT 将 RL 训练过程拆分为了三个模块:**Explorer**、**Train | `Workflow` | 智能体应用开发者 | 提升 Agent 在指定环境中完成任务的能力 | [🔗](./develop_workflow.md) | | `Algorithm` | RL 算法研究者 | 设计新的 RL 算法 | [🔗](./develop_algorithm.md) | | `Operator` | 数据工程师 | 设计新的数据清洗、增强策略 | [🔗](./develop_operator.md) | +| `Selector` | 数据工程师 | 设计新的数据选择策略 | [🔗](./develop_selector.md) | ```{tip} Trinity-RFT 提供了插件化的开发方式,可以在不修改框架代码的前提下,灵活地添加自定义模块。 diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md new file mode 100644 index 0000000000..872e3819c4 --- /dev/null +++ b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md @@ -0,0 +1,261 @@ +# 🧪 实验性功能:任务选择与调度系统 + +```{note} +该模块目前处于 **实验阶段**,接口可能在后续版本中发生变化。 +本文档描述了系统的功能及预期使用方式。 +``` + +## 概述 + +本系统支持在探索过程中,从多个数据集/任务集(称为 *tasksets*)中进行**智能、自适应的任务采样**。它包含两个核心组件: + +1. **`Selector`(选择器)** —— 控制每个任务集中**如何选择单个样本**。 +2. **`TasksetScheduler`(任务集调度器)** —— 管理**哪些任务集参与当前批次的训练**,并协调它们的采样过程。 + +二者结合,支持以下高级训练策略: +- 课程学习(由易到难) +- 多任务交替/混合训练 +- 基于难度的采样 +- 根据模型表现动态调整数据选择 + +这些能力使你能够更高效地训练模型,聚焦于信息量大或具有挑战性的样本。 + + + +## 模块 1:Selector —— 可定制的数据选择机制 + +`Selector` 决定从其对应的数据集(`Taskset`)中选择哪些**任务(样本)**。除了基本的顺序或随机访问策略外,它还支持**基于反馈信号(如样本难度、模型置信度、奖励等)动态调整采样行为的自适应算法**。 + +### 内置的选择器类型 + +| 选择器类型 | 说明 | +|-----------|------| +| `sequential` | 按固定顺序返回样本(0, 1, ..., N)。 | +| `shuffle` | 每个 epoch 开始时对数据集整体打乱一次,之后按顺序遍历。 | +| `random` | 在每个 batch 中无放回地随机采样,不同 batch 之间相互独立。 | +| `offline_easy2hard` | 根据预定义特征(如损失值、长度)对样本排序,先提供简单样本,逐步过渡到困难样本。 | +| `difficulty_based` *(自定义示例)* | 使用概率建模动态选择接近目标难度水平的样本。 | + +你也可以实现自己的**自定义选择器**,以支持自适应或课程式学习。 + + + +### ✅ 步骤 1:实现一个自定义选择器 + +要创建新的选择器,需继承 `BaseSelector` 类,并实现以下方法: + +#### 必须实现的方法 + +| 方法 | 功能说明 | +|------|---------| +| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | 返回接下来要读取的样本索引列表。 | +| `update(indices: List[int], values: List[float])` | 使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 | +| `state_dict() -> Dict` | 序列化当前状态,用于保存检查点。 | +| `load_state_dict(state_dict: Dict)` | 从保存的状态字典中恢复选择器状态。 | + +#### 示例:`DifficultyBasedSelector` + +该选择器聚焦于模型预测表现最接近目标值的样本(例如 90% 成功率),从而挑选出“难度适中”的任务。 + +```python +@SELECTORS.register_module("difficulty_based") +class DifficultyBasedSelector(BaseSelector): + def __init__(self, data_source, config: TaskSelectorConfig) -> None: + super().__init__(data_source, config) + self.logger = get_logger("difficulty_based_selector") + + # 使用两个输入特征(如正确性、不确定性)构建难度估计器 + self.diff_estimator = self.build_diff_estimator( + data_source.dataset, config.feature_keys, config.kwargs + ) + self.current_index = 0 + self.seed = config.seed + + # 配置参数 + self.do_sample = config.kwargs.get("do_sample", False) + self.target_reward = config.kwargs.get("target_reward", 1.0) + self.tau = config.kwargs.get("tau", 1.0) + + # ... 具体实现省略 + + def get_indices(self, batch_size, return_extra_info=False): + # 计算得分:越接近目标奖励得分越高 + sampling_scores = self.get_scores() + sampling_scores = torch.from_numpy(sampling_scores) + + if self.tau == 0: + # 贪心策略:选择得分最高的 top-k 样本 + selected_indices = torch.topk(sampling_scores, batch_size).indices + else: + # 随机采样:通过带温度的 softmax 进行采样 + sampling_logits = sampling_scores / self.tau + sampling_logits -= sampling_logits.max() # 数值稳定性处理 + sampling_probabilities = torch.softmax(sampling_logits, dim=0) + rng = torch.Generator().manual_seed(self.seed + self.current_index) + selected_indices = torch.multinomial( + sampling_probabilities, + batch_size, + replacement=False, + generator=rng, + ) + + self.current_index += batch_size + + if return_extra_info: + # 可选:返回调试信息 + extra_info = { + "indices": selected_indices.tolist(), + "scores": sampling_scores[selected_indices].tolist(), + # ... 其他元数据 + } + return selected_indices, extra_info + else: + return selected_indices + + def update(self, indices: List[int], values: List[float]) -> None: + # 使用观测到的奖励更新难度模型 + self.diff_estimator.update(indices, values) + + def state_dict(self) -> Dict: + return {"current_index": self.current_index} + + def load_state_dict(self, state_dict: Dict) -> None: + self.current_index = state_dict.get("current_index", 0) +``` + +> 🔁 定义完类后,请使用 `@SELECTORS.register_module("your_name")` 注册,以便在配置文件中通过名称引用。 + + + +### ✅ 步骤 2:实现反馈操作器(Feedback Operator) + +对于像 `DifficultyBasedSelector` 这样的自适应选择器,你需要提供运行时反馈(例如任务奖励)。这通过一个 **Experience Operator(经验操作器)** 实现,它处理 rollout 数据并计算相关指标。 + +> 📚 更多关于自定义经验处理器的内容,请参见 {ref}`Operator 开发指南`。 + +操作器必须输出一个键为 `trinity.common.constants.SELECTOR_METRIC` 的指标,结构如下: + +```python +{ + SELECTOR_METRIC: { + 0: { # taskset_id + "indices": [10, 25, 43], + "values": [0.8, 0.6, 0.9] # 例如:平均奖励值 + }, + 1: { ... } + } +} +``` + +#### 示例:通过率计算器(Pass Rate Calculator) + +```python +@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator") +class PassRateCalculator(ExperienceOperator): + def __init__(self, **kwargs): + pass + + def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: + raw_metric = defaultdict(lambda: defaultdict(list)) + + for exp in exps: + task_index = exp.info["task_index"] + assert "taskset_id" in task_index and "index" in task_index + raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward) + + metric = {} + for taskset_id, task_metrics in raw_metric.items(): + indices = [] + reward_means = [] + for idx, rewards in task_metrics.items(): + indices.append(idx) + reward_means.append(float(np.mean(rewards))) + metric[taskset_id] = { + "indices": indices, + "values": reward_means, + } + + return exps, {SELECTOR_METRIC: metric} +``` + +该操作器计算每个任务的平均奖励,并将其传回对应的 `Selector`,用于更新难度估计。 + + + +### ✅ 步骤 3:更新配置文件 + +完成选择器和操作器的实现后,需要在配置文件中注册它们。 + +#### 将操作器加入处理流程 + +```yaml +data_processor: + experience_pipeline: + operators: + - name: pass_rate_calculator # 必须与 @register_module 名称一致 +``` + +#### 为任务集配置你的选择器 + +```yaml +buffer: + explorer_input: + tasksets: + - name: my_taskset + storage_type: file + path: ./path/to/tasks + task_selector: + selector_type: difficulty_based # 必须与 @register_module 名称匹配 + feature_keys: ["correct", "uncertainty"] + kwargs: + m: 16 + lamb: 0.2 + rho: 0.2 + target_reward: 0.9 + tau: 0.5 + do_sample: true +``` + +> 💡 你可以定义多个任务集,每个都可以使用不同类型和配置的选择器。 + + + +## 模块 2:TasksetScheduler —— 多任务集协调调度 + +`TasksetScheduler` 负责管理训练过程中**不同任务集之间的交错方式**。 + +### 主要特性 + +- 支持**同时加载多个任务集**。 +- 按数据集大小比例**平衡采样权重**。 +- 每个 epoch 开始时**打乱任务集的访问顺序**。 +- 支持**课程式学习**或**多任务交替/混合训练**。 +- 完全**可恢复断点**:能精确从中断处继续训练。 +- 与任意已注册的 `Selector` 无缝集成。 + +### 工作原理 + +在每一步训练中: +1. 确定哪些任务集应参与当前 batch; +2. 向各任务集的选择器请求具体的样本索引; +3. 异步读取实际数据; +4. 为每个任务打上 `"taskset_id"` 标签,便于下游路由或分析。 + +每个 epoch 的步数由总样本数和 batch size 决定: +```python +steps_per_epoch = total_samples // batch_size +``` + +每个 epoch 开始时,调度器会重新打乱任务集的访问顺序,以增加多样性。 + + + +## 总结 + +通过这两个组件,你可以: +- 使用简单的策略(如随机或顺序采样); +- 利用自定义选择器设计**自适应课程学习策略**; +- 智能地融合多个数据集; +- 通过聚焦高价值样本提升训练效率。 + +将智能的 `Selector` 与灵活的 `TasksetScheduler` 结合,你将获得对模型所见内容及其出现时机的精细控制能力。 diff --git a/tests/buffer/experience_pipeline_test.py b/tests/buffer/experience_pipeline_test.py index 3adcbccab1..62d4df4eba 100644 --- a/tests/buffer/experience_pipeline_test.py +++ b/tests/buffer/experience_pipeline_test.py @@ -8,6 +8,7 @@ from trinity.buffer import get_buffer_reader from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline from trinity.common.config import ExperiencePipelineConfig, OperatorConfig +from trinity.common.constants import SELECTOR_METRIC from trinity.common.experience import EID, Experience BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_pipeline_buffer.jsonl") @@ -79,3 +80,39 @@ async def test_experience_pipeline(self): with open(config.data_processor.experience_pipeline.input_save_path, "r") as f: input_data = f.readlines() self.assertEqual(len(input_data), len(experiences)) + + async def test_pass_rate_calculation(self) -> None: + config = get_template_config() + config.data_processor.experience_pipeline = ExperiencePipelineConfig( + save_input=True, + input_save_path=BUFFER_FILE_PATH, + operators=[ + OperatorConfig( + name="pass_rate_calculator", + ) + ], + ) + config.check_and_update() + config.buffer.trainer_input.experience_buffer.name = "pipeline_test_experience_buffer" + config.buffer.trainer_input.experience_buffer.max_read_timeout = 3 + + pipeline = ( + ray.remote(ExperiencePipeline) + .options(name=f"{config.explorer.name}_pipeline") + .remote(config) + ) + await pipeline.prepare.remote() + task_num = 8 + repeat_times = 4 + experiences = get_experiences(task_num=task_num, repeat_times=repeat_times) + for exp in experiences: + exp.info["task_index"] = { + "taskset_id": 0, + "index": exp.eid.task, + } + metrics = await pipeline.process.remote(experiences) + self.assertIn(SELECTOR_METRIC, metrics) + selector_metrics = metrics[SELECTOR_METRIC] + self.assertEqual(len(selector_metrics), 1) + self.assertEqual(set(selector_metrics[0]["indices"]), set(range(task_num))) + self.assertEqual(selector_metrics[0]["values"], [(repeat_times - 1.0) / 2] * task_num) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 00aec40744..2de910d31c 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -1,4 +1,5 @@ import os +import shutil import unittest import ray @@ -26,7 +27,7 @@ def setUpClass(cls): def tearDownClass(cls): super().tearDownClass() if os.path.exists(cls.temp_output_path): - os.system(f"rm -rf {cls.temp_output_path}") + shutil.rmtree(cls.temp_output_path) def test_file_reader(self): # noqa: C901 """Test file reader.""" diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py new file mode 100644 index 0000000000..ebaf6ed697 --- /dev/null +++ b/tests/buffer/task_scheduler_test.py @@ -0,0 +1,260 @@ +import os +import shutil +import unittest +from typing import Dict, List + +from parameterized import parameterized + +from tests.tools import get_template_config +from trinity.buffer.task_scheduler import TasksetScheduler +from trinity.common.config import FormatConfig, StorageConfig, TaskSelectorConfig +from trinity.common.workflows.workflow import Task + + +class TestTaskScheduler(unittest.IsolatedAsyncioTestCase): + temp_output_path = "tmp/test_task_scheduler/" + + @classmethod + def setUpClass(cls): + super().setUpClass() + os.makedirs(cls.temp_output_path, exist_ok=True) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + if os.path.exists(cls.temp_output_path): + shutil.rmtree(cls.temp_output_path) + + def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, int]]) -> None: + for task, index in zip(batch_tasks, indices): + self.assertEqual(task.index["taskset_id"], index["taskset_id"]) + self.assertEqual(task.index["index"], index["index"]) + self.assertEqual( + task.raw_task["question"], # type: ignore + f"Question {index['index'] + 1} in subset {index['taskset_id'] + 1}.", + ) + self.assertEqual( + task.raw_task["answer"], # type: ignore + f"Answer {index['index'] + 1} in subset {index['taskset_id'] + 1}.", + ) + + @parameterized.expand( + [ + ( + {"selector_type": "sequential"}, + [ + {"index": 0, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 1, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 3, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 3, "taskset_id": 0}, + {"index": 4, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 6, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 3, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + {"index": 2, "taskset_id": 0}, + {"index": 4, "taskset_id": 1}, + {"index": 3, "taskset_id": 0}, + {"index": 5, "taskset_id": 1}, + {"index": 6, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + ], + ), + ( + {"selector_type": "shuffle", "seed": 42}, + [ + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 3, "taskset_id": 0}, + {"index": 1, "taskset_id": 0}, + {"index": 1, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 6, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 5, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + {"index": 2, "taskset_id": 0}, + {"index": 4, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 0, "taskset_id": 1}, + {"index": 3, "taskset_id": 1}, + {"index": 3, "taskset_id": 0}, + ], + ), + ( + {"selector_type": "random", "seed": 42}, + [ + {"index": 0, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 3, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 4, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 0, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 3, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 0, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 6, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 0, "taskset_id": 0}, + {"index": 5, "taskset_id": 1}, + {"index": 3, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 6, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + ], + ), + ( + {"selector_type": "offline_easy2hard", "feature_keys": ["feature_offline"]}, + [ + {"index": 3, "taskset_id": 1}, + {"index": 3, "taskset_id": 0}, + {"index": 4, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 1, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 4, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 3, "taskset_id": 0}, + {"index": 1, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 2, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 5, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + ], + ), + ( + {"selector_type": "difficulty_based", "feature_keys": ["feat_1", "feat_2"]}, + [ + {"index": 3, "taskset_id": 1}, + {"index": 3, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 3, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 3, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 3, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 4, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 3, "taskset_id": 0}, + {"index": 5, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 6, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + ], + ), + ] + ) + async def test_task_scheduler(self, task_selector_kwargs, batch_tasks_orders) -> None: + config = get_template_config() + config.buffer.batch_size = 2 + config.buffer.total_epochs = 2 + config.buffer.explorer_input.taskset = None + config.buffer.explorer_input.tasksets = [ + StorageConfig( + name="subset_1", + path=os.path.join( + os.path.dirname(__file__), + "..", + "template", + "data", + "task_scheduler", + "subset_1", + ), + split="train", + enable_progress_bar=False, + format=FormatConfig( + prompt_key="question", + response_key="answer", + ), + default_workflow_type="math_workflow", + default_reward_fn_type="math_reward", + task_selector=TaskSelectorConfig( + **task_selector_kwargs, + ), + ), + StorageConfig( + name="subset_2", + path=os.path.join( + os.path.dirname(__file__), + "..", + "template", + "data", + "task_scheduler", + "subset_2", + ), + split="train", + enable_progress_bar=False, + format=FormatConfig( + prompt_key="question", + response_key="answer", + ), + default_workflow_type="math_workflow", + default_reward_fn_type="math_reward", + task_selector=TaskSelectorConfig( + **task_selector_kwargs, + ), + ), + ] + config.check_and_update() + + task_scheduler = TasksetScheduler({}, config) + self.assertEqual(len(batch_tasks_orders) % config.buffer.batch_size, 0) + for i, start_id in enumerate(range(0, len(batch_tasks_orders), config.buffer.batch_size)): + batch_tasks_indices = batch_tasks_orders[start_id : start_id + config.buffer.batch_size] + batch_tasks = await task_scheduler.read_async() + # for task in batch_tasks: # used for debug + # print(f"{task.index},") + self._check_batch_tasks(batch_tasks, batch_tasks_indices) + if i % 3 == 2: + # test resume + state_dict = { + "latest_iteration": task_scheduler.step, + "taskset_states": task_scheduler.state_dict(), + } + task_scheduler = TasksetScheduler(state_dict, config) + + with self.assertRaises(StopAsyncIteration): + batch_tasks = await task_scheduler.read_async() diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index 1b8ab142e8..99c3d77077 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -263,7 +263,7 @@ def test_debug_mode(self, mock_load): except Exception: time.sleep(3) output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html") - self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")] mock_load.return_value = self.config with mock.patch( "argparse.ArgumentParser.parse_args", diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 0a6a5557b0..07e9a383e2 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -32,7 +32,7 @@ def test_load_default_config(self): self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.project) self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.name) self.assertEqual( - config.buffer.explorer_input.taskset.repeat_times, config.algorithm.repeat_times + config.buffer.explorer_input.tasksets[0].repeat_times, config.algorithm.repeat_times ) self.assertEqual(config.model.model_path, config.model.critic_model_path) self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path) @@ -110,7 +110,7 @@ def test_default_workflow(self): "math_boxed_workflow", ) self.assertEqual( - config.buffer.explorer_input.taskset.default_workflow_type, + config.buffer.explorer_input.tasksets[0].default_workflow_type, "simple_workflow", ) diff --git a/tests/template/data/task_scheduler/subset_1/train.jsonl b/tests/template/data/task_scheduler/subset_1/train.jsonl new file mode 100644 index 0000000000..13fd52ee51 --- /dev/null +++ b/tests/template/data/task_scheduler/subset_1/train.jsonl @@ -0,0 +1,5 @@ +{"question": "Question 1 in subset 1.", "answer": "Answer 1 in subset 1.", "feature_offline": 0.5, "feat_1": 0.4, "feat_2": 0.3} +{"question": "Question 2 in subset 1.", "answer": "Answer 2 in subset 1.", "feature_offline": 0.1, "feat_1": 0.1, "feat_2": 0.1} +{"question": "Question 3 in subset 1.", "answer": "Answer 3 in subset 1.", "feature_offline": 0.4, "feat_1": 0.5, "feat_2": 0.3} +{"question": "Question 4 in subset 1.", "answer": "Answer 4 in subset 1.", "feature_offline": 0.5, "feat_1": 0.3, "feat_2": 0.5} +{"question": "Question 5 in subset 1.", "answer": "Answer 5 in subset 1.", "feature_offline": 0.2, "feat_1": 0.1, "feat_2": 0.5} diff --git a/tests/template/data/task_scheduler/subset_2/train.jsonl b/tests/template/data/task_scheduler/subset_2/train.jsonl new file mode 100644 index 0000000000..dc93a82c7b --- /dev/null +++ b/tests/template/data/task_scheduler/subset_2/train.jsonl @@ -0,0 +1,7 @@ +{"question": "Question 1 in subset 2.", "answer": "Answer 1 in subset 2.", "feature_offline": 0.2, "feat_1": 0.5, "feat_2": 0.2} +{"question": "Question 2 in subset 2.", "answer": "Answer 2 in subset 2.", "feature_offline": 0.3, "feat_1": 0.6, "feat_2": 0.2} +{"question": "Question 3 in subset 2.", "answer": "Answer 3 in subset 2.", "feature_offline": 0.1, "feat_1": 0.7, "feat_2": 0.4} +{"question": "Question 4 in subset 2.", "answer": "Answer 4 in subset 2.", "feature_offline": 0.5, "feat_1": 0.1, "feat_2": 0.4} +{"question": "Question 5 in subset 2.", "answer": "Answer 5 in subset 2.", "feature_offline": 0.3, "feat_1": 0.1, "feat_2": 0.7} +{"question": "Question 6 in subset 2.", "answer": "Answer 6 in subset 2.", "feature_offline": 0.1, "feat_1": 0.7, "feat_2": 0.4} +{"question": "Question 7 in subset 2.", "answer": "Answer 7 in subset 2.", "feature_offline": 0.1, "feat_1": 0.7, "feat_2": 0.6} diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index acb4325fce..8815c3a132 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -30,6 +30,7 @@ ExplorerInput, StageConfig, StorageConfig, + TaskSelectorConfig, TrainerInput, ) from trinity.common.constants import ( @@ -73,6 +74,9 @@ def test_trainer(self): """Test the both and bench mode.""" # test both mode self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig( + selector_type="shuffle", seed=42 + ) self.config.buffer.explorer_input.eval_tasksets.append( get_unittest_dataset_config("countdown", "test") ) @@ -778,8 +782,7 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - # shutil.rmtree(self.config.checkpoint_job_dir) - pass + shutil.rmtree(self.config.checkpoint_job_dir) class TestTrainerMIX(BaseTrainerCase): diff --git a/trinity/buffer/buffer_reader.py b/trinity/buffer/buffer_reader.py index dc605a0277..5315bf7ecf 100644 --- a/trinity/buffer/buffer_reader.py +++ b/trinity/buffer/buffer_reader.py @@ -1,6 +1,6 @@ """Reader of the buffer.""" from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Dict, List, Optional class BufferReader(ABC): @@ -13,3 +13,13 @@ def read(self, batch_size: Optional[int] = None) -> List: @abstractmethod async def read_async(self, batch_size: Optional[int] = None) -> List: """Read from buffer asynchronously.""" + + def __len__(self) -> int: + """Get the number of samples in buffer.""" + raise NotImplementedError + + def state_dict(self) -> Dict: + return {} + + def load_state_dict(self, state_dict: Dict) -> None: + pass diff --git a/trinity/buffer/operators/__init__.py b/trinity/buffer/operators/__init__.py index 8048c31fa3..4153c049b2 100644 --- a/trinity/buffer/operators/__init__.py +++ b/trinity/buffer/operators/__init__.py @@ -4,6 +4,7 @@ ExperienceOperator, ) from trinity.buffer.operators.filters.reward_filter import RewardFilter, RewardSTDFilter +from trinity.buffer.operators.mappers.pass_rate_calculator import PassRateCalculator from trinity.buffer.operators.mappers.reward_shaping_mapper import RewardShapingMapper __all__ = [ @@ -12,5 +13,6 @@ "RewardFilter", "RewardSTDFilter", "RewardShapingMapper", + "PassRateCalculator", "DataJuicerOperator", ] diff --git a/trinity/buffer/operators/mappers/pass_rate_calculator.py b/trinity/buffer/operators/mappers/pass_rate_calculator.py new file mode 100644 index 0000000000..38ff5627c5 --- /dev/null +++ b/trinity/buffer/operators/mappers/pass_rate_calculator.py @@ -0,0 +1,37 @@ +from collections import defaultdict +from typing import Dict, List, Tuple + +import numpy as np + +from trinity.buffer.operators.experience_operator import ( + EXPERIENCE_OPERATORS, + ExperienceOperator, +) +from trinity.common.constants import SELECTOR_METRIC +from trinity.common.experience import Experience + + +@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator") +class PassRateCalculator(ExperienceOperator): + def __init__(self, **kwargs): + pass + + def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: + raw_metric = defaultdict(lambda: defaultdict(list)) + for exp in exps: + task_index = exp.info["task_index"] + assert "taskset_id" in task_index + assert "index" in task_index + raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward) + metric = {} + for taskset_id, taskset_metric in raw_metric.items(): + indices = [] + reward_means = [] + for index, rewards in taskset_metric.items(): + indices.append(index) + reward_means.append(float(np.mean(rewards))) + metric[taskset_id] = { + "indices": indices, + "values": reward_means, + } + return exps, {SELECTOR_METRIC: metric} diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index f92b4c638e..ca08b758c5 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -11,7 +11,7 @@ ExperiencePipelineConfig, StorageConfig, ) -from trinity.common.constants import StorageType +from trinity.common.constants import SELECTOR_METRIC, StorageType from trinity.common.experience import Experience from trinity.utils.log import get_logger from trinity.utils.plugin_loader import load_plugins @@ -132,6 +132,8 @@ async def process(self, exps: List[Experience]) -> Dict: for key, value in metrics.items(): if isinstance(value, (int, float)): result_metrics[f"pipeline/{key}"] = float(value) + if SELECTOR_METRIC in metrics: + result_metrics[SELECTOR_METRIC] = metrics[SELECTOR_METRIC] return result_metrics diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index b79e87285d..6f3bfe1c31 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -1,6 +1,6 @@ """Filed based buffer reader.""" -from typing import List, Optional +from typing import List, Optional, Tuple import datasets from datasets import Dataset, load_dataset @@ -40,10 +40,6 @@ def __init__( self.drop_last = drop_last self.current_offset = offset - self.iter = iter(self.dataset) - - for _ in range(self.current_offset % self.dataset_size): - next(self.iter) # convert epochs/steps to sample number if total_steps: @@ -63,34 +59,34 @@ def __init__( self.progress_bar.update(self.current_offset) - def read_batch(self, batch_size: int) -> List: - if self.current_offset >= self.total_samples: - self.progress_bar.close() - raise StopIteration - batch = [] - + def read_batch(self, batch_size: int) -> Tuple[List, List]: + batch, indices = [], [] while len(batch) < batch_size: - try: - item = next(self.iter) - batch.append(item) - self.current_offset += 1 - except StopIteration: - if self.current_offset >= self.total_samples: - # No more data to read - if not self.drop_last and len(batch) > 0: - # return last batch - self.progress_bar.update(len(batch)) - return batch - else: - self.progress_bar.close() - raise StopIteration - # Step to the next epoch - self.iter = iter(self.dataset) - self.progress_bar.update(batch_size) + if self.current_offset >= self.total_samples: + if not self.drop_last and len(batch) > 0: + break + self.progress_bar.close() + raise StopIteration + index = self.current_offset % self.dataset_size + batch.append(self.dataset[index]) + indices.append(index) + self.current_offset += 1 + + self.progress_bar.update(len(batch)) + return batch, indices + + def select_batch(self, indices: List[int]) -> List: + batch = [] + for i in indices: + assert 0 <= i < self.dataset_size + batch.append(self.dataset[int(i)]) return batch class BaseFileReader(BufferReader): + def __len__(self): + return self.dataset.dataset_size + async def read_async(self, batch_size: Optional[int] = None): try: return self.read(batch_size) @@ -117,7 +113,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): ) def read(self, batch_size: Optional[int] = None) -> List: - samples = self.dataset.read_batch(batch_size or self.read_batch_size) + samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size) exp_list = [] for sample in samples: experience = self.formatter.format(sample) @@ -147,11 +143,24 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): ) self.formatter = FORMATTER.get("task")(meta) - def read(self, batch_size: Optional[int] = None) -> List: - batch_size = batch_size or self.read_batch_size + def _get_tasks(self, samples: List, indices: List) -> List: tasks = [] - samples = self.dataset.read_batch(batch_size) - for sample in samples: + for sample, index in zip(samples, indices): task = self.formatter.format(sample) + task.index["index"] = int(index) tasks.append(task) return tasks + + def read(self, batch_size: Optional[int] = None) -> List: + batch_size = batch_size or self.read_batch_size + samples, indices = self.dataset.read_batch(batch_size) + return self._get_tasks(samples, indices) + + def read_with_indices(self, indices: List[int]) -> List: + """Read tasks with indices.""" + samples = self.dataset.select_batch(indices) + return self._get_tasks(samples, indices) + + async def read_with_indices_async(self, indices: List[int]) -> List: + """Read tasks with indices asynchronously.""" + return self.read_with_indices(indices) diff --git a/trinity/buffer/selector/__init__.py b/trinity/buffer/selector/__init__.py new file mode 100644 index 0000000000..1b84348a3b --- /dev/null +++ b/trinity/buffer/selector/__init__.py @@ -0,0 +1,5 @@ +from trinity.buffer.selector.selector import SELECTORS + +__all__ = [ + "SELECTORS", +] diff --git a/trinity/buffer/selector/difficulty_estimator.py b/trinity/buffer/selector/difficulty_estimator.py new file mode 100644 index 0000000000..97c4ab61e1 --- /dev/null +++ b/trinity/buffer/selector/difficulty_estimator.py @@ -0,0 +1,119 @@ +from typing import List + +import numpy as np + +from trinity.utils.log import get_logger + + +class BaseBetaPREstimator: + n: int + m: int + lamb: float + rho: float + alphas: np.ndarray + betas: np.ndarray + + def __init__(self, n: int, m: int = 16, lamb: float = 0.2, rho: float = 0.2): + """ + alpha_{t+1} = (1 - lamb) * alpha_t + (1 - rho) * bar{s} + rho * tilde{s} + beta_{t+1} = (1 - lamb) beta_t + (1 - rho) * bar{f} + rho * tilde{f} + + Args: + n (int): number of tasks. + m (int): repeat times per tasks. + timeout (lamb): discount factor of historical estimation. + rho (float): weight of pseudo counts. + """ + self.n = n + self.m = m + self.lamb = lamb + self.rho = rho + self.alphas = np.ones(n, dtype=float) + self.betas = np.ones(n, dtype=float) + self.logger = get_logger("BetaPREstimator") + self.logger.debug( + f"{self.n=}, {self.m=}, {self.lamb=}, {self.rho=}, {self.alphas=}, {self.betas=}" + ) + + def set(self, alphas, betas): + self.alphas = alphas + self.betas = betas + + def _update(self, s_bar, f_bar, p_tilde): + self.alphas = ( + (1 - self.lamb) * self.alphas + + self.lamb + + (1 - self.rho) * s_bar + + self.rho * p_tilde * self.m + ) + self.betas = ( + (1 - self.lamb) * self.betas + + self.lamb + + (1 - self.rho) * f_bar + + self.rho * (1 - p_tilde) * self.m + ) + + def update(self, ref_indices: List[int], ref_pass_rates: List[float]): + raise NotImplementedError + + def predict_pr(self, rng=None, indices=None, do_sample=False): + if rng is None: + rng = np.random.default_rng() + if indices is None: + indices = np.arange(self.n) + if not do_sample: + return self.alphas[indices] / (self.alphas[indices] + self.betas[indices]) + else: + return rng.beta(self.alphas[indices], self.betas[indices]) + + def equivalent_count(self, indices=None): + if indices is None: + indices = np.arange(self.n) + return self.alphas[indices] + self.betas[indices] + + +class InterpolationBetaPREstimator(BaseBetaPREstimator): + def __init__( + self, + features: np.ndarray, + m: int, + lamb, + rho, + cap_coef_update_discount=0.9, + adaptive_rho=False, + ): + super(InterpolationBetaPREstimator, self).__init__(len(features), m, lamb, rho) + self.features = features # [D, 2] + self.cap_coef = None + self.cap_coef_update_discount = cap_coef_update_discount + self.adaptive_rho = adaptive_rho + + def update(self, ref_indices: List[int], ref_pass_rates: List[float]): + ref_pass_rate = np.mean(ref_pass_rates) + ref_anchor_pass_rates = np.mean(self.features[ref_indices], axis=0) + cap_estimate = (ref_pass_rate - ref_anchor_pass_rates[0]) / ( + ref_anchor_pass_rates[1] - ref_anchor_pass_rates[0] + 1e-6 + ) + if self.cap_coef is None: + self.cap_coef = cap_estimate + else: + self.cap_coef = ( + self.cap_coef_update_discount * self.cap_coef + + (1 - self.cap_coef_update_discount) * cap_estimate + ) + s_bar = np.zeros(self.n, dtype=float) + s_bar[ref_indices] = np.array(ref_pass_rates) * self.m + f_bar = np.zeros(self.n, dtype=float) + f_bar[ref_indices] = (1 - np.array(ref_pass_rates)) * self.m + p_tilde = np.clip( + (self.features[:, 1] - self.features[:, 0]) * self.cap_coef + self.features[:, 0], 0, 1 + ) + + predicted_pass_rates = p_tilde[ref_indices] + mean_abs_error = np.mean(np.abs(np.array(predicted_pass_rates) - np.array(ref_pass_rates))) + if self.adaptive_rho and mean_abs_error >= 0.25: + self.rho = self.rho * 0.5 + self.logger.debug(f"{mean_abs_error=}, {self.rho=}") + p_tilde[ref_indices] = np.array(ref_pass_rates) + + self._update(s_bar, f_bar, p_tilde) diff --git a/trinity/buffer/selector/selector.py b/trinity/buffer/selector/selector.py new file mode 100644 index 0000000000..cc04a573ae --- /dev/null +++ b/trinity/buffer/selector/selector.py @@ -0,0 +1,430 @@ +"""Data selectors.""" +from typing import Dict, List + +import numpy as np +import torch + +from trinity.buffer.reader.file_reader import _HFBatchReader +from trinity.buffer.selector.difficulty_estimator import InterpolationBetaPREstimator +from trinity.common.config import TaskSelectorConfig +from trinity.utils.annotations import Experimental +from trinity.utils.log import get_logger +from trinity.utils.registry import Registry + +SELECTORS = Registry("selectors") + + +@Experimental +class BaseSelector: + """ + Abstract base class defining the interface for custom data selection strategies. + + A selector determines which samples (by index) are selected from the dataset + during training. It enables flexible sampling beyond simple + sequential or random access, supporting active learning, curriculum learning, + or difficulty-based sampling in the future. + + Subclasses must implement: + - get_indices: returns list of indices for next batch + - update: updates internal state using feedback (e.g., loss values, mean rewards, etc.) + - state_dict / load_state_dict: for saving/loading selector state (checkpointing) + """ + + def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + self.data_source = data_source + self.config = config + + def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]: + """ + Select a batch of sample indices from the dataset. + + Args: + batch_size (int): Number of indices to return + return_extra_info (bool): If True, may return additional metadata (future use) + + Returns: + List[int]: Selected indices into the dataset + """ + raise NotImplementedError + + def update(self, indices: List[int], values: List[float]) -> None: + """ + Update internal state based on feedback (e.g., model loss, accuracy). + + This allows adaptive selectors (like hard example mining) to learn over time. + + Args: + indices (List[int]): Previously selected indices + values (List[float]): Feedback values corresponding to those indices + """ + raise NotImplementedError + + def state_dict(self) -> Dict: + """ + Return serializable state of the selector for checkpointing. + + Returns: + Dict: State information (e.g., current position, etc.) + """ + raise NotImplementedError + + def load_state_dict(self, state_dict: Dict) -> None: + """ + Restore selector state from a saved dictionary. + + Args: + state_dict (Dict): Output from state_dict() + """ + raise NotImplementedError + + +@SELECTORS.register_module("sequential") +class SequentialSelector(BaseSelector): + """ + Selects data sequentially in fixed order across epochs. + + Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc. + """ + + def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + super().__init__(data_source, config) + self.dataset_size = data_source.dataset_size + self.current_index = 0 + + def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]: + start = self.current_index % self.dataset_size + end = start + batch_size + self.current_index += batch_size + if end <= self.dataset_size: + return list(range(start, end)) + return list(range(start, self.dataset_size)) + list(range(0, end - self.dataset_size)) + + def update(self, indices: List[int], values: List[float]) -> None: + # No-op: sequential selection doesn't adapt based on feedback + pass + + def state_dict(self) -> Dict: + return { + "current_index": self.current_index, + } + + def load_state_dict(self, state_dict): + self.current_index = state_dict.get("current_index", 0) + + +@SELECTORS.register_module("shuffle") +class ShuffleSelector(BaseSelector): + """ + Shuffles dataset once per epoch and iterates through it sequentially. + + Each epoch uses a different permutation of a subset of the full dataset. + When one epoch ends, a new shuffle is triggered. + Mimics standard PyTorch DataLoader with shuffle=True. + """ + + def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + super().__init__(data_source, config) + self.dataset_size = data_source.dataset_size # Total available samples + self.current_index = 0 # Progress tracker + self.seed = config.seed # For reproducible shuffling + self.orders = self._get_orders() # Current shuffled index order + + def _get_orders(self) -> List[int]: + """ + Generate a new shuffled order for the current epoch. + + Uses NumPy's PCG64 random generator seeded by epoch number for reproducibility. + Ensures different shuffle per epoch while being deterministic if seed is fixed. + """ + rng = np.random.default_rng(self.seed + self.current_index // self.dataset_size) + return rng.permutation(self.dataset_size).tolist() + + def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]: + start = self.current_index % self.dataset_size + end = start + batch_size + if end <= self.dataset_size: + ret = self.orders[start:end] + # At end of epoch, reshuffle for next epoch + if end == self.dataset_size: + self.orders = self._get_orders() + else: + ret = self.orders[start:] + # At end of epoch, reshuffle for next epoch + self.orders = self._get_orders() + ret += self.orders[: (end - self.dataset_size)] + self.current_index += batch_size + return ret + + def update(self, indices: List[int], values: List[float]) -> None: + # No-op: static shuffling does not adapt + pass + + def state_dict(self) -> Dict: + return { + "current_index": self.current_index, + } + + def load_state_dict(self, state_dict): + self.current_index = state_dict.get("current_index", 0) + self.orders = self._get_orders() + + +@SELECTORS.register_module("random") +class RandomSelector(BaseSelector): + """ + Uniformly samples batches randomly with replacement *per batch*. + + Unlike ShuffleSelector, there is no concept of an epoch — every batch is independently sampled. + Can result in repeated samples within an epoch. Suitable for online or stochastic training regimes. + """ + + def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + super().__init__(data_source, config) + self.dataset_size = data_source.dataset_size + self.current_index = 0 + self.seed = config.seed + + def get_indices(self, batch_size, return_extra_info=False): + # Seed varies per batch to ensure repeatability across runs + rng = np.random.default_rng(self.seed + self.current_index) + selected_indices = rng.choice(self.dataset_size, batch_size, replace=False) + self.current_index += batch_size + if return_extra_info: + return selected_indices, {} + else: + return selected_indices + + def update(self, indices: List[int], values: List[float]) -> None: + # No-op: basic random selection doesn't adapt + pass + + def state_dict(self) -> Dict: + return { + "current_index": self.current_index, + } + + def load_state_dict(self, state_dict): + self.current_index = state_dict.get("current_index", 0) + + +@SELECTORS.register_module("offline_easy2hard") +class OfflineEasy2HardSelector(BaseSelector): + """ + Selects samples in an 'easy-to-hard' curriculum based on pre-defined difficulty features. + + This selector assumes that higher feature values indicate easier examples. + It sorts all data once at initialization by descending feature value(s), then sequentially + serves batches from easy → hard over epochs. The sorting is fixed (offline), so no online + adaptation occurs during training. + + Useful for curriculum learning where sample difficulty is estimated ahead of time + (e.g., via teacher model confidence, length, BLEU score, etc.). + """ + + def __init__(self, data_source, config: TaskSelectorConfig): + super().__init__(data_source, config) + self.logger = get_logger("offline_easy2hard_selector") + + # Extract specified feature columns (e.g., 'loss', 'confidence') used to estimate difficulty + feature_keys = config.feature_keys + self.features = np.concatenate( + [np.array(list(data_source.dataset[k]))[:, None] for k in feature_keys], axis=1 + ) + # Shape: (N, len(feature_keys)) — one row per sample, one column per feature + + # Append index to each feature vector for tracking original positions after sorting + features_with_index = [list(self.features[i]) + [i] for i in range(len(self.features))] + + # Sort by feature values in descending order → highest (easiest) first + features_with_index = sorted(features_with_index)[::-1] + self.logger.debug(f"OfflineEasy2HardSelector, sorted {features_with_index[:20]}") + + # Store the sorted order of indices (from easiest to hardest) + self.sorted_index = np.array([i[-1] for i in features_with_index]) + + # Number of samples per epoch (may be less than full dataset size) + self.dataset_size = data_source.dataset_size + self.current_index = 0 + + def update(self, indices: List[int], values: List[float]) -> None: + # No-op: this selector does not adapt based on runtime feedback + pass + + def get_indices(self, batch_size, return_extra_info=False): + """ + Returns next batch of indices in curriculum order (easy → hard). + + Batches are taken sequentially from the pre-sorted list. When epoch ends, + it wraps around to the beginning (i.e., restarts curriculum). + """ + start = self.current_index % self.dataset_size + end = start + batch_size + if end <= self.dataset_size: + selected_indices = self.sorted_index[start:end] + else: + selected_indices = np.concatenate( + [self.sorted_index[start:], self.sorted_index[: (end - self.dataset_size)]] + ) + self.current_index += batch_size + if not return_extra_info: + return selected_indices + else: + extra_info = { + "indices": selected_indices.tolist(), + "feat1": self.features[selected_indices, 0].tolist(), + "feat2": self.features[selected_indices, 1].tolist(), + } + return selected_indices, extra_info + + def state_dict(self) -> Dict: + """ + Save current position in the curriculum for checkpointing. + Allows resuming from same point in the easy→hard progression. + """ + return { + "current_index": self.current_index, + } + + def load_state_dict(self, state_dict): + """ + Restore progress through the curriculum from saved state. + """ + self.current_index = state_dict.get("current_index", 0) + + +@SELECTORS.register_module("difficulty_based") +class DifficultyBasedSelector(BaseSelector): + """ + Adaptive difficulty-based selector using probabilistic modeling of sample difficulty. + + Uses `InterpolationBetaPREstimator` to model each sample's probability of success (PR), + updated with observed feedback (e.g., loss, accuracy). Then selects samples close to + a target reward (e.g., 1.0 for perfect performance), implementing a form of + *targeted difficulty sampling* — focusing on items near the edge of model capability. + + Supports both greedy selection (`tau=0`) and stochastic sampling (`tau>0`). + """ + + def __init__(self, data_source, config: TaskSelectorConfig) -> None: + super().__init__(data_source, config) + self.logger = get_logger("difficulty_based_selector") + + # Initialize difficulty estimator using two features (assumed: e.g., correctness & uncertainty) + self.diff_estimator = self.build_diff_estimator( + data_source.dataset, config.feature_keys, config.kwargs + ) + self.current_index = 0 + self.seed = config.seed + + self.do_sample = config.kwargs.get( + "do_sample", False + ) # Whether to sample PR during estimation + self.target_reward = config.kwargs.get("target_reward", 1.0) # Desired performance level + self.tau = config.kwargs.get("tau", 1.0) # Temperature for sampling distribution + + def build_diff_estimator(self, dataset, feature_keys: List[str], config: dict): + """ + Constructs a Beta-distribution-based difficulty estimator from features. + + Expects exactly two feature keys (e.g., ['correct', 'uncertainty']), which are concatenated + into a feature matrix and passed to InterpolationBetaPREstimator for modeling P(success). + """ + self.logger.debug(f"{config=}") + if len(feature_keys) != 2: + raise ValueError( + f"DifficultyBasedSelector requires exactly 2 feature keys, but got {len(feature_keys)}." + ) + features = np.concatenate( + [np.array(list(dataset[k]))[:, None] for k in feature_keys], axis=1 + ) + self.logger.debug(f"{features.shape=}") + self.logger.debug(f"{features[:5]=}") + adaptive_rho = config.get("adaptive_rho", False) + return InterpolationBetaPREstimator( + features=features, + m=config.get("m", 16), + lamb=config.get("lamb", 0.2), + rho=config.get("rho", 0.2), + adaptive_rho=adaptive_rho, + ) + + def update(self, indices: List[int], values: List[float]) -> None: + """ + Updates the difficulty estimator with observed performance on selected samples. + + Args: + indices (List[int]): Previously selected sample indices + values (List[float]): Observed rewards/scores (e.g., accuracy, BLEU) for those samples + """ + self.diff_estimator.update(indices, values) + + def get_scores(self) -> List[float]: + """ + Computes selection scores: negative distance between predicted PR and target reward. + + Samples whose predicted performance is closest to `target_reward` receive highest scores. + Encourages selection of "just right" difficulty samples (neither too easy nor too hard). + """ + rng = np.random.default_rng(self.seed + self.current_index) + predicted_pr = self.diff_estimator.predict_pr(rng=rng, do_sample=self.do_sample) + scores = -np.abs(self.target_reward - predicted_pr) + return scores + + def get_indices(self, batch_size, return_extra_info=False): + """ + Selects batch of indices based on difficulty proximity to target. + + If tau == 0: take top-k highest scoring samples (greedy). + Else: sample stochastically using softmax(logits / tau). + """ + sampling_scores = self.get_scores() + sampling_scores = torch.from_numpy(sampling_scores) + if self.tau == 0: + selected_indices = torch.topk(sampling_scores, batch_size).indices + else: + sampling_logits = sampling_scores / self.tau + sampling_logits -= sampling_logits.max() + sampling_probabilities = torch.softmax(sampling_logits, dim=0) + rng = torch.Generator() + rng.manual_seed(self.seed + self.current_index) + selected_indices = torch.multinomial( + sampling_probabilities, + batch_size, + replacement=False, + generator=rng, + ) + self.logger.debug(f"{selected_indices=}") + self.logger.debug(f"{sampling_scores=}") + self.logger.debug(f"{sampling_scores[selected_indices]=}") + self.current_index += batch_size + + if return_extra_info: + selected_indices_list = selected_indices.tolist() + alphas = self.diff_estimator.alphas[selected_indices_list] + betas = self.diff_estimator.betas[selected_indices_list] + point_est = alphas / (alphas + betas) + extra_info = { + "indices": selected_indices_list, + "scores": sampling_scores[selected_indices].tolist(), + "alphas": alphas.tolist(), + "betas": betas.tolist(), + "point": point_est.tolist(), + } + return selected_indices, extra_info + else: + return selected_indices + + def state_dict(self) -> Dict: + """ + Save current state for checkpointing. + Only tracks sampling progress; actual difficulty estimates are in diff_estimator. + """ + return { + "current_index": self.current_index, + } + + def load_state_dict(self, state_dict): + """ + Restore selector state from checkpoint. + """ + self.current_index = state_dict.get("current_index", 0) diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 3254cd663e..f7fd3fa643 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -223,7 +223,6 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.default_workflow_cls = WORKFLOWS.get(storage_config.default_workflow_type) # type: ignore self.default_reward_fn_cls = REWARD_FUNCTIONS.get(storage_config.default_reward_fn_type) # type: ignore self.formatter = TaskFormatter(storage_config) - self.offset = storage_config.index if storage_config.total_steps: self.total_samples = self.batch_size * storage_config.total_steps else: diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py new file mode 100644 index 0000000000..35a4eff2ce --- /dev/null +++ b/trinity/buffer/task_scheduler.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +"""The taskset scheduler.""" + +from collections import Counter +from typing import Dict, List + +import numpy as np + +from trinity.buffer.buffer import get_buffer_reader +from trinity.buffer.selector import SELECTORS +from trinity.common.config import Config +from trinity.common.constants import SELECTOR_METRIC +from trinity.utils.annotations import Experimental + + +@Experimental +class TasksetScheduler: + """ + Coordinates multiple datasets (tasksets) with customizable task selection strategies per taskset. + + The scheduler: + - Manages multiple data sources (tasksets) + - Uses a selector per taskset to determine which samples to read + - Shuffles the order of taskset access across epochs + - Supports adaptive selectors via feedback (e.g., difficulty-based sampling) + - Enables curriculum-like or interleaved multi-task training + + It assumes that each call to `read_async()` corresponds to one training step, + and batches are built by aggregating samples from different tasksets based on + a shuffled global schedule. + """ + + def __init__(self, explorer_state: Dict, config: Config): + """ + Initialize the scheduler from configuration and previous state (for resume support). + + Args: + explorer_state (Dict): Restoration state from checkpoint (may include progress info) + config (Config): Full system configuration containing buffer and taskset settings + """ + self.config = config + + # Backward compatibility: old format stored 'latest_task_index' directly + if "latest_task_index" in explorer_state: + assert len(config.buffer.explorer_input.tasksets) == 1 # old format + explorer_state["taskset_states"] = [ + { + "current_index": explorer_state["latest_task_index"], + } + ] + + self.read_batch_size = config.buffer.batch_size + taskset_configs = config.buffer.explorer_input.tasksets + + from trinity.buffer.reader.file_reader import TaskFileReader + + taskset_states = explorer_state.get( + "taskset_states", [{"current_index": 0}] * len(taskset_configs) + ) + self.tasksets = [] + self.selectors = [] + for taskset_config, taskset_state in zip(taskset_configs, taskset_states): + assert not taskset_config.is_eval # assume drop last + taskset = get_buffer_reader(taskset_config, config.buffer) + if not isinstance(taskset, TaskFileReader): + raise TypeError( + f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'." + f"Currently, only 'TaskFileReader' is supported by TasksetScheduler." + ) + + # Create selector based on type specified in config (e.g., 'sequential', 'shuffle') + selector = SELECTORS.get(taskset_config.task_selector.selector_type)( + taskset.dataset, taskset_config.task_selector + ) + selector.load_state_dict(taskset_state) # Restore any prior state + + self.tasksets.append(taskset) + self.selectors.append(selector) + + # Each explorer step calls read_async once → track step globally + self.step = explorer_state.get("latest_iteration", 0) + + # Build flat list indicating how often each taskset should appear per epoch + self.base_taskset_ids = [] + for i, taskset in enumerate(self.tasksets): + self.base_taskset_ids.extend([i] * len(taskset)) + if len(self.base_taskset_ids) == 0: + raise ValueError("Empty tasksets provided!") + + self.epoch = self.step * self.read_batch_size // len(self.base_taskset_ids) + self.orders = self.build_orders(self.epoch) + + def build_orders(self, epoch: int): + """ + Creates a shuffled sequence of taskset IDs to control sampling priority per step. + + At the start of each epoch, all tasksets are shuffled proportionally to their size, + ensuring balanced exposure while introducing randomness in selection order. + + Args: + epoch (int): Epoch ID used as seed for deterministic shuffling + + Returns: + List[int]: Sequence of taskset IDs, length = steps_per_epoch * batch_size + """ + taskset_ids = self.base_taskset_ids.copy() + rng = np.random.default_rng(epoch) + rng.shuffle(taskset_ids) + return taskset_ids + + async def read_async(self) -> List: + """ + Asynchronously reads a batch of tasks according to the current schedule. + + For each step: + - Checks if a new epoch has started; rebuilds order if so + - Determines which tasksets contribute to this batch + - Uses each taskset's selector to pick specific samples + - Annotates each task with its source taskset_id + - Returns combined list of tasks + + Raises: + StopAsyncIteration: When total_epochs is reached + + Returns: + List[Task]: A batch of tasks from potentially multiple tasksets + """ + if self.config.buffer.total_steps: + if self.step >= self.config.buffer.total_steps: + raise StopAsyncIteration + else: + if self.epoch >= self.config.buffer.total_epochs: + raise StopAsyncIteration + + batch_size = self.read_batch_size + start = self.step * batch_size % len(self.base_taskset_ids) + end = start + batch_size + if end <= len(self.base_taskset_ids): + taskset_ids = self.orders[start:end] + if end == len(self.base_taskset_ids): + self.epoch += 1 + self.orders = self.build_orders(self.epoch) + else: + taskset_ids = self.orders[start:] + self.epoch += 1 + if self.epoch >= self.config.buffer.total_epochs: + raise StopAsyncIteration + self.orders = self.build_orders(self.epoch) + taskset_ids += self.orders[: (end - len(self.base_taskset_ids))] + + counter = Counter(taskset_ids) + batch = [] + for taskset_id, count in counter.items(): + indices = self.selectors[taskset_id].get_indices(batch_size=count) + tasks = await self.tasksets[taskset_id].read_with_indices_async(indices) + # Annotate each task with its origin + for task in tasks: + task.index["taskset_id"] = taskset_id + batch.extend(tasks) + + self.step += 1 + return batch + + def state_dict(self) -> List[Dict]: + """ + Save persistent state for checkpointing. + + Returns: + List[Dict]: State dicts for all selectors (one per taskset) + """ + return [selector.state_dict() for selector in self.selectors] + + def update(self, pipeline_metrics: Dict) -> None: + """ + Update selectors using feedback from the training pipeline. + + Expected format: + pipeline_metrics = { + SELECTOR_METRIC: { + 0: {"indices": [...], "values": [...]}, + 1: {"indices": [...], "values": [...]} + }, + ... # other metrics + } + + This allows adaptive selectors (like `DifficultyBasedSelector`) to refine difficulty estimates. + + Args: + pipeline_metrics (Dict): Metrics dictionary passed from explorer. + """ + if SELECTOR_METRIC not in pipeline_metrics: + return + selector_metric = pipeline_metrics[SELECTOR_METRIC] + for taskset_id, taskset_kwargs in selector_metric.items(): + selector = self.selectors[taskset_id] + selector.update(**taskset_kwargs) diff --git a/trinity/common/config.py b/trinity/common/config.py index ba66634539..fc57afa41f 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -112,6 +112,21 @@ class LoRAConfig: target_modules: str = "all-linear" +@Experimental +@dataclass +class TaskSelectorConfig: + """Data selector config.""" + + selector_type: Optional[str] = "sequential" + + # For shuffle + seed: int = 42 + + # Estimator Config + feature_keys: List[str] = field(default_factory=lambda: []) + kwargs: dict = field(default_factory=dict) + + @dataclass class StorageConfig: """Storage config.""" @@ -148,6 +163,7 @@ class StorageConfig: rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) reward_fn_args: dict = field(default_factory=dict) + task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig) # enable progress bar (tqdm) for _HFBatchReader enable_progress_bar: Optional[bool] = False @@ -381,7 +397,8 @@ class ClusterConfig: class ExplorerInput: """Config for explorer input.""" - taskset: StorageConfig = field(default_factory=StorageConfig) + taskset: Optional[StorageConfig] = None + tasksets: List[StorageConfig] = field(default_factory=list) eval_tasksets: List[StorageConfig] = field(default_factory=list) # The following args provide default values for the corresponding args in `taskset` and `eval_tasksets` default_workflow_type: Optional[str] = None @@ -669,34 +686,44 @@ def _check_buffer(self) -> None: # noqa: C901 trainer_input = self.buffer.trainer_input experience_buffer = trainer_input.experience_buffer explorer_input = self.buffer.explorer_input - taskset = explorer_input.taskset - if self.mode != "train" and not taskset.path: - raise ValueError( - "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." - ) - if not taskset.name: - taskset.name = "taskset" - if taskset.repeat_times is None or taskset.repeat_times != self.algorithm.repeat_times: - taskset.repeat_times = self.algorithm.repeat_times - logger.info( - "`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`" - f" (={self.algorithm.repeat_times})." - ) + if explorer_input.taskset: + if len(explorer_input.tasksets) > 0: + raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!") + explorer_input.tasksets = [explorer_input.taskset] + explorer_input.taskset = None + else: + if len(explorer_input.tasksets) == 0: + explorer_input.tasksets = [StorageConfig()] + tasksets = explorer_input.tasksets + if self.mode == "train": assert ( experience_buffer is not None ), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`." experience_buffer.total_epochs = self.buffer.total_epochs experience_buffer.total_steps = self.buffer.total_steps - else: + + for taskset in tasksets: + if self.mode != "train" and not taskset.path: + raise ValueError( + "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." + ) + if not taskset.name: + taskset.name = "taskset" + if taskset.repeat_times is None or taskset.repeat_times != self.algorithm.repeat_times: + taskset.repeat_times = self.algorithm.repeat_times + logger.info( + "`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`" + f" (={self.algorithm.repeat_times})." + ) taskset.total_epochs = self.buffer.total_epochs taskset.total_steps = self.buffer.total_steps - set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type) - set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type) - set_if_none(taskset, "ray_namespace", self.ray_namespace) - set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens) + set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type) + set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type) + set_if_none(taskset, "ray_namespace", self.ray_namespace) + set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens) remained_tasksets = [] for idx, dataset in enumerate(explorer_input.eval_tasksets): @@ -760,8 +787,8 @@ def _check_buffer(self) -> None: # noqa: C901 task_pipeline = self.data_processor.task_pipeline if task_pipeline is not None: if task_pipeline.output is None: - if taskset.path is not None: - task_pipeline.output = taskset + if tasksets and tasksets[0].path is not None: + task_pipeline.output = tasksets[0] elif ( experience_buffer.schema_type in {"dpo", "sft"} and experience_buffer.path is not None @@ -770,7 +797,7 @@ def _check_buffer(self) -> None: # noqa: C901 else: raise ValueError( "`data_processor.task_pipeline.output` is required when both " - "`buffer.explorer_input.taskset.path` and `buffer.trainer_input.experience_buffer.path` are " + "`buffer.explorer_input.tasksets[0].path` and `buffer.trainer_input.experience_buffer.path` are " "None" ) if task_pipeline.output.path and os.path.exists(task_pipeline.output.path): diff --git a/trinity/common/constants.py b/trinity/common/constants.py index ad092603d2..183702927b 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -7,6 +7,8 @@ EXPLORER_NAME = "explorer" TRAINER_NAME = "trainer" +SELECTOR_METRIC = "selector_metric" + ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync" DEBUG_NAMESPACE = "TRINITY_DEBUG_NAMESPACE" diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 05c26bfea4..1954833191 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -413,9 +413,9 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template self.actor_rollout_ref.actor.optim.total_training_steps = self.trainer.total_training_steps self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size - self.actor_rollout_ref.rollout.temperature = ( - config.buffer.explorer_input.taskset.rollout_args.temperature - ) + self.actor_rollout_ref.rollout.temperature = config.buffer.explorer_input.tasksets[ + 0 + ].rollout_args.temperature self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times if self.actor_rollout_ref.actor.grad_clip is None: self.actor_rollout_ref.actor.grad_clip = config.trainer.grad_clip diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 26a555d7cc..8a493e161f 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -37,6 +37,8 @@ class Task(dict): batch_id: Union[int, str] = "" task_id: Union[int, str] = "" + index: dict = field(default_factory=dict) + def to_workflow( self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None ) -> Workflow: diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index a10b523af2..b80807213f 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -15,6 +15,7 @@ from trinity.buffer.buffer import get_buffer_reader from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline +from trinity.buffer.task_scheduler import TasksetScheduler from trinity.common.config import Config from trinity.common.constants import ( ROLLOUT_WEIGHT_SYNC_GROUP_NAME, @@ -49,11 +50,8 @@ def __init__(self, config: Config): self.config = config self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() - self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0) self.taskset = ( - get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) - if self.config.mode != "serve" - else None + TasksetScheduler(explorer_state, config) if self.config.mode != "serve" else None ) self.scheduler = None self.monitor = MONITOR.get(self.config.monitor.monitor_type)( @@ -324,7 +322,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: # save explore checkpoint self.state.save_explorer( current_step=self.explore_step_num, - current_task_index=self.explore_step_num * self.config.buffer.batch_size, + taskset_states=self.taskset.state_dict() if self.taskset else [], ) async def sync_weight(self) -> None: @@ -342,6 +340,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: statuses, exps = await self.scheduler.get_results(batch_id=step) metric = {"rollout/model_version": model_version} pipeline_metrics = await self.experience_pipeline.process.remote(exps) + self.taskset.update(pipeline_metrics) metric.update(pipeline_metrics) if statuses: metric.update(gather_metrics([status.metric for status in statuses], "rollout")) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 187d0d5adf..92ca79563f 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -137,6 +137,7 @@ async def run_task( exp.info = {} exp.info["model_version"] = model_version exp.info["use_count"] = 0 + exp.info["task_index"] = task.index if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} @@ -171,7 +172,7 @@ def __init__( ) -> None: model, auxiliary_models = get_debug_inference_model(config) super().__init__(config, model, auxiliary_models, 0) - self.taskset = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0], config.buffer) self.output_file = output_file async def debug(self) -> None: diff --git a/trinity/manager/state_manager.py b/trinity/manager/state_manager.py index e47566d839..eee97e49a5 100644 --- a/trinity/manager/state_manager.py +++ b/trinity/manager/state_manager.py @@ -2,7 +2,7 @@ """State manager.""" import json import os -from typing import Optional +from typing import Dict, List, Optional from trinity.common.config import Config, load_config from trinity.utils.log import get_logger @@ -48,14 +48,14 @@ def _check_config_consistency(self, config: Config) -> None: def save_explorer( self, - current_task_index: int, current_step: int, + taskset_states: List[Dict], ) -> None: with open(self.explorer_state_path, "w", encoding="utf-8") as f: json.dump( { - "latest_task_index": current_task_index, "latest_iteration": current_step, + "taskset_states": taskset_states, }, f, indent=2,