Skip to content

Commit

Permalink
Avoid unpacking NeMo checkpoints before exporting to TRT-LLM (#8866)
Browse files Browse the repository at this point in the history
* Replaced unpacking of nemo checkpoints on export with a VFS-like TarPath object.

Signed-off-by: Alexey Panteleev <alpanteleev@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed the signature of ZarrPathStore.__delitem__

Signed-off-by: Alexey Panteleev <alpanteleev@nvidia.com>

---------

Signed-off-by: Alexey Panteleev <alpanteleev@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 19, 2024
1 parent c687a69 commit 24ac02a
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 144 deletions.
204 changes: 204 additions & 0 deletions nemo/export/tarutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fnmatch
import os
import tarfile
from typing import Union

import zarr.storage


class TarPath:
"""
A class that represents a path inside a TAR archive and behaves like pathlib.Path.
Expected use is to create a TarPath for the root of the archive first, and then derive
paths to other files or directories inside the archive like so:
with TarPath('/path/to/archive.tar') as archive:
myfile = archive / 'filename.txt'
if myfile.exists():
data = myfile.read()
...
Only read and enumeration operations are supported.
"""

def __init__(self, tar: Union[str, tarfile.TarFile, 'TarPath'], *parts):
self._relpath = ''
if isinstance(tar, TarPath):
self._tar = tar._tar
self._relpath = os.path.join(tar._relpath, *parts)
elif isinstance(tar, tarfile.TarFile):
self._tar = tar
if parts:
self._relpath = os.path.join(*parts)
elif isinstance(tar, str):
self._tar = tarfile.open(tar, 'r')
if parts:
self._relpath = os.path.join(*parts)
else:
raise ValueError(f"Unexpected argument type for TarPath: {type(tar).__name__}")

def __truediv__(self, key) -> 'TarPath':
return TarPath(self._tar, os.path.join(self._relpath, key))

def __str__(self) -> str:
return os.path.join(self._tar.name, self._relpath)

@property
def tarobject(self):
return self._tar

@property
def relpath(self):
return self._relpath

@property
def name(self):
return os.path.split(self._relpath)[1]

@property
def suffix(self):
name = self.name
i = name.rfind('.')
if 0 < i < len(name) - 1:
return name[i:]
else:
return ''

def __enter__(self):
self._tar.__enter__()
return self

def __exit__(self, *args):
return self._tar.__exit__(*args)

def exists(self):
try:
self._tar.getmember(self._relpath)
return True
except KeyError:
try:
self._tar.getmember(os.path.join('.', self._relpath))
return True
except KeyError:
return False

def is_file(self):
try:
self._tar.getmember(self._relpath).isreg()
return True
except KeyError:
try:
self._tar.getmember(os.path.join('.', self._relpath)).isreg()
return True
except KeyError:
return False

def is_dir(self):
try:
self._tar.getmember(self._relpath).isdir()
return True
except KeyError:
try:
self._tar.getmember(os.path.join('.', self._relpath)).isdir()
return True
except KeyError:
return False

def open(self, mode: str):
if mode != 'r' and mode != 'rb':
raise NotImplementedError()
try:
# Try the relative path as-is first
return self._tar.extractfile(self._relpath)
except KeyError:
try:
# Try the relative path with "./" prefix
return self._tar.extractfile(os.path.join('.', self._relpath))
except KeyError:
raise FileNotFoundError()

def glob(self, pattern):
for member in self._tar.getmembers():
# Remove the "./" prefix, if any
name = member.name[2:] if member.name.startswith('./') else member.name

# If we're in a subdirectory, make sure the file is too, and remove that subdir component
if self._relpath:
if not name.startswith(self._relpath + '/'):
continue
name = name[len(self._relpath) + 1 :]

# See if the name matches the pattern
if fnmatch.fnmatch(name, pattern):
yield TarPath(self._tar, os.path.join(self._relpath, name))

def rglob(self, pattern):
for member in self._tar.getmembers():
# Remove the "./" prefix, if any
name = member.name[2:] if member.name.startswith('./') else member.name

# If we're in a subdirectory, make sure the file is too, and remove that subdir component
if self._relpath:
if not name.startswith(self._relpath + '/'):
continue
name = name[len(self._relpath) + 1 :]

# See if any tail of the path matches the pattern, return full path if that's true
parts = name.split('/')
for i in range(len(parts)):
subname = '/'.join(parts[i:])
if fnmatch.fnmatch(subname, pattern):
yield TarPath(self._tar, os.path.join(self._relpath, name))
break

def iterdir(self):
return self.glob('*')


class ZarrPathStore(zarr.storage.BaseStore):
"""
An implementation of read-only Store for zarr library
that works with pathlib.Path or TarPath objects.
"""

def __init__(self, tarpath: TarPath):
self._path = tarpath
self._writable = False
self._erasable = False

def __getitem__(self, key):
with (self._path / key).open('rb') as file:
return file.read()

def __contains__(self, key):
return (self._path / key).is_file()

def __iter__(self):
return self.keys()

def __len__(self):
return sum(1 for _ in self.keys())

def __setitem__(self, key, value):
raise NotImplementedError()

def __delitem__(self, key):
raise NotImplementedError()

def keys(self):
return self._path.iterdir()
18 changes: 10 additions & 8 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
import wrapt

from nemo.deploy import ITritonDeployable
from nemo.export.tarutils import TarPath
from nemo.export.trt_llm.model_config_trt import model_config_to_tensorrt_llm
from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer
from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_llm_to_model_config
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_refit
from nemo.export.trt_llm.utils import is_nemo_file, unpack_nemo_ckpt
from nemo.export.trt_llm.utils import is_nemo_file

use_deploy = True
try:
Expand Down Expand Up @@ -584,18 +585,19 @@ def _load_prompt_tables(self):
self.ptuning_tables = []

def _get_prompt_embedding_table_ckpt(self, prompt_embeddings_checkpoint_path):
with tempfile.TemporaryDirectory() as temp_dir:
unpack_nemo_ckpt(prompt_embeddings_checkpoint_path, temp_dir)
mw_path = os.path.join(temp_dir, "model_weights.ckpt")
if not Path(mw_path).exists():
mw_path = os.path.join(temp_dir, "mp_rank_00", "model_weights.ckpt")
if not Path(mw_path).exists():
with TarPath(prompt_embeddings_checkpoint_path) as checkpoint_archive:
mw_path = checkpoint_archive / "model_weights.ckpt"
if not mw_path.exists():
mw_path = checkpoint_archive / "mp_rank_00/model_weights.ckpt"
if not mw_path.exists():
raise FileNotFoundError(
"File: {0} could not be found in the nemo checkpoint. "
"Please check the nemo checkpoint format for the prompt "
"embedding table.".format(mw_path)
)
weights = torch.load(mw_path)

with mw_path.open('rb') as mw_file:
weights = torch.load(mw_file)

weights_found = True
if "model.embedding.adapter_layer.ptuning_adapter.inference_table" in weights:
Expand Down
48 changes: 4 additions & 44 deletions nemo/export/trt_llm/nemo/nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@

import functools
import logging
import os
import pathlib
import tarfile
import typing

import torch
import yaml
from transformers import FalconConfig, GPT2Config, LlamaConfig

from nemo.export.tarutils import TarPath
from nemo.export.trt_llm.nemo.convert import cpu_map_location, gpu_map_location


LOGGER = logging.getLogger("NeMo")


Expand Down Expand Up @@ -100,45 +98,6 @@ def add_special_tokens_to_tokenizer(tokenizer):
tokenizer.add_special_tokens({"eos_token": "</s>"})


def unpack_nemo_ckpt(
nemo_archive_path: typing.Union[str, pathlib.Path], out_dir_path: typing.Union[str, pathlib.Path],
):
nemo_archive_path = pathlib.Path(nemo_archive_path)
if not nemo_archive_path.exists():
raise FileNotFoundError(f"{nemo_archive_path} does not exist")

for tar_mode in ["r:", "r:gz"]:
try:
with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file:

def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory

def safe_members(tar_file):
members = []
for member in tar_file.getmembers():
member_path = os.path.join(out_dir_path, member.name)
if not is_within_directory(out_dir_path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
members.append(member)
return members

tar_file.extractall(
out_dir_path, members=safe_members(tar_file), numeric_owner=False
) # nosec - tar path has been validated.

return out_dir_path
except tarfile.ReadError:
pass

raise RuntimeError(f"Could not unpack {nemo_archive_path}")


def extract_layers_with_prefix(model_, prefix):
length_to_trim = len(prefix)
model_state = model_.get("state_dict", model_)
Expand All @@ -147,9 +106,10 @@ def extract_layers_with_prefix(model_, prefix):

class UnpackedNemoCheckpointDir:
def __init__(
self, checkpoints_dir: typing.Union[str, pathlib.Path], load_checkpoints_to_cpu: bool = False,
self, checkpoints_dir: typing.Union[pathlib.Path, TarPath], load_checkpoints_to_cpu: bool = False,
):
self._checkpoints_dir = pathlib.Path(checkpoints_dir)
assert isinstance(checkpoints_dir, (pathlib.Path, TarPath))
self._checkpoints_dir = checkpoints_dir
self._load_checkpoints_to_cpu = load_checkpoints_to_cpu

@property
Expand Down
Loading

0 comments on commit 24ac02a

Please sign in to comment.