Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ Controls the rollout models and workflow execution.
explorer:
name: explorer
runner_per_model: 8
concurrent_mode: sequential
max_repeat_times_per_runner: null
max_timeout: 900
max_retry_times: 2
env_vars: {}
Expand All @@ -414,6 +416,11 @@ explorer:

- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
- `runner_per_model`: Number of parallel workflow runners per each rollout model.
- `concurrent_mode`: Concurrency mode for executing a set of tasks (e.g., multiple repeats of a task in GRPO algorithm). Supported options:
- `sequential`: Executes tasks sequentially (default). One task is executed at a time, waiting for its completion before executing the next. This requires the least from the workflow implementation but has the worst throughput.
- `asynchronous`: Executes tasks asynchronously. All tasks are submitted at once, and results are collected as they complete. Requires the workflow to correctly implement asynchronous call interfaces and have no shared state between workflows to avoid race conditions. Throughput is better than sequential execution, but the performance depends on the workflow implementation.
- `multi-threading`: Executes tasks using multi-threading. Multiple tasks are executed simultaneously using threads. Requires the workflow implementation to be thread-safe to avoid race conditions. Throughput is usually better than sequential execution but may be lower than asynchronous execution, depending on the workflow implementation and system resources.
- `max_repeat_times_per_runner`: Splits tasks that originally need to be repeated `algorithm.repeat_times` times into multiple subtasks, where each subtask's `repeat_times` does not exceed this value. This parameter is only applicable to GRPO-like algorithms. If not set, there is no limit on the number of repeats. It is recommended to use this parameter when `concurrent_mode` is `sequential` to reduce the end to end latency and improve throughput.
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
- `max_retry_times`: Maximum number of retries for a workflow.
- `env_vars`: Environment variables to be set for every workflow runners.
Expand Down
7 changes: 7 additions & 0 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ buffer:
explorer:
name: explorer
runner_per_model: 8
concurrent_mode: sequential
max_repeat_times_per_runner: null
max_timeout: 900
max_retry_times: 2
env_vars: {}
Expand All @@ -411,6 +413,11 @@ explorer:

- `name`: explorer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。
- `runner_per_model`: 每个推理引擎实例所服务的 WorkflowRunner 数量。
- `concurrent_mode`: 执行一组任务(例如 GRPO 算法中对一个任务的多次重复执行)的并发模式。支持如下选项:
- `sequential`: 顺序执行(默认)。每次执行一个任务,等待其完成后再执行下一个任务。对 workflow 的实现要求最低,但吞吐量也最差。
- `asynchronous`: 异步执行。使用异步模式一次性提交所有任务,并在任务完成后收集结果。要求 workflow 正确地实现了异步调用接口,并且 workflow 之间没有共享状态,以避免竞态条件。吞吐量优于顺序执行,但具体性能受限于 workflow 的实现。
- `multi-threading`: 多线程执行。使用多线程同时执行多个任务。需要确保 workflow 的实现是线程安全的,以避免竞态条件。吞吐量通常优于顺序执行,但可能低于异步执行,具体取决于工作流的实现和系统资源。
- `max_repeat_times_per_runner`: 将本来需要重复执行 `algorithm.repeat_times` 次的任务切分为多个子任务,每个子任务的 `repeat_times` 不超过该值,仅适用于 GRPO 类算法。如果未设置,则不限制重复次数。推荐在 `concurrent_mode` 为 `sequential` 时使用此参数,以避免单个 WorkflowRunner 长时间占用资源。
- `max_timeout`: 等待 Workflow 完成的最大时间(秒)。
- `max_retry_times`: Workflow 失败或超时情况下的最大重试次数。
- `env_vars`: 为每个 WorkflowRunner 设置的环境变量。
Expand Down
43 changes: 42 additions & 1 deletion tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import unittest
from collections import defaultdict
from typing import List, Optional
from typing import Dict, List, Optional, Sequence

import ray
import torch
Expand Down Expand Up @@ -169,6 +169,27 @@ async def run_async(self) -> List[Experience]:
return exps


@WORKFLOWS.register_module("dummy_concurrent_workflow")
class DummyConcurrentWorkflow(Workflow):
can_repeat: bool = False
is_async: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)

async def run_async(self) -> List[Experience]:
await asyncio.sleep(1)

return [
Experience(
eid=EID(run=self.run_id_base, step=0),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
)
]


@ray.remote
class DummyModel(InferenceModel):
def __init__(self):
Expand Down Expand Up @@ -200,6 +221,26 @@ def init_process_group(
def get_api_server_url(self) -> Optional[str]:
return None

async def chat(self, messages: List[Dict], lora_request=None, **kwargs) -> Sequence[Experience]:
prompt_length = sum(len(msg["content"]) for msg in messages)
return [
Experience(
tokens=torch.zeros(prompt_length + 10),
prompt_length=prompt_length,
logprobs=torch.zeros(10),
)
]

async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[Experience]:
prompt_length = len(prompt)
return [
Experience(
tokens=torch.zeros(prompt_length + 5),
prompt_length=prompt_length,
logprobs=torch.zeros(5),
)
]


@ray.remote
class DummyAuxiliaryModel(InferenceModel):
Expand Down
106 changes: 105 additions & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""Test for the workflow module"""
import asyncio
import time
import unittest
from collections import defaultdict
from dataclasses import dataclass, field
Expand All @@ -14,7 +15,12 @@
from torch import Tensor

from tests.common.vllm_test import CHAT_TEMPLATE
from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config
from tests.tools import (
RayUnittestBaseAsync,
get_model_path,
get_template_config,
get_unittest_dataset_config,
)
from trinity.common.config import InferenceModelConfig
from trinity.common.experience import EID, Experience
from trinity.common.models import create_inference_models
Expand Down Expand Up @@ -789,3 +795,101 @@ async def test_workflow_with_openai(self):

def tearDown(self):
ray.shutdown(_exiting_interpreter=True)


class ConcurrentTestWorkflow(Workflow):
is_async: bool = True

def __init__(self, model: ModelWrapper, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.client = self.model.get_openai_async_client()

async def run_async(self):
assert self.task.raw_task is not None
_ = await self.model.chat_async([{"role": "user", "content": self.task.raw_task["text"]}])
await asyncio.sleep(1.0)
_ = await self.client.chat.completions.create(
model=self.client.model_path,
messages=[{"role": "user", "content": self.task.raw_task["text"]}],
)
history_exps = self.model.extract_experience_from_history()
assert len(history_exps) == 2
assert history_exps[0].prompt_length == history_exps[1].prompt_length
prompt_length = history_exps[0].prompt_length
assert (
history_exps[0].tokens[:prompt_length].shape
== history_exps[1].tokens[:prompt_length].shape
)
return history_exps


class TestConcurrentWorkflowRunner(RayUnittestBaseAsync):
async def test_concurrent_workflow_runner(self):
config = get_template_config()
config.mode = "explore"
config.model.model_path = get_model_path()
config.explorer.rollout_model.engine_num = 1
config.explorer.rollout_model.enable_history = True
config.explorer.rollout_model.enable_openai_api = True
config.check_and_update()
engines, auxiliary_engines = create_inference_models(config)
await asyncio.gather(*[engine.prepare.remote() for engine in engines])

config.explorer.concurrent_mode = "sequential"
sequential_runner = WorkflowRunner(
config,
model=engines[0],
auxiliary_models=[],
runner_id=0,
)
config.explorer.concurrent_mode = "asynchronous"
async_runner = WorkflowRunner(
config,
model=engines[0],
auxiliary_models=[],
runner_id=1,
)
thread_runner = WorkflowRunner(
config,
model=engines[0],
auxiliary_models=[],
runner_id=2,
)
await asyncio.gather(
sequential_runner.prepare(),
async_runner.prepare(),
thread_runner.prepare(),
)

task = Task(
workflow=ConcurrentTestWorkflow,
repeat_times=4,
raw_task={"text": "Hello, world!"},
)
# warmup
async_status, async_exps = await async_runner.run_task(task, repeat_times=2, run_id_base=0)

st = time.time()
async_status, async_exps = await async_runner.run_task(task, repeat_times=4, run_id_base=0)
async_runtime = time.time() - st
st = time.time()
thread_status, thread_exps = await thread_runner.run_task(
task, repeat_times=4, run_id_base=0
)
thread_runtime = time.time() - st
st = time.time()
sequential_status, sequential_exps = await sequential_runner.run_task(
task, repeat_times=4, run_id_base=0
)
sequential_runtime = time.time() - st

self.assertTrue(async_status.ok)
self.assertTrue(thread_status.ok)
self.assertTrue(sequential_status.ok)

self.assertEqual(len(async_exps), 8)
self.assertEqual(len(thread_exps), 8)
self.assertEqual(len(sequential_exps), 8)

self.assertLessEqual(async_runtime * 2, sequential_runtime)
self.assertLessEqual(thread_runtime * 2, sequential_runtime)
15 changes: 12 additions & 3 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,18 @@ class ExplorerConfig:
max_timeout: int = 1800 # wait each task for 30 minutes at most
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
env_vars: dict = field(default_factory=dict) # environment variables for workflow runner
max_repeat_times_per_runner: Optional[
int
] = None # the number of time to repeat each task in a single workflow runner (for GRPO-like algorithms)

# Workflow Runner Configs for tasks requiring group execution
# how to run a group of tasks in a single workflow runner
# "sequential": run tasks one by one, no requirements on workflow design, but have lower throughput
# "asynchronous": run tasks asynchronously, requires the workflow to be designed with async/await
# syntax, and no sharing of state between tasks
# "multi-threading": run tasks using multi-threading, requires the workflow to be thread-safe,
# and no sharing of state between tasks
concurrent_mode: str = "sequential"
# the number of time to repeat each task in a single workflow runner
# we recommend setting this only when using "sequential" concurrent_mode
max_repeat_times_per_runner: Optional[int] = None

runner_num: Optional[int] = None # ! Deprecated

Expand Down
15 changes: 15 additions & 0 deletions trinity/common/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,21 @@ def validate(self, config: Config) -> None:

self._validate_lora(config)

# check concurrent mode
if config.explorer.concurrent_mode not in ["sequential", "asynchronous", "multi-threading"]:
raise ValueError(f"Invalid explorer.concurrent_mode: {config.explorer.concurrent_mode}")
if config.explorer.concurrent_mode in ["asynchronous", "multi-threading"]:
batch_size = config.buffer.batch_size
max_runner_per_model = math.ceil(batch_size / config.explorer.rollout_model.engine_num)
if config.explorer.runner_per_model > max_runner_per_model:
self.logger.warning(
f"explorer.runner_per_model ({config.explorer.runner_per_model}) is too large "
f"for concurrent_mode '{config.explorer.concurrent_mode}' with batch_size "
f"({batch_size}) and rollout_model.engine_num ({config.explorer.rollout_model.engine_num}). "
f"It is set to {max_runner_per_model}."
)
config.explorer.runner_per_model = max_runner_per_model

def _validate_lora(self, config: Config) -> None:
"""Process and validate LoRA configuration settings.

Expand Down
23 changes: 17 additions & 6 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""Base Model Class"""
import asyncio
import copy
import socket
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -229,7 +230,6 @@ def __init__(
engine_type: str = "vllm",
enable_lora: bool = False,
enable_history: bool = False,
enable_thinking: Optional[bool] = None,
):
"""Initialize the ModelWrapper.

Expand All @@ -254,7 +254,6 @@ def __init__(
self.logger = get_logger(__name__)
self.enable_lora = enable_lora
self.enable_history = enable_history
self.enable_thinking = enable_thinking
self.history = []
self.status = RunningStatus.RUNNING
self.workflow_state: Dict = {}
Expand Down Expand Up @@ -523,10 +522,12 @@ def chat_completions(*args, **kwargs):
def record_chat_completions(*args, **kwargs):
logprobs = kwargs.pop("logprobs", True)
extra_body = kwargs.pop("extra_body", {})
if self.enable_thinking is not None:
if self.config.enable_thinking is not None:
if "chat_template_kwargs" not in extra_body:
extra_body["chat_template_kwargs"] = {}
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
extra_body["chat_template_kwargs"][
"enable_thinking"
] = self.config.enable_thinking
extra_body["return_token_ids"] = True
response = ori_create(*args, extra_body=extra_body, logprobs=logprobs, **kwargs)
self.history.extend(convert_api_output_to_experience(response))
Expand Down Expand Up @@ -580,10 +581,12 @@ async def chat_completions(*args, **kwargs):
async def record_chat_completions(*args, **kwargs):
logprobs = kwargs.pop("logprobs", True)
extra_body = kwargs.pop("extra_body", {})
if self.enable_thinking is not None:
if self.config.enable_thinking is not None:
if "chat_template_kwargs" not in extra_body:
extra_body["chat_template_kwargs"] = {}
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
extra_body["chat_template_kwargs"][
"enable_thinking"
] = self.config.enable_thinking
extra_body["return_token_ids"] = True
response = await ori_create(
*args, extra_body=extra_body, logprobs=logprobs, **kwargs
Expand Down Expand Up @@ -637,6 +640,14 @@ async def get_workflow_state(self) -> Dict:
async with self.state_lock:
return self.workflow_state.copy()

def clone_with_isolated_history(self) -> "ModelWrapper":
"""Clone the current ModelWrapper with isolated history."""
new_wrapper = copy.copy(self)
new_wrapper.openai_async_client = None
new_wrapper.openai_client = None
new_wrapper.history = []
return new_wrapper


def convert_api_output_to_experience(
output,
Expand Down
Loading