Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
trinity-node-1:
image: trinity-rft-unittest:20260115
image: trinity-rft-unittest:20260126
cap_add:
- SYS_PTRACE
pull_policy: never
Expand Down Expand Up @@ -32,7 +32,7 @@ services:
capabilities: [gpu]

trinity-node-2:
image: trinity-rft-unittest:20260115
image: trinity-rft-unittest:20260126
cap_add:
- SYS_PTRACE
pull_policy: never
Expand Down
13 changes: 13 additions & 0 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ jobs:
MODULE=$(echo "$COMMENT" | sed -n 's/\/unittest-module-\(.*\)/\1/p')
echo "type=module" >> $GITHUB_OUTPUT
echo "module=$MODULE" >> $GITHUB_OUTPUT
elif [[ "$COMMENT" =~ ^/unittest-pattern-(.+)$ ]]; then
PATTERN=$(echo "$COMMENT" | sed -n 's/\/unittest-pattern-\(.*\)/\1/p')
echo "type=pattern" >> $GITHUB_OUTPUT
echo "pattern=$PATTERN" >> $GITHUB_OUTPUT
else
echo "type=all" >> $GITHUB_OUTPUT
fi
Expand Down Expand Up @@ -98,6 +102,15 @@ jobs:
echo "No module specified, skipping tests."
echo "tests_run=false" >> $GITHUB_ENV
fi
elif [ "$TYPE" = "pattern" ]; then
PATTERN="${{ steps.test_type.outputs.pattern }}"
if [ -n "$PATTERN" ]; then
echo "tests_run=true" >> $GITHUB_ENV
docker compose exec trinity-node-1 bash -c "source /opt/venv/bin/activate && pytest tests -v -s -k '$PATTERN' --ctrf report.json"
else
echo "No pattern specified, skipping tests."
echo "tests_run=false" >> $GITHUB_ENV
fi
fi

- name: Convert report.json time to ms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ Thus you can prepare a split environment for it and start the server manually us

```shell
# prepare split environments, including the one of data processor
python scripts/install.py
python scripts/data/install.py

# start all split servers
python scripts/start_servers.py
python scripts/data/start_servers.py
```

These scripts will create split environments for Trinity-RFT and Data-Juicer-based data processor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ trinity run --config <Trinity-RFT_config_path>

```shell
# 准备独立环境,包括数据处理器环境
python scripts/install.py
python scripts/data/install.py

# 启动所有独立服务
python scripts/start_servers.py
python scripts/data/start_servers.py
```

这些脚本将为 Trinity-RFT 和基于 Data-Juicer 的数据处理器创建独立环境,并在 Data-Juicer 环境中自动启动数据处理器服务。
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "trinity-rft"
version = "0.4.1"
version = "0.5.0.dev0"
authors = [
{name="Trinity-RFT Team", email="trinity-rft@outlook.com"},
]
Expand Down Expand Up @@ -50,7 +50,9 @@ trinity = "trinity.cli.launcher:main"

[project.optional-dependencies]
vllm = [
"vllm>=0.10.2,<=0.11.0",
"vllm>=0.10.2,<=0.14.1,!=0.12.0",
# v0.12.0 has a huge performance regression so we exclude it
# v0.10.2 is the most stable version, but we allow up to 0.14.1 for new features
]
data = [
"py-data-juicer>=1.4.3"
Expand All @@ -75,10 +77,10 @@ dev = [
"viztracer",
]
megatron = [
"megatron-core[mlm]==0.13.1",
"megatron-core[mlm]==0.15.0",
# if you found "undefined symbol" error in transformer engine
# reinstall it with --no-build-isolation and `--no-cache-dir` flag
# "transformer_engine[pytorch]==2.8.0",
# "transformer_engine[pytorch]==2.10.0",

# Install mbridge from main branch (unreleased version)
# "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612",
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion scripts/start_servers.py → scripts/data/start_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def main():

os.makedirs(args.log_dir, exist_ok=True)
env_mapping_file = os.path.join(
os.path.dirname(__file__), "..", "environments", "env_mapping.json"
os.path.dirname(__file__), "..", "..", "environments", "env_mapping.json"
)
with open(env_mapping_file, "r") as f:
env_mapping = json.load(f)
Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/Dockerfile.megatron
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RUN pip install --upgrade pip \
&& pip install -e .[vllm,mm,dev] \
&& pip install flash_attn==2.8.1 --no-build-isolation \
&& pip install -e .[megatron] \
&& pip install transformer_engine[pytorch]==2.8.0 --no-build-isolation --no-cache-dir \
&& pip install transformer_engine[pytorch]==2.10.0 --no-build-isolation --no-cache-dir \
&& pip install git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612 \
&& NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 pip install -v \
--disable-pip-version-check --no-cache-dir --no-build-isolation \
Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/Dockerfile.uv
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ RUN . /opt/venv/bin/activate && \
uv pip install -e .[megatron] && \
uv pip install flash_attn==2.8.1 --no-build-isolation && \
uv pip install git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612 && \
uv pip install transformer_engine[pytorch]==2.8.0 --no-build-isolation --no-cache-dir && \
uv pip install transformer_engine[pytorch]==2.10.0 --no-build-isolation --no-cache-dir && \
NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 \
uv pip install -v --no-build-isolation \
--config-settings="--build-option=--cpp_ext" \
Expand Down
33 changes: 0 additions & 33 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
from trinity.common.experience import EID, Experience
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import RMGalleryFn
from trinity.common.workflows import WORKFLOWS, Workflow
from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow
from trinity.common.workflows.eval_workflow import MathEvalWorkflow
from trinity.common.workflows.math_rm_workflow import MathRMWorkflow
from trinity.common.workflows.workflow import MathWorkflow, MultiTurnWorkflow, Task
from trinity.explorer.workflow_runner import WorkflowRunner

Expand Down Expand Up @@ -358,37 +356,6 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)

def test_rm_gallery_workflow(self) -> None:
model = MagicMock()
model.chat.return_value = [
MockResponse("<think> balabalabala 99 </think>\n \\boxed{36}"),
MockResponse("answer is \\boxed{36 }"),
MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
MockResponse("<think> balalaba </think> \\boxed{35.00}"),
]
taskset_config = get_unittest_dataset_config("countdown")
task = Task(
workflow=MathRMWorkflow,
reward_fn=RMGalleryFn,
repeat_times=taskset_config.repeat_times,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
reward_fn_args={
"reward_name": "math_verify_reward",
},
is_eval=False,
raw_task={
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: r"36",
},
)
workflow = task.to_workflow(model=model)
experiences = workflow.run()
self.assertEqual(experiences[0].reward, 1.0)
self.assertEqual(experiences[1].reward, 1.0)
self.assertEqual(experiences[2].reward, 1.0)
self.assertEqual(experiences[3].reward, 0.0)

def test_math_eval_workflow(self) -> None:
model = MagicMock()
model.chat.return_value = [
Expand Down
2 changes: 1 addition & 1 deletion trinity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Trinity-RFT (Reinforcement Fine-Tuning)"""

__version__ = "0.4.1"
__version__ = "0.5.0.dev0"
66 changes: 25 additions & 41 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ def __init__(
rope_kwargs = {"hf_overrides": rope_params}
else:
rope_kwargs = {}
self.logprobs_no_prefix_cache = True
else:
rope_kwargs = {
key: getattr(config, key)
for key in ["rope_scaling", "rope_theta"]
if getattr(config, key) is not None
}
self.logprobs_no_prefix_cache = False
engine_args = vllm.AsyncEngineArgs(
model=config.model_path,
enforce_eager=config.enforce_eager,
Expand All @@ -111,7 +113,6 @@ def __init__(
enable_chunked_prefill=config.enable_chunked_prefill,
dtype=config.dtype,
trust_remote_code=True,
task="generate",
gpu_memory_utilization=config.gpu_memory_utilization,
override_generation_config={ # TODO: find a way to unittest this
"temperature": config.temperature,
Expand All @@ -132,7 +133,8 @@ def __init__(
engine_args.disable_log_requests = not config.enable_log_requests
if self.vllm_version >= parse_version("0.11.0"):
engine_args.reasoning_parser = config.reasoning_parser

if self.vllm_version >= parse_version("0.13.0"):
engine_args.async_scheduling = False
self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.processor = None
self.state_dict_meta = None
Expand Down Expand Up @@ -326,13 +328,19 @@ async def logprobs( # type: ignore [override]
temperature = temperature if temperature is not None else self.config.temperature
if temperature is None:
temperature = 1.0
kwargs = {
"n": 1,
"max_tokens": 1,
"prompt_logprobs": 0, # vLLM return `prompt_logprobs + 1` logrpobs for each token
"temperature": temperature,
}
# avoid using prefix cache when calculating logprobs, only for vLLM >= 0.12.0
if self.logprobs_no_prefix_cache:
kwargs["skip_reading_prefix_cache"] = True
output = await self._generate_internal(
prompt={"prompt_token_ids": token_ids},
lora_request=lora_request,
n=1,
max_tokens=1,
prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token
temperature=temperature,
**kwargs,
)
return torch.tensor(
[list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]],
Expand Down Expand Up @@ -404,6 +412,8 @@ async def sample(
# in vLLM, 0 means only return the chosen token's logprob
"logprobs": 0,
}
if include_prompt_logprobs and self.logprobs_no_prefix_cache:
params["skip_reading_prefix_cache"] = True
if sampling_params.stop is not None:
params["stop"] = sampling_params.stop
req_output = await self._generate_internal(
Expand Down Expand Up @@ -579,41 +589,15 @@ async def run_api_server(self) -> bool:
return True # already running

api_server_host, api_server_port = self.get_available_address()
if self.vllm_version <= parse_version("0.11.0"):
from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

self.api_server = asyncio.create_task(
run_api_server_in_ray_actor(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path, # type: ignore [arg-type]
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
self.config.enable_log_requests,
)
)
else:
from trinity.common.models.vllm_patch.api_patch_v12 import (
run_api_server_in_ray_actor_v12,
)

self.api_server = asyncio.create_task(
run_api_server_in_ray_actor_v12(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path, # type: ignore [arg-type]
logger=self.logger,
enable_auto_tool_choice=self.config.enable_auto_tool_choice,
tool_call_parser=self.config.tool_call_parser,
reasoning_parser=self.config.reasoning_parser,
enable_log_requests=self.config.enable_log_requests,
)
)
from trinity.common.models.vllm_patch import get_api_server

self.api_server = get_api_server(
self.async_llm,
host=api_server_host,
port=api_server_port,
config=self.config,
logger=self.logger,
)
self.api_server_host = api_server_host
self.api_server_port = api_server_port
return True
Expand Down
69 changes: 69 additions & 0 deletions trinity/common/models/vllm_patch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import asyncio
from logging import Logger

import vllm
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version

from trinity.common.config import InferenceModelConfig


def get_vllm_version():
try:
Expand All @@ -11,3 +16,67 @@ def get_vllm_version():
# we cannot parse the version, trait it as the lowest version we support
vllm_version = parse_version("0.8.5")
return vllm_version


def get_api_server(
async_llm,
host: str,
port: int,
config: InferenceModelConfig,
logger: Logger,
):
vllm_version = get_vllm_version()
if vllm_version <= parse_version("0.11.0"):
from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

return asyncio.create_task(
run_api_server_in_ray_actor(
async_llm,
host=host,
port=port,
model_path=config.model_path, # type: ignore [arg-type]
enable_auto_tool_choice=config.enable_auto_tool_choice,
tool_call_parser=config.tool_call_parser,
reasoning_parser=config.reasoning_parser,
enable_log_requests=config.enable_log_requests,
)
)
elif vllm_version == parse_version("0.12.0"):
from trinity.common.models.vllm_patch.api_patch_v12 import (
run_api_server_in_ray_actor_v12,
)

return asyncio.create_task(
run_api_server_in_ray_actor_v12(
async_llm,
host=host,
port=port,
model_path=config.model_path, # type: ignore [arg-type]
logger=logger,
enable_auto_tool_choice=config.enable_auto_tool_choice,
tool_call_parser=config.tool_call_parser,
reasoning_parser=config.reasoning_parser,
enable_log_requests=config.enable_log_requests,
)
)
else:
from trinity.common.models.vllm_patch.api_patch_v13 import (
run_api_server_in_ray_actor_v13,
)

logger.info(f"Using vLLM API patch for version {vllm.__version__}")
return asyncio.create_task(
run_api_server_in_ray_actor_v13(
async_llm,
host=host,
port=port,
model_path=config.model_path, # type: ignore [arg-type]
logger=logger,
enable_auto_tool_choice=config.enable_auto_tool_choice,
tool_call_parser=config.tool_call_parser,
reasoning_parser=config.reasoning_parser,
enable_log_requests=config.enable_log_requests,
)
)
5 changes: 1 addition & 4 deletions trinity/common/models/vllm_patch/api_patch_v12.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Patch for vllm OpenAI API server. Only for vllm versions >0.11.0.

1. Mocks the `add_signal_handler` method to do nothing.
2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`.
"""Patch for vllm OpenAI API server. Only for vllm versions == 0.12.0.
"""
import logging
from typing import Optional
Expand Down
Loading