Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/develop_selector.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ To create a new selector, inherit from `BaseSelector` and implement the followin
| Method | Purpose |
|-------|--------|
| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | Return a list of sample indices to read next. |
| `update(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). |
| `feedback(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). |
| `state_dict() -> Dict` | Serialize current state for checkpointing. |
| `load_state_dict(state_dict: Dict)` | Restore state from a saved dictionary. |

Expand Down Expand Up @@ -113,7 +113,7 @@ class DifficultyBasedSelector(BaseSelector):
else:
return selected_indices

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# Update difficulty model with observed rewards
self.diff_estimator.update(indices, values)

Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source_zh/tutorial/develop_selector.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
| 方法 | 功能说明 |
|------|---------|
| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | 返回接下来要读取的样本索引列表。 |
| `update(indices: List[int], values: List[float])` | 使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 |
| `feedback(indices: List[int], values: List[float])` | 使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 |
| `state_dict() -> Dict` | 序列化当前状态,用于保存检查点。 |
| `load_state_dict(state_dict: Dict)` | 从保存的状态字典中恢复选择器状态。 |

Expand Down Expand Up @@ -111,7 +111,7 @@ class DifficultyBasedSelector(BaseSelector):
else:
return selected_indices

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# 使用观测到的奖励更新难度模型
self.diff_estimator.update(indices, values)

Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ async def test_task_scheduler_simple(self):
self.assertEqual(len(task_scheduler_state), 1)
self.assertEqual(task_scheduler_state[0]["current_index"], 4)
# no effect
task_scheduler.update({"metric1": 0.5})
task_scheduler.feedback({"metric1": 0.5})

task_scheduler = get_taskset_scheduler(
{
Expand Down
48 changes: 30 additions & 18 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def select_batch(self, indices: List[int]) -> List:
batch = []
for i in indices:
assert 0 <= i < self.dataset_size
if self.current_offset >= self.total_samples:
if not self.drop_last and len(batch) > 0:
break
self.progress_bar.close()
raise StopIteration
batch.append(self.dataset[int(i)])
self.current_offset += 1

self.progress_bar.update(len(batch)) # update progress bar
return batch

Expand All @@ -104,20 +111,16 @@ def __init__(self, config: StorageConfig):
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
return self.reader.read(batch_size)

def read_with_indices(self, indices: List[int]) -> List:
"""Read tasks with indices."""
return self.reader.read_with_indices(indices)

async def read_with_indices_async(self, indices: List[int]) -> List:
"""Read tasks with indices asynchronously."""
return await self.reader.read_with_indices_async(indices)

def state_dict(self):
return self.reader.state_dict()

def load_state_dict(self, state_dict):
return self.reader.load_state_dict(state_dict)

def feedback(self, **pipeline_metrics):
if self.reader.selector is not None:
self.reader.selector.feedback(**pipeline_metrics)

def __len__(self):
return self.reader.__len__()

Expand All @@ -139,6 +142,7 @@ def __init__(self, config: StorageConfig):
total_steps=config.total_steps,
enable_progress_bar=config.enable_progress_bar,
)
self.selector = None

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size)
Expand Down Expand Up @@ -178,6 +182,15 @@ def __init__(self, config: StorageConfig):
enable_progress_bar=self.config.enable_progress_bar,
)
self.formatter = FORMATTER.get("task")(config)
if self.config.task_selector is not None:
from trinity.buffer.selector import SELECTORS
from trinity.buffer.selector.selector import BaseSelector

self.selector: BaseSelector = SELECTORS.get(self.config.task_selector.selector_type)(
self.dataset, self.config.task_selector
)
else:
self.selector = None

def _get_tasks(self, samples: List, indices: List) -> List:
tasks = []
Expand All @@ -189,22 +202,21 @@ def _get_tasks(self, samples: List, indices: List) -> List:

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = batch_size or self.read_batch_size
samples, indices = self.dataset.read_batch(batch_size)
return self._get_tasks(samples, indices)

def read_with_indices(self, indices: List[int]) -> List:
"""Read tasks with indices."""
samples = self.dataset.select_batch(indices)
if self.selector is not None:
indices = self.selector.get_indices(batch_size)
samples = self.dataset.select_batch(indices)
else:
samples, indices = self.dataset.read_batch(batch_size)
return self._get_tasks(samples, indices)

async def read_with_indices_async(self, indices: List[int]) -> List:
"""Read tasks with indices asynchronously."""
return self.read_with_indices(indices)

def state_dict(self):
if self.selector is not None:
return self.selector.state_dict()
return {"current_index": self.dataset.current_offset}

def load_state_dict(self, state_dict):
if self.selector is not None:
self.selector.load_state_dict(state_dict)
self.dataset.current_offset = state_dict["current_index"]

def __len__(self):
Expand Down
12 changes: 6 additions & 6 deletions trinity/buffer/selector/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
"""
raise NotImplementedError

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
"""
Update internal state based on feedback (e.g., model loss, accuracy).

Expand Down Expand Up @@ -95,7 +95,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
return list(range(start, end))
return list(range(start, self.dataset_size)) + list(range(0, end - self.dataset_size))

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: sequential selection doesn't adapt based on feedback
pass

Expand Down Expand Up @@ -150,7 +150,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
self.current_index += batch_size
return ret

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: static shuffling does not adapt
pass

Expand Down Expand Up @@ -188,7 +188,7 @@ def get_indices(self, batch_size, return_extra_info=False):
else:
return selected_indices

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: basic random selection doesn't adapt
pass

Expand Down Expand Up @@ -239,7 +239,7 @@ def __init__(self, data_source, config: TaskSelectorConfig):
self.dataset_size = data_source.dataset_size
self.current_index = 0

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: this selector does not adapt based on runtime feedback
pass

Expand Down Expand Up @@ -340,7 +340,7 @@ def build_diff_estimator(self, dataset, feature_keys: List[str], config: dict):
adaptive_rho=adaptive_rho,
)

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
"""
Updates the difficulty estimator with observed performance on selected samples.

Expand Down
37 changes: 15 additions & 22 deletions trinity/buffer/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
"""The taskset scheduler."""

from collections import Counter
from copy import deepcopy
from typing import Dict, List

import numpy as np

from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.selector import SELECTORS
from trinity.common.config import Config
from trinity.common.constants import SELECTOR_METRIC
from trinity.utils.annotations import Experimental
Expand Down Expand Up @@ -47,7 +47,7 @@ def state_dict(self) -> List[Dict]:
"""
raise NotImplementedError

def update(self, pipeline_metrics: Dict) -> None:
def feedback(self, pipeline_metrics: Dict) -> None:
"""Update selectors using feedback from the training pipeline."""
raise NotImplementedError

Expand All @@ -68,16 +68,18 @@ def __init__(self, explorer_state: Dict, config: Config):
index = self.explorer_state.get("taskset_states", [{"current_index": 0}])[0].get(
"current_index", 0
)
self.config.buffer.explorer_input.tasksets[0].index = index
self.reader = get_buffer_reader(config.buffer.explorer_input.tasksets[0])
taskset_config = deepcopy(self.config.buffer.explorer_input.tasksets[0])
taskset_config.index = index
taskset_config.task_selector = None # disable selection
self.reader = get_buffer_reader(taskset_config)

async def read_async(self) -> List:
return await self.reader.read_async()

def state_dict(self) -> List[Dict]:
return [self.reader.state_dict()]

def update(self, pipeline_metrics: Dict) -> None:
def feedback(self, pipeline_metrics: Dict) -> None:
# do nothing here
return

Expand Down Expand Up @@ -127,7 +129,6 @@ def __init__(self, explorer_state: Dict, config: Config):
"taskset_states", [{"current_index": 0}] * len(taskset_configs)
)
self.tasksets = []
self.selectors = []
for taskset_config, taskset_state in zip(taskset_configs, taskset_states):
assert not taskset_config.is_eval # assume drop last
taskset = get_buffer_reader(taskset_config)
Expand All @@ -136,15 +137,8 @@ def __init__(self, explorer_state: Dict, config: Config):
f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'."
f"Currently, only 'FileReader' is supported by TasksetScheduler."
)

# Create selector based on type specified in config (e.g., 'sequential', 'shuffle')
selector = SELECTORS.get(taskset_config.task_selector.selector_type)(
taskset.reader.dataset, taskset_config.task_selector
)
selector.load_state_dict(taskset_state) # Restore any prior state

taskset.load_state_dict(taskset_state) # Restore any prior state
self.tasksets.append(taskset)
self.selectors.append(selector)

# Each explorer step calls read_async once → track step globally
self.step = explorer_state.get("latest_iteration", 0)
Expand Down Expand Up @@ -224,8 +218,7 @@ async def read_async(self) -> List:
counter = Counter(taskset_ids)
batch = []
for taskset_id, count in counter.items():
indices = self.selectors[taskset_id].get_indices(batch_size=count)
tasks = await self.tasksets[taskset_id].read_with_indices_async(indices)
tasks = await self.tasksets[taskset_id].read_async(batch_size=count)
# Annotate each task with its origin
for task in tasks:
task.index["taskset_id"] = taskset_id
Expand All @@ -239,13 +232,13 @@ def state_dict(self) -> List[Dict]:
Save persistent state for checkpointing.

Returns:
List[Dict]: State dicts for all selectors (one per taskset)
List[Dict]: State dicts for all tasksets
"""
return [selector.state_dict() for selector in self.selectors]
return [taskset.state_dict() for taskset in self.tasksets]

def update(self, pipeline_metrics: Dict) -> None:
def feedback(self, pipeline_metrics: Dict) -> None:
"""
Update selectors using feedback from the training pipeline.
Update selectors in tasksets using feedback from the training pipeline.

Expected format:
pipeline_metrics = {
Expand All @@ -265,5 +258,5 @@ def update(self, pipeline_metrics: Dict) -> None:
return
selector_metric = pipeline_metrics.pop(SELECTOR_METRIC, {})
for taskset_id, taskset_kwargs in selector_metric.items():
selector = self.selectors[taskset_id]
selector.update(**taskset_kwargs)
taskset = self.tasksets[taskset_id]
taskset.feedback(**taskset_kwargs)
5 changes: 5 additions & 0 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ def model_name(self) -> Optional[str]:
"""Get the name of the model."""
return self._model_name

@property
def model_config(self) -> InferenceModelConfig:
"""Get the model config."""
return self.config

@property
def generate_kwargs(self) -> Dict[str, Any]:
"""Get the generation kwargs for openai client."""
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
batch_id=step, min_num=self.min_wait_num
)
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
self.taskset.update(pipeline_metrics)
self.taskset.feedback(pipeline_metrics)
metric.update(pipeline_metrics)
if statuses:
metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout"))
Expand Down