Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from importlib import metadata
from importlib.metadata import PackageNotFoundError
from typing import TYPE_CHECKING, Any, Optional, Union

import msgspec
Expand Down Expand Up @@ -78,6 +80,14 @@
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())


def get_nixl_version(package_name: str = 'nixl') -> str:
"""Gets the version of an installed Python package."""
try:
return metadata.version(package_name)
except PackageNotFoundError:
return "0.0.0"


class NixlAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
Expand Down Expand Up @@ -480,8 +490,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
if nixl_agent_config is None:
config = None
else:
ucx_args = {
'num_threads': 8,
} if get_nixl_version() >= '0.5.1' else {}
Comment on lines +493 to +495
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Comparing versions as strings is error-prone because it uses lexicographical ordering. For example, '0.10.0' would be considered less than '0.5.1'. To ensure correct version comparison, you should parse the version string into a comparable format, like a tuple of integers. This approach is more robust than string comparison for simple version numbers. For full compliance with PEP 440 versioning, consider using the packaging library if nixl adopts more complex versioning schemes in the future.

Suggested change
ucx_args = {
'num_threads': 8,
} if get_nixl_version() >= '0.5.1' else {}
ucx_args = {
'num_threads': 8,
} if tuple(map(int, get_nixl_version().split('.'))) >= (0, 5, 1) else {}

config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 else nixl_agent_config(num_threads=8)
non_ucx_backends) > 0 else nixl_agent_config(**ucx_args)

self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
Expand Down