Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/guides/use-custom-vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ dependencies = [

[project.optional-dependencies]
vllm = [
#"vllm==0.9.0", # <-- BEFORE
"vllm", # <-- AFTER
#"vllm==0.10.0", # <-- BEFORE
"vllm", # <-- AFTER
]

# ...<OMITTED>
Expand Down
5 changes: 4 additions & 1 deletion nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def _patch_vllm_init_workers_ray():
llm_kwargs = dict(
model=self.model_name,
load_format=load_format,
skip_tokenizer_init=self.cfg["vllm_cfg"]["skip_tokenizer_init"],
# vllm==0.10.0 breaks skip_tokenizer_init=True.
# This will be reverted to `self.cfg["vllm_cfg"]["skip_tokenizer_init"]` once https://github.com/NVIDIA-NeMo/RL/issues/818 is resolved.
skip_tokenizer_init=False,
tensor_parallel_size=self.tensor_parallel_size,
pipeline_parallel_size=self.pipeline_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
Expand All @@ -338,6 +340,7 @@ def _patch_vllm_init_workers_ray():
worker_extension_cls="nemo_rl.models.generation.vllm_backend.VllmInternalWorkerExtension",
enable_sleep_mode=True,
disable_log_stats=True,
logprobs_mode="raw_logprobs",
**vllm_kwargs,
)

Expand Down
50 changes: 1 addition & 49 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from typing import Any, Iterable, Optional
from typing import Any, Optional

import torch
from torch.multiprocessing.reductions import rebuild_cuda_tensor
Expand All @@ -28,49 +27,6 @@
)


def _patch_gemma3_mm():
"""Patch gemma3_mm.py to support new HF multimodal format (post transformers v4.52).

Patch taken from:https://github.com/vllm-project/vllm/pull/19151/files#diff-5890909300e4e6c3160444e4587ec3fd80498bb83f598b22ce81337f75992b06
"""
from packaging.version import Version as PkgVersion

assert PkgVersion(vllm.__version__) < PkgVersion("0.9.2"), (
f"You are using vllm version {vllm.__version__}. "
"Please remove this patch (_patch_gemma3_mm in nemo_rl/models/generation/vllm_backend.py) "
"since it is included in vllm>=0.9.2."
)

from vllm.logger import init_logger
from vllm.model_executor.models import gemma3_mm
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
)

logger = init_logger("gemma3_mm_patch")

gemma3_mm.Gemma3ForConditionalGeneration.hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
}
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

gemma3_mm.Gemma3ForConditionalGeneration.load_weights = load_weights
logger.info("Successfully patched gemma3_mm.py in vllm_backend.")


_patch_gemma3_mm()


class VllmInternalWorkerExtension:
def init_collective(
self, rank_prefix: int, ip: str, port: int, world_size: int
Expand All @@ -82,10 +38,6 @@ def init_collective(
local_rank = torch.distributed.get_rank()
rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster

# Temporary fix for vllm==0.9.0 which overrides the NCCL_CUMEM_ENABLE to 0 and causes
# https://github.com/NVIDIA-NeMo/RL/issues/564. This can be removed after it is upgraded to vllm>=0.9.1rc1.
os.environ["NCCL_CUMEM_ENABLE"] = "1"

pg = StatelessProcessGroup.create(
host=ip, port=port, rank=rank, world_size=world_size
)
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ license = {text = "Apache 2.0"}
dependencies = [
"setuptools",
"ninja", # for flash-attn parallel build
"torch==2.7.0",
"torch==2.7.1",
"triton",
"colored==2.2.3",
"ray[default]==2.46.0",
# vllm<0.10.0 and transformers>=4.54.0 has a conflict
# Remove this once we upgrade vllm to >=0.10.0
# transformers==4.54.0/4.54.1 both fail on rm models
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved
"transformers>=4.51.0,<4.54.0",
"wandb",
"numpy",
Expand Down Expand Up @@ -57,7 +57,7 @@ automodel = [
"flash-attn==2.7.4.post1",
]
vllm = [
"vllm==0.9.0",
"vllm==0.10.0",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"flash-attn==2.7.4.post1",
]
Expand All @@ -81,7 +81,7 @@ mcore = [
# This is a default group so that we install these even with bare `uv sync`
build = [
# Build requirement for TE
"torch==2.7.0",
"torch==2.7.1",
# Build requirement for TE
"setuptools",
"packaging",
Expand Down
Loading