-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) #6611
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
Able to load Nemotron-4-340B-Instruct (had to make a lot of edits to this checkpoint, will upload) with cpu offloading: >>> from vllm import LLM, SamplingParams
>>> model = LLM("/home/mgoin/models/Nemotron-4-340B-Instruct", tensor_parallel_size=8, distributed_executor_backend="ray", cpu_offload_gb=20, enforce_eager=True)
2024-07-21 02:32:18,002 INFO worker.py:1749 -- Started a local Ray instance.
INFO 07-21 02:32:19 llm_engine.py:175] Initializing an LLM engine (v0.5.2) with config: model='/home/mgoin/models/Nemotron-4-340B-Instruct', speculative_config=None, tokenizer='/home/mgoin/models/Nemotron-4-340B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/home/mgoin/models/Nemotron-4-340B-Instruct, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-21 02:32:36 utils.py:779] Found nccl from library libnccl.so.2
INFO 07-21 02:32:36 pynccl.py:63] vLLM is using nccl==2.20.5
(RayWorkerWrapper pid=1139686) INFO 07-21 02:32:36 utils.py:779] Found nccl from library libnccl.so.2
(RayWorkerWrapper pid=1139686) INFO 07-21 02:32:36 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 07-21 02:32:39 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-21 02:32:39 shm_broadcast.py:240] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7fda1b3b33d0>, local_subscribe_port=55417, local_sync_port=53103, remote_subscribe_port=None, remote_sync_port=None)
(RayWorkerWrapper pid=1139686) INFO 07-21 02:32:39 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-21 02:33:14 model_runner.py:563] Loading model weights took 59.3774 GB
(RayWorkerWrapper pid=1140131) INFO 07-21 02:33:34 model_runner.py:563] Loading model weights took 59.3774 GB
(RayWorkerWrapper pid=1140410) INFO 07-21 02:32:36 utils.py:779] Found nccl from library libnccl.so.2 [repeated 6x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayWorkerWrapper pid=1140410) INFO 07-21 02:32:36 pynccl.py:63] vLLM is using nccl==2.20.5 [repeated 6x across cluster]
(RayWorkerWrapper pid=1140410) INFO 07-21 02:32:39 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json [repeated 6x across cluster]
INFO 07-21 02:33:42 distributed_gpu_executor.py:56] # GPU blocks: 4867, # CPU blocks: 3640
>>> model.generate("Hello!")
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:16<00:00, 16.29s/it, est. speed input: 0.12 toks/s, output: 0.98 toks/s]
[RequestOutput(request_id=0, prompt='Hello!', prompt_token_ids=[14716, 252397], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='8493045638869919', token_ids=(252377, 252372, 252370, 252365, 252334, 252372, 252366, 252376, 252365, 252377, 252377, 252376, 252370, 252370, 252338, 252370), cumulative_logprob=-38.362853050231934, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721529255.4465594, last_token_time=1721529255.4465594, first_scheduled_time=1721529255.4538574, first_token_time=1721529256.6029613, time_in_queue=0.007297992706298828, finished_time=1721529271.7377276), lora_request=None)] |
I made an FP8 W8 quantized checkpoint based on the above and it produces the same output. The tokenizer is not right for this model so that is the next step. >>> from vllm import LLM, SamplingParams
>>> model = LLM("/home/mgoin/models/Nemotron-4-340B-Instruct-FP8", tensor_parallel_size=8, distributed_executor_backend="ray", enforce_eager=True)
2024-07-21 03:09:56,222 INFO worker.py:1749 -- Started a local Ray instance.
INFO 07-21 03:09:57 llm_engine.py:175] Initializing an LLM engine (v0.5.2) with config: model='/home/mgoin/models/Nemotron-4-340B-Instruct-FP8', speculative_config=None, tokenizer='/home/mgoin/models/Nemotron-4-340B-Instruct-FP8', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/home/mgoin/models/Nemotron-4-340B-Instruct-FP8, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2
INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-21 03:10:17 shm_broadcast.py:240] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7fdf89ac73d0>, local_subscribe_port=56965, local_sync_port=40185, remote_subscribe_port=None, remote_sync_port=None)
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
(RayWorkerWrapper pid=1233230) WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
WARNING 07-21 03:10:35 utils.py:569] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
INFO 07-21 03:10:36 model_runner.py:563] Loading model weights took 40.8975 GB
(RayWorkerWrapper pid=1233718) WARNING 07-21 03:10:45 utils.py:569] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2 [repeated 6x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5 [repeated 6x across cluster]
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json [repeated 6x across cluster]
(RayWorkerWrapper pid=1234002) WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change. [repeated 6x across cluster]
(RayWorkerWrapper pid=1233718) INFO 07-21 03:10:46 model_runner.py:563] Loading model weights took 40.8975 GB
INFO 07-21 03:10:54 distributed_gpu_executor.py:56] # GPU blocks: 22421, # CPU blocks: 3640
>>> model.generate("Hello!")
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.19s/it, est. speed input: 1.69 toks/s, output: 13.49 toks/s]
[RequestOutput(request_id=0, prompt='Hello!', prompt_token_ids=[14716, 252397], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='8493045638869919', token_ids=(252377, 252372, 252370, 252365, 252334, 252372, 252366, 252376, 252365, 252377, 252377, 252376, 252370, 252370, 252338, 252370), cumulative_logprob=-38.44283938407898, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721531506.653993, last_token_time=1721531506.653993, first_scheduled_time=1721531506.6587515, first_token_time=1721531506.8336384, time_in_queue=0.004758596420288086, finished_time=1721531507.8438315), lora_request=None)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: how is the architecture different from Llama?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fairly certain it could be implemented inside our Llama implementation. I'm not sure how to deal with the absence of gate_proj though.
The key differences are:
- There is no gate_proj, just up_proj
- Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
- Squared ReLU instead of SwiGLU
- Adds a rotary_percentage to RoPE
This is a good overview of main changes: https://twitter.com/danielhanchen/status/1801671106266599770
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update documentation as well?
class ReLUSquaredActivation(CustomOp): | ||
""" | ||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might also need to add this op to CPU similar to what I did here previously with QuickGELU
.
Per offline discussion, no CPU op should need to be added since it's just calling torch API. @bigPYJ1151 It would be great if you can confirm that, thanks!
This looks great! Just wondering has anyone tried serving the bf16 version on 2 8xA100 nodes, with ray? |
…llm-project#6611) Signed-off-by: Alvant <alvasian@yandex.ru>
FIX #5722
Based off huggingface/transformers#31699 - Nemotron-3 loads and produces reasonable output. Nemotron-4 and the most recently released Minitron works and evals can be reproduced.
For CI, a Minitron-4B-Base GSM8k eval has been added to the lm-eval test suite.
The architecture is pretty similar to Llama, with these changes:
gate_proj
, justup_proj
rotary_percent
to RoPECollection of checkpoints (Nemotron-3, Nemotron-4, Minitron): https://huggingface.co/collections/mgoin/nemotron-in-vllm-66a151b4240bcd9c28735ec5
Loading
nvidia/Minitron-4B-Base
:Loading
nemotron3-8b-base
:Loading
mgoin/Nemotron-4-340B-Instruct-FP8-Dynamic
on 8xA100: