Skip to content

Commit f2f5e8d

Browse files
authored
TRT-LLM installation tool (#3829)
1 parent 7e51f49 commit f2f5e8d

File tree

12 files changed

+435
-204
lines changed

12 files changed

+435
-204
lines changed

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,39 @@ jobs:
461461
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_ts_integrations_tests_results.xml integrations/
462462
popd
463463
464+
L2-dynamo-distributed-tests:
465+
name: L2 dynamo distributed tests
466+
needs: [filter-matrix, build, L1-dynamo-core-tests, L1-dynamo-compile-tests, L1-torch-compile-tests, L1-torchscript-tests]
467+
strategy:
468+
fail-fast: false
469+
matrix:
470+
include:
471+
- repository: pytorch/tensorrt
472+
package-name: torch_tensorrt
473+
pre-script: packaging/pre_build_script.sh
474+
post-script: packaging/post_build_script.sh
475+
smoke-test-script: packaging/smoke_test_script.sh
476+
uses: ./.github/workflows/linux-test.yml
477+
with:
478+
job-name: L2-dynamo-distributed-tests
479+
repository: "pytorch/tensorrt"
480+
ref: ""
481+
test-infra-repository: pytorch/test-infra
482+
test-infra-ref: main
483+
build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
484+
pre-script: ${{ matrix.pre-script }}
485+
script: |
486+
set -euo pipefail
487+
export USE_HOST_DEPS=1
488+
export CI_BUILD=1
489+
export USE_TRTLLM_PLUGINS=1
490+
dnf install -y mpich mpich-devel openmpi openmpi-devel
491+
pushd .
492+
cd tests/py
493+
cd dynamo
494+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml distributed/test_nccl_ops.py
495+
popd
496+
464497
concurrency:
465498
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-tensorrt-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
466499
cancel-in-progress: true

dev_dep_versions.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
__cuda_version__: "12.8"
22
__tensorrt_version__: "10.13.3"
33
__tensorrt_rtx_version__: "1.0.0"
4+
__tensorrt_llm_version__: "0.17.0.post1"

py/torch_tensorrt/_features.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import tensorrt
88
from torch_tensorrt._utils import (
99
check_cross_compile_trt_win_lib,
10+
load_tensorrt_llm_for_nccl,
1011
sanitized_torch_version,
1112
)
1213

@@ -23,6 +24,7 @@
2324
"qdp_plugin",
2425
"windows_cross_compile",
2526
"tensorrt_rtx",
27+
"trtllm_for_nccl",
2628
],
2729
)
2830

@@ -48,6 +50,7 @@
4850
_FX_FE_AVAIL = False if _TENSORRT_RTX else True
4951
_REFIT_AVAIL = True
5052
_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib()
53+
_TRTLLM_AVAIL = load_tensorrt_llm_for_nccl()
5154

5255
if importlib.util.find_spec("tensorrt.plugin"):
5356
_QDP_PLUGIN_AVAIL = True
@@ -63,6 +66,7 @@
6366
_QDP_PLUGIN_AVAIL,
6467
_WINDOWS_CROSS_COMPILE,
6568
_TENSORRT_RTX,
69+
_TRTLLM_AVAIL,
6670
)
6771

6872
T = TypeVar("T")
@@ -158,6 +162,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
158162
return wrapper
159163

160164

165+
def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]:
166+
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
167+
if ENABLED_FEATURES.trtllm_for_nccl:
168+
return f(*args, **kwargs)
169+
else:
170+
171+
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
172+
raise NotImplementedError(
173+
"Refit feature is currently not available in Python 3.13 or higher"
174+
)
175+
176+
return not_implemented(*args, **kwargs)
177+
178+
return wrapper
179+
180+
161181
def for_all_methods(
162182
decorator: Callable[..., Any], exclude: Optional[List[str]] = None
163183
) -> Callable[..., Any]:

py/torch_tensorrt/_utils.py

Lines changed: 265 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
import ctypes
2+
import getpass
3+
import logging
4+
import os
5+
import platform
16
import sys
2-
from typing import Any
7+
import tempfile
8+
import urllib.request
9+
from pathlib import Path
10+
from typing import Any, Optional
311

412
import tensorrt as trt
513
import torch
614

15+
logger = logging.getLogger(__name__)
16+
17+
_WHL_CPYTHON_VERSION = "cp310"
18+
_TENSORRT_LLM_VERSION_ = "0.17.0.post1"
19+
720

821
def sanitized_torch_version() -> Any:
922
return (
@@ -50,3 +63,254 @@ def is_tensorrt_version_supported(min_version: str) -> bool:
5063
except (ImportError, ValueError):
5164
# If tensorrt is not installed or version cannot be determined
5265
return False
66+
67+
68+
def is_thor() -> bool:
69+
if torch.cuda.get_device_capability() in [(11, 0)]:
70+
return True
71+
return False
72+
73+
74+
def is_platform_supported_for_trtllm() -> bool:
75+
"""
76+
Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend.
77+
78+
Returns:
79+
bool: True if supported, False otherwise.
80+
81+
Unsupported:
82+
- Windows platforms
83+
- Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release)
84+
- Thor devices
85+
- CUDA 13 not supported
86+
"""
87+
system = platform.system().lower()
88+
machine = platform.machine().lower()
89+
release = platform.release().lower()
90+
91+
if "windows" in system:
92+
logger.info(
93+
"TensorRT-LLM plugins for NCCL backend are not supported on Windows."
94+
)
95+
return False
96+
97+
if machine == "aarch64" and "tegra" in release or is_thor():
98+
logger.info(
99+
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) or Thor devices."
100+
)
101+
return False
102+
103+
try:
104+
cuda_version = torch.version.cuda # e.g., "12.4" or "13.0"
105+
if cuda_version is None:
106+
logger.error(
107+
"This pytorch build does not support CUDA, please reinstall pytorch with CUDA support"
108+
)
109+
return False
110+
111+
major, minor = map(int, cuda_version.split("."))
112+
if major != 12:
113+
logger.error(
114+
"CUDA 13 is not currently supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support"
115+
)
116+
return False
117+
118+
return True
119+
120+
except Exception as e:
121+
logger.warning(f"Failed to detect CUDA version: {e}")
122+
return False
123+
124+
return True
125+
126+
127+
def _cache_root() -> Path:
128+
username = getpass.getuser()
129+
return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
130+
131+
132+
def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
133+
return (
134+
_cache_root()
135+
/ "trtllm"
136+
/ f"{_TENSORRT_LLM_VERSION_}_{platform_system}_{platform_machine}"
137+
)
138+
139+
140+
def download_and_get_plugin_lib_path() -> Optional[str]:
141+
"""
142+
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
143+
144+
Args:
145+
platform (str): Platform identifier (e.g., 'linux_x86_64')
146+
147+
Returns:
148+
Optional[str]: Path to shared library or None if operation fails.
149+
"""
150+
platform_system = platform.system().lower()
151+
platform_machine = platform.machine().lower()
152+
wheel_filename = (
153+
f"tensorrt_llm-{_TENSORRT_LLM_VERSION_}-{_WHL_CPYTHON_VERSION}-"
154+
f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl"
155+
)
156+
wheel_path = _cache_root() / wheel_filename
157+
extract_dir = _extracted_dir_trtllm(platform_system, platform_machine)
158+
# else will never be met though
159+
lib_filename = (
160+
"libnvinfer_plugin_tensorrt_llm.so"
161+
if "linux" in platform_system
162+
else "libnvinfer_plugin_tensorrt_llm.dll"
163+
)
164+
# eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
165+
plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename
166+
167+
if plugin_lib_path.exists():
168+
return str(plugin_lib_path)
169+
170+
wheel_path.parent.mkdir(parents=True, exist_ok=True)
171+
extract_dir.mkdir(parents=True, exist_ok=True)
172+
173+
if not wheel_path.exists():
174+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
175+
download_url = base_url + wheel_filename
176+
try:
177+
logger.debug(f"Downloading {download_url} ...")
178+
urllib.request.urlretrieve(download_url, wheel_path)
179+
logger.debug("Download succeeded and TRT-LLM wheel is now present")
180+
except urllib.error.HTTPError as e:
181+
logger.error(
182+
f"HTTP error {e.code} when trying to download {download_url}: {e.reason}"
183+
)
184+
except urllib.error.URLError as e:
185+
logger.error(
186+
f"URL error when trying to download {download_url}: {e.reason}"
187+
)
188+
except OSError as e:
189+
logger.error(f"Local file write error: {e}")
190+
191+
try:
192+
import zipfile
193+
except ImportError as e:
194+
raise ImportError(
195+
"zipfile module is required but not found. Please install zipfile"
196+
)
197+
try:
198+
with zipfile.ZipFile(wheel_path) as zip_ref:
199+
zip_ref.extractall(extract_dir)
200+
logger.debug(f"Extracted wheel to {extract_dir}")
201+
except FileNotFoundError as e:
202+
# This should capture the errors in the download failure above
203+
logger.error(f"Wheel file not found at {wheel_path}: {e}")
204+
raise RuntimeError(
205+
f"Failed to find downloaded wheel file at {wheel_path}"
206+
) from e
207+
except zipfile.BadZipFile as e:
208+
logger.error(f"Invalid or corrupted wheel file: {e}")
209+
raise RuntimeError(
210+
"Downloaded wheel file is corrupted or not a valid zip archive"
211+
) from e
212+
except Exception as e:
213+
logger.error(f"Unexpected error while extracting wheel: {e}")
214+
raise RuntimeError(
215+
"Unexpected error during extraction of TensorRT-LLM wheel"
216+
) from e
217+
218+
try:
219+
wheel_path.unlink(missing_ok=True)
220+
logger.debug(f"Deleted wheel file: {wheel_path}")
221+
except Exception as e:
222+
logger.warning(f"Could not delete wheel file {wheel_path}: {e}")
223+
if not plugin_lib_path.exists():
224+
logger.error(
225+
f"Plugin library not found at expected location: {plugin_lib_path}"
226+
)
227+
return None
228+
229+
return str(plugin_lib_path)
230+
231+
232+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
233+
"""
234+
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
235+
236+
Args:
237+
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
238+
239+
Returns:
240+
bool: True if successful, False otherwise.
241+
"""
242+
try:
243+
handle = ctypes.CDLL(plugin_lib_path)
244+
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
245+
except OSError as e_os_error:
246+
if "libmpi" in str(e_os_error):
247+
logger.warning(
248+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)",
249+
exc_info=e_os_error,
250+
)
251+
else:
252+
logger.warning(
253+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
254+
f"Ensure the path is correct and the library is compatible.",
255+
exc_info=e_os_error,
256+
)
257+
return False
258+
259+
try:
260+
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
261+
handle.initTrtLlmPlugins.restype = ctypes.c_bool
262+
except AttributeError as e_plugin_unavailable:
263+
logger.warning(
264+
"Unable to initialize the TensorRT-LLM plugin library",
265+
exc_info=e_plugin_unavailable,
266+
)
267+
return False
268+
269+
try:
270+
if handle.initTrtLlmPlugins(None, b"tensorrt_llm"):
271+
logger.info("TensorRT-LLM plugin successfully initialized")
272+
return True
273+
else:
274+
logger.warning("TensorRT-LLM plugin library failed in initialization")
275+
return False
276+
except Exception as e_initialization_error:
277+
logger.warning(
278+
"Exception occurred during TensorRT-LLM plugin library initialization",
279+
exc_info=e_initialization_error,
280+
)
281+
return False
282+
return False
283+
284+
285+
def load_tensorrt_llm_for_nccl() -> bool:
286+
"""
287+
Attempts to load the TensorRT-LLM plugin and initialize it.
288+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
289+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
290+
291+
Returns:
292+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
293+
"""
294+
if not is_platform_supported_for_trtllm():
295+
return False
296+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
297+
298+
if plugin_lib_path:
299+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
300+
else:
301+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
302+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
303+
"1",
304+
"true",
305+
"yes",
306+
"on",
307+
)
308+
if not use_trtllm_plugin:
309+
logger.warning(
310+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
311+
)
312+
return False
313+
314+
plugin_lib_path = download_and_get_plugin_lib_path()
315+
return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type]
316+
return False

0 commit comments

Comments
 (0)