Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-org export code #9353

Merged
merged 27 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d0952e1
reorg the export code
oyilmaz-nvidia May 31, 2024
7911c52
Apply isort and black reformatting
oyilmaz-nvidia May 31, 2024
5eb797c
replaced log with raise
oyilmaz-nvidia Jun 3, 2024
f5be3e4
Merge branch 'onur/reorg_export' of https://github.com/oyilmaz-nvidia…
oyilmaz-nvidia Jun 3, 2024
2198471
add converter and loader folders
oyilmaz-nvidia Jun 4, 2024
db706c9
move nemo_ckpt_convert into the converter folder
oyilmaz-nvidia Jun 4, 2024
04637f5
move nemo_file into loader folder
oyilmaz-nvidia Jun 4, 2024
2716670
reorg converter
oyilmaz-nvidia Jun 4, 2024
1c8a54c
Apply isort and black reformatting
oyilmaz-nvidia Jun 4, 2024
ad08ca8
continue to reorg converter
oyilmaz-nvidia Jun 4, 2024
0402bd6
Merge branch 'onur/reorg_export' of https://github.com/oyilmaz-nvidia…
oyilmaz-nvidia Jun 4, 2024
58fbdcb
Apply isort and black reformatting
oyilmaz-nvidia Jun 4, 2024
70e17b9
continue to reorg
oyilmaz-nvidia Jun 4, 2024
12d4dd2
Merge branch 'onur/reorg_export' of https://github.com/oyilmaz-nvidia…
oyilmaz-nvidia Jun 4, 2024
2050be0
move nemo file back into nemo folder
oyilmaz-nvidia Jun 5, 2024
6f32d0f
renamed nemo folder to nemo_ckpt_loader
oyilmaz-nvidia Jun 5, 2024
b09dccb
remove unused function
oyilmaz-nvidia Jun 5, 2024
753b654
removed nemo file
oyilmaz-nvidia Jun 5, 2024
387e36e
Apply isort and black reformatting
oyilmaz-nvidia Jun 5, 2024
55cb60c
moved a function to tensorrt_llm_run file
oyilmaz-nvidia Jun 5, 2024
3536130
Merge branch 'onur/reorg_export' of https://github.com/oyilmaz-nvidia…
oyilmaz-nvidia Jun 5, 2024
5aabced
Apply isort and black reformatting
oyilmaz-nvidia Jun 5, 2024
410daef
Remove unused imports
oyilmaz-nvidia Jun 5, 2024
e2655a8
Merge branch 'main' into onur/reorg_export
oyilmaz-nvidia Jun 5, 2024
dc9dced
Apply isort and black reformatting
oyilmaz-nvidia Jun 5, 2024
653dab9
import csv added
oyilmaz-nvidia Jun 6, 2024
ad286ff
Merge branch 'main' into onur/reorg_export
oyilmaz-nvidia Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

from nemo.deploy import ITritonDeployable
from nemo.export.tarutils import TarPath, unpack_tarball
from nemo.export.trt_llm.nemo_utils import get_tokenzier, is_nemo_file, nemo_to_trtllm_config
from nemo.export.trt_llm.converter.model_converter import model_to_trtllm_ckpt
from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import get_tokenzier, is_nemo_file, load_nemo_model
from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
Expand Down Expand Up @@ -225,15 +226,16 @@ def export(
lora_target_modules=lora_target_modules,
)
else:
weights_dicts, model_configs, self.tokenizer = nemo_to_trtllm_config(
in_file=nemo_checkpoint_path,
model, model_configs, self.tokenizer = load_nemo_model(nemo_checkpoint_path, nemo_export_dir)
weights_dicts, model_configs = model_to_trtllm_ckpt(
model=model,
nemo_model_config=model_configs,
nemo_export_dir=nemo_export_dir,
decoder_type=model_type,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
use_parallel_embedding=use_parallel_embedding,
nemo_export_dir=nemo_export_dir,
save_nemo_model_config=save_nemo_model_config,
)

for weight_dict, model_config in zip(weights_dicts, model_configs):
Expand Down
13 changes: 13 additions & 0 deletions nemo/export/trt_llm/converter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,19 @@
# limitations under the License.


import argparse
import csv
import datetime
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple

import numpy as np
import tensorrt_llm
from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from transformers import AutoTokenizer, LlamaConfig, PreTrainedTokenizer

from nemo.export.tarutils import TarPath
from nemo.export.trt_llm.nemo.nemo import UnpackedNemoCheckpointDir
from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer, convert_dist_checkpoint


DECODER_MODEL_TYPE = {
"gptj": 'GPTForCausalLM',
"gptnext": 'GPTForCausalLM',
"llama": 'LLaMAForCausalLM',
"gemma": 'GemmaForCausalLM',
"falcon": 'FalconForCausalLM',
}
from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import convert_model_to_trt_llm_ckpt
from nemo.export.trt_llm.converter.utils import DECODER_MODEL_TYPE, split

LOGGER = logging.getLogger("NeMo")

Expand Down Expand Up @@ -80,181 +64,26 @@ def prompt_convert(prompt_config, prompt_weights):
return vtokens_embeddings


def is_nemo_file(path):
flag = False

if path is not None:
if len(path) > 5:
pc = pathlib.Path(path)
if pc.exists():
if pc.is_file():
if path[-5 : len(path)] == ".nemo":
flag = True

return flag


def split(v, tp_size, idx, dim=0):
"""Splits the np tensor v on dim and return the idx's slice."""
if tp_size == 1:
return v
if len(v.shape) == 1:
return np.ascontiguousarray(np.split(v, tp_size)[idx])
else:
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])


def _nemo_llm_decode(
in_file: str,
out_dir: str,
tensor_parallelism: int = 1,
processes: int = 1,
storage_type: str = "bfloat16",
load_checkpoints_on_gpu: bool = False,
decoder_type: str = "gptnext",
use_parallel_embedding: bool = False,
save_nemo_model_config: bool = False,
) -> Tuple[Dict[str, np.ndarray], PretrainedConfig, PreTrainedTokenizer]:
"""Decodes the NEMO file and returns the weights dict, llm config and tokenizer."""
args = argparse.Namespace()
args.out_dir = out_dir
args.tensor_parallelism = tensor_parallelism
args.processes = processes
args.storage_type = storage_type
args.load_checkpoints_on_gpu = load_checkpoints_on_gpu
args.verbose = False
args.decoder_type = decoder_type
args.use_parallel_embedding = use_parallel_embedding

if not os.path.exists(in_file):
LOGGER.error("%s does not exist", in_file)
sys.exit(1)

if os.path.isdir(in_file):
nemo_dir = Path(in_file)
else:
nemo_dir = TarPath(in_file)

try:
unpacked_checkpoint_dir = UnpackedNemoCheckpointDir(
nemo_dir, load_checkpoints_to_cpu=not args.load_checkpoints_on_gpu
)

start_time = datetime.datetime.now()
dist_ckpt_folder = nemo_dir / "model_weights"

if dist_ckpt_folder.exists():
weights_dict, llm_config, tokenizer = convert_dist_checkpoint(unpacked_checkpoint_dir, args)
else:
raise Exception(
"Not a supported nemo file format. " "Only distributed mcore nemo checkpoints are support."
)

LOGGER.info("Spent %s (h:m:s) to convert the model", datetime.datetime.now() - start_time)

if save_nemo_model_config:
# Copy the config file without using shutil.copy(...) because input may be a TarPath
with (unpacked_checkpoint_dir._checkpoints_dir / "model_config.yaml").open("rb") as infile:
with open(os.path.join(args.out_dir, "model_config.yaml"), "wb") as outfile:
outfile.write(infile.read())
finally:
if isinstance(nemo_dir, TarPath):
nemo_dir.tarobject.close()

return weights_dict, llm_config, tokenizer


def get_tokenzier(tokenizer_dir_or_path: Path) -> PreTrainedTokenizer:
"""Loads the tokenizer from the decoded NEMO weights dir."""
if os.path.isdir(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")):
return AutoTokenizer.from_pretrained(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer"))

model_path = tokenizer_dir_or_path / "tokenizer.model" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path
tokenizer_config = {"library": "sentencepiece", "model": str(model_path)}
return build_tokenizer(tokenizer_config)


def to_word_list_format(
word_dict: List[List[str]],
tokenizer=None,
ref_str="<extra_id_1>",
):
'''
format of word_dict
len(word_dict) should be same to batch_size
word_dict[i] means the words for batch i
len(word_dict[i]) must be 1, which means it only contains 1 string
This string can contains several sentences and split by ",".
For example, if word_dict[2] = " I am happy, I am sad", then this function will return
the ids for two short sentences " I am happy" and " I am sad".
'''
assert tokenizer is not None, "need to set tokenizer"

flat_ids = []
offsets = []
# The encoding of a single word can't always be trusted. See
# https://github.com/NVIDIA/NeMo/blob/bb575b72fd0be51ae10cc77d9f89ddb9e9d3b96d/nemo/collections/nlp/modules/common/text_generation_strategy.py#L229
ids_ref = tokenizer.encode(ref_str)
for word_dict_item in word_dict:
item_flat_ids = []
item_offsets = []

if isinstance(word_dict_item[0], bytes):
word_dict_item = [word_dict_item[0].decode()]

words = list(csv.reader(word_dict_item))[0]
for word in words:
ids = tokenizer.encode(f"{ref_str}{word}")
if ids[0 : len(ids_ref)] == ids_ref:
# It worked! We can obtain the token(s) associated to `word` by stripping the prefix tokens.
ids = ids[len(ids_ref) :]
else:
# Unfortunately the prefix was merged with `word`. We could try with a different prefix, but
# for now we just use the basic encoding since this should be a very rare edge case.
ids = tokenizer.encode(word)
logging.warning(f"The encoding of word '{word}' into tokens {ids} might be incorrect")

if len(ids) == 0:
continue

item_flat_ids += ids
item_offsets.append(len(ids))

flat_ids.append(np.array(item_flat_ids))
offsets.append(np.cumsum(np.array(item_offsets)))

pad_to = max(1, max(len(ids) for ids in flat_ids))

for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)

return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2))


def nemo_to_trtllm_config(
in_file: str,
def model_to_trtllm_ckpt(
model,
nemo_model_config,
nemo_export_dir,
decoder_type: str,
nemo_export_dir: Union[str, Path],
dtype: str = "bfloat16",
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
use_parallel_embedding: bool = False,
save_nemo_model_config: bool = False,
) -> Tuple[List[Dict], List[PretrainedConfig], PreTrainedTokenizer]:
"""Converts the NEMO file and construct the `PretrainedConfig` before tensorrt_llm deployment."""
dtype_str = dtype

weights_dict, nemo_model_config, tokenizer = _nemo_llm_decode(
in_file=in_file,
out_dir=nemo_export_dir,
tensor_parallelism=tensor_parallel_size,
) -> Tuple[List[Dict], List[PretrainedConfig]]:

weights_dict = convert_model_to_trt_llm_ckpt(
model=model,
nemo_model_config=nemo_model_config,
nemo_export_dir=nemo_export_dir,
inference_tp_size=tensor_parallel_size,
processes=1,
storage_type=dtype_str,
storage_type=dtype,
use_parallel_embedding=use_parallel_embedding,
load_checkpoints_on_gpu=False,
decoder_type=decoder_type,
save_nemo_model_config=save_nemo_model_config,
)

world_size = tensor_parallel_size * pipeline_parallel_size
Expand All @@ -275,7 +104,7 @@ def nemo_to_trtllm_config(

config = {
'architecture': DECODER_MODEL_TYPE[decoder_type],
'dtype': dtype_str,
'dtype': dtype,
'num_hidden_layers': nemo_model_config.get('num_layers'),
'num_attention_heads': nemo_model_config.get('num_attention_heads'),
'num_key_value_heads': nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']),
Expand Down Expand Up @@ -387,4 +216,4 @@ def nemo_to_trtllm_config(
model_configs.append(model_config)
weights_dicts.append(weights_dict_local)

return weights_dicts, model_configs, tokenizer
return weights_dicts, model_configs
Loading
Loading