Skip to content

Commit

Permalink
fix(export): update API for disabling device reassignment in TRTLLM f…
Browse files Browse the repository at this point in the history
…or Aligner

[feat] Upgrade nemo-export path for aligner to TRTLLM-v12 and use python runtime

Signed-off-by: Terry Kong <terryk@nvidia.com>

fix: forgot to always set _disable_torch_cuda_device_set

Signed-off-by: Terry Kong <terryk@nvidia.com>

Signed-off-by: Terry Kong <terryk@nvidia.com>

Apply isort and black reformatting

Signed-off-by: terrykong <terrykong@users.noreply.github.com>

invert torch device set

Signed-off-by: Terry Kong <terryk@nvidia.com>
  • Loading branch information
terrykong committed Nov 1, 2024
1 parent 19668e5 commit b8bf39f
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 40 deletions.
14 changes: 12 additions & 2 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_distributed, refit
from nemo.export.trt_llm.tensorrt_llm_run import (
generate,
generate_streaming,
load,
load_distributed,
refit,
unload_engine,
)

use_deploy = True
try:
Expand Down Expand Up @@ -490,12 +497,12 @@ def build(
engine = build_and_save_engine(
max_input_len=max_input_len,
max_output_len=max_output_len,
max_seq_len=max_input_len + max_output_len,
max_batch_size=max_batch_size,
model_config=model_config[0],
model_weights=weights[0],
model_dir=self.model_dir,
model_type=model_type,
custom_all_reduce=False,
use_refit=use_refit,
)
torch.distributed.barrier()
Expand Down Expand Up @@ -968,3 +975,6 @@ def _load(self):
"model needs to be exported again. "
"Error message: " + repr(error)
) from error

def unload_engine(self):
unload_engine()
9 changes: 8 additions & 1 deletion nemo/export/trt_llm/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import tensorrt_llm
import torch
from tensorrt_llm._utils import torch_to_numpy
from tensorrt_llm._utils import mpi_comm, torch_to_numpy

# A global dicts to store exported weights.
# This is set to be a global variable to avoid extra code modification from tensorrt_llm.
Expand Down Expand Up @@ -492,6 +492,13 @@ def init_model_parallel_from_nemo(reshard_model):
pp_size = 1

mp_rank = tp_size * pp_rank + tp_rank
# Need to split cpp MPI World Comm because TensorRT-LLM NCCL plugins refer to the locally split comm.
# High level call structure is: MpiComm::split -> MpiComm::setSession -> LOCAL_COMM_SESSION (used in allReducePlugin.cpp)
tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank)
# Also split the python mpi communicator and set the global world one to the local split one
new_comm = mpi_comm().Split(color=dp_rank, key=mp_rank)
from mpi4py import MPI

MPI.COMM_WORLD = new_comm

return mp_rank, dp_rank, tp_size, pp_size, dp_size
2 changes: 2 additions & 0 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def build_and_save_engine(
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
reduce_fusion: bool = False,
):
architecture = "LLaMAForCausalLM" if model_config.architecture == "LlamaForCausalLM" else model_config.architecture
try:
Expand All @@ -71,6 +72,7 @@ def build_and_save_engine(
plugin_config.remove_input_padding = remove_input_padding
plugin_config.use_paged_context_fmha = paged_context_fmha
plugin_config.multiple_profiles = multiple_profiles
plugin_config.reduce_fusion = reduce_fusion

max_num_tokens, opt_num_tokens = check_max_num_tokens(
max_num_tokens=max_num_tokens,
Expand Down
138 changes: 101 additions & 37 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,30 @@
from typing import List, Optional

import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
from mpi4py.futures import MPIPoolExecutor
from tensorrt_llm._utils import mpi_comm
from tensorrt_llm.builder import Engine
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig

from transformers import PreTrainedTokenizer

LOGGER = logging.getLogger("NeMo")

use_trtllm_bindings = True
try:
from tensorrt_llm.bindings import GptJsonConfig, GptSession, GptSessionConfig, KvCacheConfig, WorldConfig
from tensorrt_llm.bindings import GptJsonConfig
except Exception as e:
use_trtllm_bindings = False

use_cpp_gpt_session = True
TRTLLM_SUPPORTS_DEVICE_DISABLE = True
try:
from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession
except Exception as e:
use_cpp_gpt_session = False
from tensorrt_llm.runtime.generation import DISABLE_TORCH_DEVICE_SET
except (ImportError, ModuleNotFoundError):
TRTLLM_SUPPORTS_DEVICE_DISABLE = False


@dataclass
Expand All @@ -63,7 +65,7 @@ class TensorrtLLMHostContext:
class TensorrtLLMWorkerContext:
"""The MPI worker side context for TRT LLM inference."""

decoder: ModelRunner = None
decoder: ModelRunner | ModelRunnerCpp = None
sampling_config: SamplingConfig = None
lora_manager: LoraManager = None
max_batch_size: int = 0
Expand Down Expand Up @@ -123,7 +125,6 @@ def _read_config(config_path: Path):
lora_plugin=config["plugin_config"]["lora_plugin"],
lora_target_modules=config["builder_config"]["lora_target_modules"],
quant_mode=quant_mode,
use_custom_all_reduce=config["plugin_config"]["use_custom_all_reduce"],
use_context_fmha_for_generation=config["plugin_config"]["use_context_fmha_for_generation"],
gather_context_logits=config["builder_config"]["gather_context_logits"],
gather_generation_logits=config["builder_config"]["gather_generation_logits"],
Expand Down Expand Up @@ -456,7 +457,7 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node):
this function creates a custom mapping of device_id to WorldConfig
"""
global tensorrt_llm_worker_context
if isinstance(tensorrt_llm_worker_context.decoder, ModelRunnerCppGptSession):
if isinstance(tensorrt_llm_worker_context.decoder, ModelRunner):
return

config_path = Path(engine_dir) / f"config_{torch.distributed.get_rank()}.json"
Expand All @@ -480,46 +481,109 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node):
device_ids = [i for i in range(gpus_per_node)]
for _ in range(offset):
device_ids.append(device_ids.pop(0))
world_config = WorldConfig.mpi(
gpus_per_node=gpus_per_node, tensor_parallelism=tp_size, pipeline_parallelism=pp_size, device_ids=device_ids
)
engine_filename = json_config.engine_filename(world_config)
engine_index = model_parallel_rank
mpi_rank = mpi_comm().Get_rank()
# TODO: copied from worldConfig.h (getDevice())
mpi_device = mpi_rank % gpus_per_node
# TODO: check if API exists (copied from gptJsonConfig.cpp)
# https://github.com/terrykong/TensorRT-LLM/blob/05316d3313360012536ace46c781518f5afae75e/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp#L478
engine_filename = f"rank{engine_index}.engine"
serialize_path = Path(engine_dir) / engine_filename
assert torch.cuda.current_device() == world_config.device

session_config = GptSessionConfig(
max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_sequence_length=max_seq_len
)
session_config.gen_micro_batch_size = max_batch_size
session_config.ctx_micro_batch_size = max_batch_size
session_config.kv_cache_config = KvCacheConfig(
max_tokens=max_seq_len * max_batch_size, max_attention_window=max_seq_len
)

# $#$#$assert torch.cuda.current_device() == mpi_device
with open(serialize_path, "rb") as f:
engine_data = bytearray(f.read())

session = GptSession(session_config, model_config, world_config, engine_data)
decoder = ModelRunnerCppGptSession(
session,
lora_manager=None,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_seq_len=max_seq_len,
max_beam_width=max_beam_width,
with open(config_path) as f:
json_config_str = f.read()

engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank)

if not TRTLLM_SUPPORTS_DEVICE_DISABLE:
raise RuntimeError(
f"TensorRT-LLM does not support torch device disabling. Please upgrade TensorRT-LLM to make use of this feature."
)
elif not DISABLE_TORCH_DEVICE_SET:
raise RuntimeError(
f"To use TensorRT-LLM's python ModelRunner API in load_distributed(...) you must set the env var DISABLE_TORCH_DEVICE_SET=1"
)
decoder = ModelRunner.from_engine(
engine=engine,
# rank=world_config.rank,
# We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process
# So we will set it to the current
rank=torch.cuda.current_device(),
)

tensorrt_llm_worker_context.decoder = decoder
tensorrt_llm_worker_context.max_batch_size = max_batch_size
tensorrt_llm_worker_context.max_input_len = max_input_len
# Save the model config in case for refit
tensorrt_llm_worker_context.model_config = model_config


def refit(weights_dict):
def maybe_cast_to_trt_dtype(dtype):
if isinstance(dtype, trt.DataType):
return dtype
elif isinstance(dtype, torch.dtype):
return tensorrt_llm._utils.torch_dtype_to_trt(dtype)
else:
raise NotImplementedError(f"Expects the type to be a tensorrt.DataType or torch.dtype, but got {type(dtype)=}")


def refit(weights_dict: dict):
global tensorrt_llm_worker_context
decoder = tensorrt_llm_worker_context.decoder
if not isinstance(decoder, ModelRunner):
raise ValueError(
f"Refit is only supported with ModelRunner, but export has been configured with {type(decoder)=}"
)

engine = decoder.session.runtime.engine
# The session dtype plumbs the model_config's dtype
model_dtype = maybe_cast_to_trt_dtype(decoder.session.dtype)
assert engine.refittable, "Tried refitting engine without refit enabled"

refitter = trt.Refitter(engine=engine, logger=trt.Logger(trt.Logger.ERROR))
remaining_refit_weights = set(refitter.get_all_weights())
skipped_weights = []
for trt_name, weight in weights_dict.items():
if trt_name not in remaining_refit_weights:
skipped_weights.append(trt_name)
continue
trt_weight = trt.Weights(model_dtype, weight.data_ptr(), torch.numel(weight))
trt_wt_location = trt.TensorLocation.DEVICE if weight.is_cuda else trt.TensorLocation.HOST
assert (
model_dtype == refitter.get_weights_prototype(trt_name).dtype == maybe_cast_to_trt_dtype(weight.dtype)
), f"Expected all three of these dtypes to be the same {model_dtype=} {refitter.get_weights_prototype(trt_name).dtype=} weight.dtype={maybe_cast_to_trt_dtype(weight.dtype)}"

refitter.set_named_weights(
trt_name, trt_weight, trt_wt_location
), f"Unable to set {trt_name=} {trt_weight=} {trt_wt_location=}"
remaining_refit_weights.remove(trt_name)
if skipped_weights:
logging.warning(
f"These weights were ignored during refit since they are not present in engine: {skipped_weights}"
)
if remaining_refit_weights:
logging.warning(f"Weights dict did not contain weights for these named TRT weights: {remaining_refit_weights}")

if not refitter.refit_cuda_engine():
raise ValueError(f"Refit failed!")


def unload_engine():
"""
Deletes the ModelRunner which should free up device memory
"""
global tensorrt_llm_worker_context
dtype = tensorrt_llm_worker_context.model_config.data_type
tensorrt_llm_worker_context.decoder.session.refit_engine(weights_dict, dtype)
decoder = tensorrt_llm_worker_context.decoder
if not isinstance(decoder, ModelRunner):
raise ValueError(
f"unload_engine is only supported with ModelRunner, but export has been configured with {type(decoder)=}"
)

logging.info("Unloading engine...")
del tensorrt_llm_worker_context.decoder
tensorrt_llm_worker_context.decoder = None
logging.info("Engine unloaded!")


def prepare_input_tensors(
Expand Down

0 comments on commit b8bf39f

Please sign in to comment.