diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 1772c071a745..bd7b8abd5f9e 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -32,17 +32,23 @@ from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.mapping import Mapping from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import GenerationSession, ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig +from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig from transformers import PreTrainedTokenizer LOGGER = logging.getLogger("NeMo") use_trtllm_bindings = True try: - from tensorrt_llm.bindings import GptJsonConfig, KvCacheConfig, WorldConfig + from tensorrt_llm.bindings import GptJsonConfig except Exception as e: use_trtllm_bindings = False +TRTLLM_SUPPORTS_DEVICE_DISABLE = True +try: + from tensorrt_llm.runtime.generation import DISABLE_TORCH_DEVICE_SET +except (ImportError, ModuleNotFoundError): + TRTLLM_SUPPORTS_DEVICE_DISABLE = False + @dataclass class TensorrtLLMHostContext: @@ -494,12 +500,20 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): json_config_str = f.read() engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank) + + if not TRTLLM_SUPPORTS_DEVICE_DISABLE: + raise RuntimeError( + f"TensorRT-LLM does not support torch device disabling. Please upgrade TensorRT-LLM to make use of this feature." + ) + elif not DISABLE_TORCH_DEVICE_SET: + raise RuntimeError( + f"To use TensorRT-LLM's python ModelRunner API in load_distributed(...) you must set the env var DISABLE_TORCH_DEVICE_SET=1" + ) decoder = ModelRunner.from_engine( engine=engine, # We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process # So we will set it to the current device rank=torch.cuda.current_device(), - _disable_torch_cuda_device_set=True, ) tensorrt_llm_worker_context.decoder = decoder