diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index f2294d85cdd..a923b461eb9 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-unittest:20260115 + image: trinity-rft-unittest:20260126 cap_add: - SYS_PTRACE pull_policy: never @@ -32,7 +32,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-unittest:20260115 + image: trinity-rft-unittest:20260126 cap_add: - SYS_PTRACE pull_policy: never diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 53b486ae1c6..6a3f5211fd4 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -59,6 +59,10 @@ jobs: MODULE=$(echo "$COMMENT" | sed -n 's/\/unittest-module-\(.*\)/\1/p') echo "type=module" >> $GITHUB_OUTPUT echo "module=$MODULE" >> $GITHUB_OUTPUT + elif [[ "$COMMENT" =~ ^/unittest-pattern-(.+)$ ]]; then + PATTERN=$(echo "$COMMENT" | sed -n 's/\/unittest-pattern-\(.*\)/\1/p') + echo "type=pattern" >> $GITHUB_OUTPUT + echo "pattern=$PATTERN" >> $GITHUB_OUTPUT else echo "type=all" >> $GITHUB_OUTPUT fi @@ -98,6 +102,15 @@ jobs: echo "No module specified, skipping tests." echo "tests_run=false" >> $GITHUB_ENV fi + elif [ "$TYPE" = "pattern" ]; then + PATTERN="${{ steps.test_type.outputs.pattern }}" + if [ -n "$PATTERN" ]; then + echo "tests_run=true" >> $GITHUB_ENV + docker compose exec trinity-node-1 bash -c "source /opt/venv/bin/activate && pytest tests -v -s -k '$PATTERN' --ctrf report.json" + else + echo "No pattern specified, skipping tests." + echo "tests_run=false" >> $GITHUB_ENV + fi fi - name: Convert report.json time to ms diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index 7663bcd41a1..5911d911b62 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -111,10 +111,10 @@ Thus you can prepare a split environment for it and start the server manually us ```shell # prepare split environments, including the one of data processor -python scripts/install.py +python scripts/data/install.py # start all split servers -python scripts/start_servers.py +python scripts/data/start_servers.py ``` These scripts will create split environments for Trinity-RFT and Data-Juicer-based data processor. diff --git a/docs/sphinx_doc/source_zh/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source_zh/tutorial/example_data_functionalities.md index c1a9a42b930..af31cd5b96e 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_data_functionalities.md @@ -107,10 +107,10 @@ trinity run --config ```shell # 准备独立环境,包括数据处理器环境 -python scripts/install.py +python scripts/data/install.py # 启动所有独立服务 -python scripts/start_servers.py +python scripts/data/start_servers.py ``` 这些脚本将为 Trinity-RFT 和基于 Data-Juicer 的数据处理器创建独立环境,并在 Data-Juicer 环境中自动启动数据处理器服务。 diff --git a/pyproject.toml b/pyproject.toml index 738041d1e94..411864a793d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trinity-rft" -version = "0.4.1" +version = "0.5.0.dev0" authors = [ {name="Trinity-RFT Team", email="trinity-rft@outlook.com"}, ] @@ -50,7 +50,9 @@ trinity = "trinity.cli.launcher:main" [project.optional-dependencies] vllm = [ - "vllm>=0.10.2,<=0.11.0", + "vllm>=0.10.2,<=0.14.1,!=0.12.0", + # v0.12.0 has a huge performance regression so we exclude it + # v0.10.2 is the most stable version, but we allow up to 0.14.1 for new features ] data = [ "py-data-juicer>=1.4.3" @@ -75,10 +77,10 @@ dev = [ "viztracer", ] megatron = [ - "megatron-core[mlm]==0.13.1", + "megatron-core[mlm]==0.15.0", # if you found "undefined symbol" error in transformer engine # reinstall it with --no-build-isolation and `--no-cache-dir` flag - # "transformer_engine[pytorch]==2.8.0", + # "transformer_engine[pytorch]==2.10.0", # Install mbridge from main branch (unreleased version) # "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612", diff --git a/scripts/install.py b/scripts/data/install.py similarity index 100% rename from scripts/install.py rename to scripts/data/install.py diff --git a/scripts/start_servers.py b/scripts/data/start_servers.py similarity index 95% rename from scripts/start_servers.py rename to scripts/data/start_servers.py index 2e6e74961f3..239e1f5fbad 100644 --- a/scripts/start_servers.py +++ b/scripts/data/start_servers.py @@ -22,7 +22,7 @@ def main(): os.makedirs(args.log_dir, exist_ok=True) env_mapping_file = os.path.join( - os.path.dirname(__file__), "..", "environments", "env_mapping.json" + os.path.dirname(__file__), "..", "..", "environments", "env_mapping.json" ) with open(env_mapping_file, "r") as f: env_mapping = json.load(f) diff --git a/scripts/docker/Dockerfile.megatron b/scripts/docker/Dockerfile.megatron index f4e1ae29c32..681dd1c9f5e 100644 --- a/scripts/docker/Dockerfile.megatron +++ b/scripts/docker/Dockerfile.megatron @@ -31,7 +31,7 @@ RUN pip install --upgrade pip \ && pip install -e .[vllm,mm,dev] \ && pip install flash_attn==2.8.1 --no-build-isolation \ && pip install -e .[megatron] \ - && pip install transformer_engine[pytorch]==2.8.0 --no-build-isolation --no-cache-dir \ + && pip install transformer_engine[pytorch]==2.10.0 --no-build-isolation --no-cache-dir \ && pip install git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612 \ && NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 pip install -v \ --disable-pip-version-check --no-cache-dir --no-build-isolation \ diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 3b47393147b..a33026d98de 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -43,7 +43,7 @@ RUN . /opt/venv/bin/activate && \ uv pip install -e .[megatron] && \ uv pip install flash_attn==2.8.1 --no-build-isolation && \ uv pip install git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612 && \ - uv pip install transformer_engine[pytorch]==2.8.0 --no-build-isolation --no-cache-dir && \ + uv pip install transformer_engine[pytorch]==2.10.0 --no-build-isolation --no-cache-dir && \ NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 \ uv pip install -v --no-build-isolation \ --config-settings="--build-option=--cpp_ext" \ diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 982d75a5ec2..6dbee05f8fd 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -19,11 +19,9 @@ from trinity.common.experience import EID, Experience from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper -from trinity.common.rewards.reward_fn import RMGalleryFn from trinity.common.workflows import WORKFLOWS, Workflow from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow from trinity.common.workflows.eval_workflow import MathEvalWorkflow -from trinity.common.workflows.math_rm_workflow import MathRMWorkflow from trinity.common.workflows.workflow import MathWorkflow, MultiTurnWorkflow, Task from trinity.explorer.workflow_runner import WorkflowRunner @@ -358,37 +356,6 @@ def test_gsm8k_workflow(self) -> None: self.assertEqual(experiences[2].reward, -0.1) self.assertEqual(experiences[3].reward, 1.1) - def test_rm_gallery_workflow(self) -> None: - model = MagicMock() - model.chat.return_value = [ - MockResponse(" balabalabala 99 \n \\boxed{36}"), - MockResponse("answer is \\boxed{36 }"), - MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"), - MockResponse(" balalaba \\boxed{35.00}"), - ] - taskset_config = get_unittest_dataset_config("countdown") - task = Task( - workflow=MathRMWorkflow, - reward_fn=RMGalleryFn, - repeat_times=taskset_config.repeat_times, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - reward_fn_args={ - "reward_name": "math_verify_reward", - }, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "", - taskset_config.format.response_key: r"36", - }, - ) - workflow = task.to_workflow(model=model) - experiences = workflow.run() - self.assertEqual(experiences[0].reward, 1.0) - self.assertEqual(experiences[1].reward, 1.0) - self.assertEqual(experiences[2].reward, 1.0) - self.assertEqual(experiences[3].reward, 0.0) - def test_math_eval_workflow(self) -> None: model = MagicMock() model.chat.return_value = [ diff --git a/trinity/__init__.py b/trinity/__init__.py index ef49c839306..b29fbc48919 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """Trinity-RFT (Reinforcement Fine-Tuning)""" -__version__ = "0.4.1" +__version__ = "0.5.0.dev0" diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 59369f5e18c..f2c732b1b54 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -93,12 +93,14 @@ def __init__( rope_kwargs = {"hf_overrides": rope_params} else: rope_kwargs = {} + self.logprobs_no_prefix_cache = True else: rope_kwargs = { key: getattr(config, key) for key in ["rope_scaling", "rope_theta"] if getattr(config, key) is not None } + self.logprobs_no_prefix_cache = False engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, @@ -111,7 +113,6 @@ def __init__( enable_chunked_prefill=config.enable_chunked_prefill, dtype=config.dtype, trust_remote_code=True, - task="generate", gpu_memory_utilization=config.gpu_memory_utilization, override_generation_config={ # TODO: find a way to unittest this "temperature": config.temperature, @@ -132,7 +133,8 @@ def __init__( engine_args.disable_log_requests = not config.enable_log_requests if self.vllm_version >= parse_version("0.11.0"): engine_args.reasoning_parser = config.reasoning_parser - + if self.vllm_version >= parse_version("0.13.0"): + engine_args.async_scheduling = False self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.processor = None self.state_dict_meta = None @@ -326,13 +328,19 @@ async def logprobs( # type: ignore [override] temperature = temperature if temperature is not None else self.config.temperature if temperature is None: temperature = 1.0 + kwargs = { + "n": 1, + "max_tokens": 1, + "prompt_logprobs": 0, # vLLM return `prompt_logprobs + 1` logrpobs for each token + "temperature": temperature, + } + # avoid using prefix cache when calculating logprobs, only for vLLM >= 0.12.0 + if self.logprobs_no_prefix_cache: + kwargs["skip_reading_prefix_cache"] = True output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, - n=1, - max_tokens=1, - prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token - temperature=temperature, + **kwargs, ) return torch.tensor( [list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]], @@ -404,6 +412,8 @@ async def sample( # in vLLM, 0 means only return the chosen token's logprob "logprobs": 0, } + if include_prompt_logprobs and self.logprobs_no_prefix_cache: + params["skip_reading_prefix_cache"] = True if sampling_params.stop is not None: params["stop"] = sampling_params.stop req_output = await self._generate_internal( @@ -579,41 +589,15 @@ async def run_api_server(self) -> bool: return True # already running api_server_host, api_server_port = self.get_available_address() - if self.vllm_version <= parse_version("0.11.0"): - from trinity.common.models.vllm_patch.api_patch import ( - run_api_server_in_ray_actor, - ) - - self.api_server = asyncio.create_task( - run_api_server_in_ray_actor( - self.async_llm, - api_server_host, - api_server_port, - self.config.model_path, # type: ignore [arg-type] - self.config.enable_auto_tool_choice, - self.config.tool_call_parser, - self.config.reasoning_parser, - self.config.enable_log_requests, - ) - ) - else: - from trinity.common.models.vllm_patch.api_patch_v12 import ( - run_api_server_in_ray_actor_v12, - ) - - self.api_server = asyncio.create_task( - run_api_server_in_ray_actor_v12( - self.async_llm, - api_server_host, - api_server_port, - self.config.model_path, # type: ignore [arg-type] - logger=self.logger, - enable_auto_tool_choice=self.config.enable_auto_tool_choice, - tool_call_parser=self.config.tool_call_parser, - reasoning_parser=self.config.reasoning_parser, - enable_log_requests=self.config.enable_log_requests, - ) - ) + from trinity.common.models.vllm_patch import get_api_server + + self.api_server = get_api_server( + self.async_llm, + host=api_server_host, + port=api_server_port, + config=self.config, + logger=self.logger, + ) self.api_server_host = api_server_host self.api_server_port = api_server_port return True diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py index 294bb68eb4e..4253127a68a 100644 --- a/trinity/common/models/vllm_patch/__init__.py +++ b/trinity/common/models/vllm_patch/__init__.py @@ -1,7 +1,12 @@ +import asyncio +from logging import Logger + import vllm from packaging.version import InvalidVersion from packaging.version import parse as parse_version +from trinity.common.config import InferenceModelConfig + def get_vllm_version(): try: @@ -11,3 +16,67 @@ def get_vllm_version(): # we cannot parse the version, trait it as the lowest version we support vllm_version = parse_version("0.8.5") return vllm_version + + +def get_api_server( + async_llm, + host: str, + port: int, + config: InferenceModelConfig, + logger: Logger, +): + vllm_version = get_vllm_version() + if vllm_version <= parse_version("0.11.0"): + from trinity.common.models.vllm_patch.api_patch import ( + run_api_server_in_ray_actor, + ) + + return asyncio.create_task( + run_api_server_in_ray_actor( + async_llm, + host=host, + port=port, + model_path=config.model_path, # type: ignore [arg-type] + enable_auto_tool_choice=config.enable_auto_tool_choice, + tool_call_parser=config.tool_call_parser, + reasoning_parser=config.reasoning_parser, + enable_log_requests=config.enable_log_requests, + ) + ) + elif vllm_version == parse_version("0.12.0"): + from trinity.common.models.vllm_patch.api_patch_v12 import ( + run_api_server_in_ray_actor_v12, + ) + + return asyncio.create_task( + run_api_server_in_ray_actor_v12( + async_llm, + host=host, + port=port, + model_path=config.model_path, # type: ignore [arg-type] + logger=logger, + enable_auto_tool_choice=config.enable_auto_tool_choice, + tool_call_parser=config.tool_call_parser, + reasoning_parser=config.reasoning_parser, + enable_log_requests=config.enable_log_requests, + ) + ) + else: + from trinity.common.models.vllm_patch.api_patch_v13 import ( + run_api_server_in_ray_actor_v13, + ) + + logger.info(f"Using vLLM API patch for version {vllm.__version__}") + return asyncio.create_task( + run_api_server_in_ray_actor_v13( + async_llm, + host=host, + port=port, + model_path=config.model_path, # type: ignore [arg-type] + logger=logger, + enable_auto_tool_choice=config.enable_auto_tool_choice, + tool_call_parser=config.tool_call_parser, + reasoning_parser=config.reasoning_parser, + enable_log_requests=config.enable_log_requests, + ) + ) diff --git a/trinity/common/models/vllm_patch/api_patch_v12.py b/trinity/common/models/vllm_patch/api_patch_v12.py index b926b158a18..1419184a907 100644 --- a/trinity/common/models/vllm_patch/api_patch_v12.py +++ b/trinity/common/models/vllm_patch/api_patch_v12.py @@ -1,7 +1,4 @@ -"""Patch for vllm OpenAI API server. Only for vllm versions >0.11.0. - -1. Mocks the `add_signal_handler` method to do nothing. -2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. +"""Patch for vllm OpenAI API server. Only for vllm versions == 0.12.0. """ import logging from typing import Optional diff --git a/trinity/common/models/vllm_patch/api_patch_v13.py b/trinity/common/models/vllm_patch/api_patch_v13.py new file mode 100644 index 00000000000..480ad5424e3 --- /dev/null +++ b/trinity/common/models/vllm_patch/api_patch_v13.py @@ -0,0 +1,174 @@ +"""Patch for vllm OpenAI API server. Only for vllm versions >= 0.13.0. +""" +import asyncio +import functools +import logging +from typing import Optional + +import vllm +import vllm.envs as envs +from packaging.version import parse as parse_version +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + create_server_unix_socket, + init_app_state, + validate_api_server_args, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.utils import log_non_default_args +from vllm.reasoning import ReasoningParserManager +from vllm.tool_parsers import ToolParserManager +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.network_utils import is_valid_ipv6_address +from vllm.utils.system_utils import set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +from trinity.common.models.vllm_patch import get_vllm_version + + +def setup_server_in_ray(args, logger): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) + + validate_api_server_args(args) + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + if args.uds: + sock = create_server_unix_socket(args.uds) + else: + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + if args.uds: + listen_address = f"unix:{args.uds}" + else: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + return listen_address, sock + + +def dummy_add_signal_handler(self, *args, **kwargs): + # DO NOTHING HERE + pass + + +async def run_server_worker_in_ray( + listen_address, + sock, + args, + engine_client, + logger, +) -> None: + # Modified from vllm.entrypoints.openai.api_server.run_server_worker + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) + + app = build_app(args) + + await init_app_state(engine_client, app.state, args) + + loop = asyncio.get_event_loop() + loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) + + logger.info( + "Starting vLLM API server %d on %s", + engine_client.vllm_config.parallel_config._api_process_rank, + listen_address, + ) + + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, + h11_max_header_count=args.h11_max_header_count, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +async def run_server_in_ray(args, engine_client, logger): + # Modified from vllm.entrypoints.openai.api_server.run_server + listen_address, sock = setup_server_in_ray(args, logger) + logger.info("vLLM API server listening on %s", listen_address) + await run_server_worker_in_ray(listen_address, sock, args, engine_client, logger) + + +async def run_api_server_in_ray_actor_v13( + async_llm, + host: str, + port: int, + model_path: str, + logger: logging.Logger, + enable_auto_tool_choice: bool = False, + tool_call_parser: Optional[str] = None, + reasoning_parser: Optional[str] = None, + enable_log_requests: bool = False, +): + vllm_version = get_vllm_version() + if vllm_version < parse_version("0.13.0"): + raise ValueError( + f"Unsupported vllm version: {vllm.__version__}. " + "This patch requires vllm version >= 0.13.0" + ) + + parser = FlexibleArgumentParser(description="Run the OpenAI API server.") + args = make_arg_parser(parser) + cli_args = [ + "--host", + str(host), + "--port", + str(port), + "--model", + model_path, + "--enable-server-load-tracking", # enable tracking for load balancing + ] + if enable_log_requests: + cli_args.append("--enable-log-requests") + if enable_auto_tool_choice: + cli_args.append("--enable-auto-tool-choice") + if tool_call_parser: + cli_args.extend(["--tool-call-parser", tool_call_parser]) + if reasoning_parser: + cli_args.extend(["--reasoning-parser", reasoning_parser]) + args = parser.parse_args(cli_args) + args.structured_outputs_config.reasoning_parser = reasoning_parser + logger.info(f"Starting vLLM OpenAI API server with args: {args}") + await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index c58decbf7c3..a0d6f37647c 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -13,10 +13,10 @@ def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901 """Patch vLLM model runner to support prompt logprobs extraction.""" version = get_vllm_version() - if version < parse_version("0.10.2") or version > parse_version("0.12.0"): + if version < parse_version("0.10.2") or version > parse_version("0.14.1"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.10.2, <= 0.12.0." + "This patch requires vllm version >= 0.10.2, <= 0.14.1." ) is_v0102 = version == parse_version("0.10.2") @@ -150,7 +150,7 @@ def _get_prompt_logprobs_dict_v12( This is a monkey-patched version of `_get_prompt_logprobs_dict` from `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions - 0.10.2 to 0.11.0). + 0.12.0 to 0.14.1). The original function does not apply temperature scaling to logits when calculating prompt logprobs, which can lead to incorrect logprob values diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py index 5eacc98914c..90acfa340a1 100644 --- a/trinity/common/rewards/reward_fn.py +++ b/trinity/common/rewards/reward_fn.py @@ -22,6 +22,8 @@ def __call__(self, **kwargs) -> Dict[str, float]: class RMGalleryFn(RewardFn): """Reward Function from RMGallery. https://github.com/modelscope/RM-Gallery + + TODO: Update to OpenJudgeFn """ def __init__(