Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GQA support to MPT (and GPT) models #205

Closed
wants to merge 5 commits into from

Conversation

bheilbrun
Copy link
Contributor

@bheilbrun bheilbrun commented Oct 31, 2023

Why

TensorRT-LLM currently supports MPT models with MHA and MQA, but not GQA. However, there is at least one MPT-based model in the wild that uses GQA (replit-code-v1.5). It's my understanding that others may exist in the future.

What

TensorRT-LLM already supports GQA, so the delta in this PR is mostly about plumbing 'num KV heads' through a few layers, including the generic GPT model implementation. As such, GPT models should also support GQA but I didn't deeply test it (beyond the pre-existing unit and e2e tests).

Additionally, this PR improved support for the MPT no_bias option by not writing empty bias tensors (in most cases) when no bias is present in the model.

I also removed the unused examples.mpt.weights.load_from_hf_gpt function. The existing example scripts use only load_from_ft in the same file.

Testing

  • converted and tested replit-code-v1.5 from HuggingFace checkpoints. (commands below)
  • converted and tested mosaicml/mpt-7b with --world_size set to 1 and 2.
  • locally ran all pre-existing tests under testing/

I'm not sure how much we need to maintain backwards compatibility with existing FasterTransformer configs or implementations, so let me know if you see any problems in this area.

Similarly, if there are any other models I should test, let me know.

cd examples/mpt
python convert_hf_mpt_to_ft.py \
    -i /models/replit-v1-5/hf_ckpt \
    -o /models/replit-v1-5/ft_ckpts \
    -t bfloat16

python3 build.py \
    --model_dir /models/replit-v1-5/ft_ckpts/1-gpu \
    --output_dir /models/replit-v1-5/trt_engines/1-gpu \
    --max_batch_size 16 \
    --max_input_len 2048 \
    --max_output_len 256 \
    --use_gpt_attention_plugin \
    --use_gemm_plugin \
    --use_inflight_batching \
    --remove_input_padding \
    --paged_kv_cache \
    --enable_context_fmha \
    --parallel_build

python run.py \
    --engine_dir /models/replit-v1-5/trt_engines/1-gpu \
    --tokenizer /models/replit-v1-5/hf_ckpt \
    --input_text "def fibonacci" \
    --max_output_len 64

n_embd // tensor_parallel +
(n_embd // n_head) * 2)
(head_dim * n_kv_head * 2) // tensor_parallel)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be wrong for the MQA case. I'll need to find a model to verify this.

@dskhudia
Copy link

Looks great :-)

Copy link
Contributor

@megha95 megha95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one comment: there are few restrictions on number of kv heads and tp size for GQA/MQA; n_kv_heads must be divisible by tp_size and num_heads must be divisible by n_kv_heads. I'd suggest we put an assertion command to ensure this is satisfied.

@megha95
Copy link
Contributor

megha95 commented Nov 1, 2023

Also, I'll remove the dependency on FT conversion for MPT models in a new PR if it helps. Basically directly convert from HF.

@jdemouth-nvidia
Copy link
Collaborator

Hi @bheilbrun ,

Thanks a lot for the pull request. Can you rebase the PR against the main branch, please? We are not going to do updates to the release/0.5.0 branch.

@nv-guomingz , can you take a look at this PR, please?

Thanks,
Julien

@nv-guomingz nv-guomingz self-assigned this Nov 3, 2023
@nv-guomingz
Copy link
Collaborator

Hi @bheilbrun ,

Thanks a lot for the pull request. Can you rebase the PR against the main branch, please? We are not going to do updates to the release/0.5.0 branch.

@nv-guomingz , can you take a look at this PR, please?

Thanks, Julien

Sure, I'll take a look this PR today.

@nv-guomingz
Copy link
Collaborator

nv-guomingz commented Nov 9, 2023

Hi @bheilbrun , Could u please give me a full steps-by-steps instructions on building engine with replit-code-v1.5 model?

I managed to convert the weights via below cmd

python convert_hf_mpt_to_ft.py -i ./replit-code-v1_5-3b  -o ./ft_ckpts/replit/bf16-gqa/ -t bfloat16

However, I got a failure when I tried to build the engine with below cmd

python3 build.py --model_dir=./ft_ckpts/replit/bf16-gqa/1-gpu \ 130 ↵
--max_batch_size 64
--use_gpt_attention_plugin
--use_gemm_plugin
--output_dir ./trt_engines/replit/bf16/1-gpu

Error msg

[11/09/2023-11:31:46] [TRT-LLM] [I] Loading weights from FT...
[11/09/2023-11:31:46] [TRT-LLM] [I] Loading weights from FT...
Traceback (most recent call last):
File "/home/proj/mpt-gqa/examples/mpt/build.py", line 612, in
run_build()
File "/home/proj/mpt-gqa/examples/mpt/build.py", line 604, in run_build
build(0, args)
File "/home/proj/mpt-gqa/examples/mpt/build.py", line 570, in build
engine = build_rank_engine(builder, builder_config, engine_name,
File "/home/proj/mpt-gqa/examples/mpt/build.py", line 444, in build_rank_engine
load_from_ft(tensorrt_llm_gpt,
File "/home/scratch.work_sw_1/proj/mpt-gqa/examples/mpt/weight.py", line 205, in load_from_ft
tensorrt_llm_gpt.ln_f.bias.value = (fromfile(
File "/home/proj/mpt-gqa/tensorrt_llm/parameter.py", line 79, in value
assert isinstance(v, np.ndarray)

@nv-guomingz
Copy link
Collaborator

And another issue is this PR will break the original supporting with MPT weights converting.
Could u please optimize the PR to fix the regression?

@nv-guomingz nv-guomingz added the feature request New feature or request label Nov 9, 2023
@bheilbrun
Copy link
Contributor Author

However, I got a failure when I tried to build the engine with below cmd
...
tensorrt_llm_gpt.ln_f.bias.value = (fromfile(

Ohhh, I think I know what happened. In this PR, I tried to finish support for bias=False but I probably missed a code path. This didn't fail on my side because my ft_ckpt directory still had bias files left over from a previous build. So, my manual tests were non-hermetic, apologies!

I'll see if this has a quick fix or if I should back out the bias=False part of the PR.

@nv-guomingz
Copy link
Collaborator

Hi @bheilbrun It seems that you had pushed new commit and it fixed the original mpt model building issue.

However, I still met the issue when I tried to verify replit-v1.5 model.

Specifically, if I tried to build the engine with below command

python3 build.py --model_dir=./ft_ckpts/replit/bf16-gqa/1-gpu \                                                                                                                                        
                 --max_batch_size 64 \
                 --use_gpt_attention_plugin \
                 --use_gemm_plugin \
                 --output_dir ./trt_engines/replit/bf16/1-gpu --n_kv_head 8

There's error msg like

TypeError: GPTLMHeadModel.__init__() got an unexpected keyword argument 'num_kv_heads'

I think the rootcause is that we may need to apply simliar change like here.

Could u please take a look at this issue?

B.T.W, Would u please provide the full cmd to reproduce your local results in case we may have different usage with your PR?

@bheilbrun
Copy link
Contributor Author

bheilbrun commented Nov 13, 2023

Heya @nv-guomingz, thanks again for looking.

I added my test commands to the PR description. Hope that helps. I also tested mpt-7b with 1 and 2 GPUs. The latter required a small fix to mpt/run.py, assuming I diagnosed it right.

TypeError: GPTLMHeadModel.init() got an unexpected keyword argument 'num_kv_heads'

This error surprises me because I added that kwarg in this PR here, https://github.com/NVIDIA/TensorRT-LLM/pull/205/files#diff-1767dd0367b35551b6031983a93a636d50efca440e69bbdc17f8e0ac3d147151R341 .

Could you double check your local checkout of tensorrt_llm/models/gpt/model.py:341 ? Maybe it became unhappy when I rebased onto main.

Thanks for testing.

@nv-guomingz
Copy link
Collaborator

Hi @bheilbrun thanks for updating and the issue has gone with a clean build.

I've verified the correctness on both tp1 and tp2 case on H100/A100/L40S platform.

We're going to merge your PR into internal repo firstly and credit your great work in next weely release if everything goes well.

Thanks,
Guoming

@bheilbrun
Copy link
Contributor Author

@nv-guomingz great news, appreciate the help!

@bheilbrun
Copy link
Contributor Author

Also, I'll remove the dependency on FT conversion for MPT models in a new PR if it helps. Basically directly convert from HF.

@megha95 that'd be a great improvement. Hopping through the "old" FasterTransformer format is definitely a pain. It's working now but is also a maintenance headache. Let me know if I can help.

@nv-guomingz
Copy link
Collaborator

Hi @bheilbrun I saw you've update commit to 78b1b03. By checking the git history, I guess u wanna to update this branch with latest main code.

I think it's not neccessary if there's no feature changes since we've rebased the 416eee2 with internal main branch succesfully 😄

Thanks,
Guoming

@bheilbrun
Copy link
Contributor Author

Thanks! Out of convenience, I was using this branch to share code between a few different machines. :) I'll do this on a different branch if I need to update again, to avoid the notification noise for y'all.

@@ -90,10 +90,6 @@ def convert_weight_to_ft_each(out_dir: str, tensor_parallelism: int,
for j in range(tensor_parallelism):
save_path = os.path.join(out_dir, f'model.{tensor_name}.{j}.bin')
split_vals[j].tofile(save_path)
if config['no_bias']:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bheilbrun May I know why we need to remove line 93 to line 96?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the no_bias change I mentioned in the PR description. I translate MPT's no_bias=True option to GPT's bias=False. When this is set, GPT doesn't load bias tensors for many layers.

However, there is one implementation difference between MPT and GPT. MPT has no bias for all layers. GPT by contrast still expects biases for layernorm layers, based on my reading and experimentation.

Hope that clears it up and that it's not causing problems.

@kaiyux kaiyux mentioned this pull request Nov 17, 2023
@kaiyux
Copy link
Member

kaiyux commented Nov 18, 2023

Hi @bheilbrun , we pushed an update to the main branch, and we added you as co-author, which is also mentioned in the announcement.

We're going to close this PR, please let us know if you have any questions. Thanks again for the great contribution.

@kaiyux kaiyux closed this Nov 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants