Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler #13594

Merged
merged 11 commits into from
Feb 22, 2025

Conversation

JenZhao
Copy link
Contributor

@JenZhao JenZhao commented Feb 20, 2025

This PR integrates the sampling logic from the V0 version of profile_run into the V1 memory profiling tool. Note that the rejection sampler is not included.

Testing on H100 GPU with CUDA 12.4 with VLLM_USE_V1=1 python3 examples/offline_inference/basic/basic.py

before the change

INFO 02-21 06:27:03 kv_cache_utils.py:524] GPU KV cache size: 2,039,760 tokens
INFO 02-21 06:27:03 kv_cache_utils.py:527] Maximum concurrency for 2,048 tokens per request: 995.98x

after the change, no_penalties=True, all other parameters set

                    dummy_tensors = lambda v: torch.full(
                    (num_reqs, ), v, device=self.device)

                    no_penalties=True,
                    prompt_token_ids=torch.ones_like(logits, dtype=torch.int64),
                    frequency_penalties=dummy_tensors(0.1),
                    presence_penalties=dummy_tensors(0.1),
                    repetition_penalties=dummy_tensors(0.1),
INFO 02-22 04:09:52 kv_cache_utils.py:524] GPU KV cache size: 1,980,832 tokens
INFO 02-22 04:09:52 kv_cache_utils.py:527] Maximum concurrency for 2,048 tokens per request: 967.20x

after the change, no_penalties=False, all other parameters set

                    dummy_tensors = lambda v: torch.full(
                    (num_reqs, ), v, device=self.device)

                    no_penalties=False,
                    prompt_token_ids=torch.ones_like(logits, dtype=torch.long),
                    frequency_penalties=dummy_tensors(0.1),
                    presence_penalties=dummy_tensors(0.1),
                    repetition_penalties=dummy_tensors(0.1),
INFO 02-21 06:36:47 kv_cache_utils.py:524] GPU KV cache size: 1,954,272 tokens
INFO 02-21 06:36:47 kv_cache_utils.py:527] Maximum concurrency for 2,048 tokens per request: 954.23x

Testing on H100 GPU with CUDA 12.4 with VLLM_USE_V1=1 python3 benchmarks/benchmark_throughput.py --model NousResearch/Hermes-3-Llama-3.1-8B --dataset /home/jovyan/vllm/ShareGPT_V3_unfiltered_cleaned_split.json

No sampler available, original main code

INFO 02-22 02:49:21 monitor.py:33] torch.compile takes 7.43 s in total
INFO 02-22 02:49:22 kv_cache_utils.py:524] GPU KV cache size: 438,384 tokens
INFO 02-22 02:49:22 kv_cache_utils.py:527] Maximum concurrency for 131,072 tokens per request: 3.34x
INFO 02-22 02:49:53 gpu_model_runner.py:1339] Graph capturing finished in 31 secs, took 0.64 GiB
INFO 02-22 02:49:53 core.py:116] init engine (profile, create kv cache, warmup model) took 45.61 seconds
Processed prompts: 100%|███████████████████████| 1000/1000 [00:19<00:00, 51.05it/s, est. speed input: 10985.47 toks/s, output: 10125.14 toks/s]
Throughput: 48.45 requests/s, 20035.76 total tokens/s, 9609.62 output tokens/s

Sampler available, no penalties, prompt token IDs are None

            if get_pp_group().is_last_rank:
                hidden_states = hidden_states[logit_indices]
                logits = self.model.compute_logits(hidden_states, None)
                dummy_tensors = lambda v: torch.full(
                    (num_reqs, ), v, device=self.device)
                dummy_metadata = SamplingMetadata(
                    temperature=dummy_tensors(0.5),
                    all_greedy=False,
                    all_random=False,
                    spec_token_ids=None,
                    top_p=dummy_tensors(0.9),
                    top_k=dummy_tensors(logits.size(1) - 1),
                    min_p=None,
                    generators={},
                    max_num_logprobs=None,
                    no_penalties=True,
                    prompt_token_ids=None,
                    frequency_penalties=None,
                    presence_penalties=None,
                    repetition_penalties=None,
                    output_token_ids=[[] for _ in range(num_reqs)],
                    min_tokens={},
                    logit_bias=[None for _ in range(num_reqs)])
                sampler_output = self.model.sample(
                    logits=logits, sampling_metadata=dummy_metadata)
            else:
                logits = None
                sampler_output = None
                dummy_metadata = None
            torch.cuda.synchronize()
            del hidden_states, logits, sampler_output, dummy_metadata
INFO 02-22 02:46:20 monitor.py:33] torch.compile takes 7.57 s in total
INFO 02-22 02:46:21 kv_cache_utils.py:524] GPU KV cache size: 415,072 tokens
INFO 02-22 02:46:21 kv_cache_utils.py:527] Maximum concurrency for 131,072 tokens per request: 3.17x
INFO 02-22 02:46:51 gpu_model_runner.py:1363] Graph capturing finished in 30 secs, took 0.64 GiB
INFO 02-22 02:46:51 core.py:116] init engine (profile, create kv cache, warmup model) took 45.31 seconds
Processed prompts: 100%|███████████████████████| 1000/1000 [00:19<00:00, 50.95it/s, est. speed input: 10965.14 toks/s, output: 10106.39 toks/s]
Throughput: 48.45 requests/s, 20037.99 total tokens/s, 9610.69 output tokens/s

Sampler available, no penalties, prompt token IDs provided.

            if get_pp_group().is_last_rank:
                hidden_states = hidden_states[logit_indices]
                logits = self.model.compute_logits(hidden_states, None)
                dummy_tensors = lambda v: torch.full(
                    (num_reqs, ), v, device=self.device)
                dummy_metadata = SamplingMetadata(
                    temperature=dummy_tensors(0.5),
                    all_greedy=False,
                    all_random=False,
                    spec_token_ids=None,
                    top_p=dummy_tensors(0.9),
                    top_k=dummy_tensors(logits.size(1) - 1),
                    min_p=None,
                    generators={},
                    max_num_logprobs=None,
                    no_penalties=True,
                    prompt_token_ids=torch.ones_like(logits, dtype=torch.long),
                    frequency_penalties=None,
                    presence_penalties=None,
                    repetition_penalties=None,
                    output_token_ids=[[] for _ in range(num_reqs)],
                    min_tokens={},
                    logit_bias=[None for _ in range(num_reqs)])
                sampler_output = self.model.sample(
                    logits=logits, sampling_metadata=dummy_metadata)
            else:
                logits = None
                sampler_output = None
                dummy_metadata = None
            torch.cuda.synchronize()
            del hidden_states, logits, sampler_output, dummy_metadata
INFO 02-22 02:40:52 monitor.py:33] torch.compile takes 7.38 s in total
INFO 02-22 02:40:52 kv_cache_utils.py:524] GPU KV cache size: 407,056 tokens
INFO 02-22 02:40:52 kv_cache_utils.py:527] Maximum concurrency for 131,072 tokens per request: 3.11x
INFO 02-22 02:41:14 gpu_model_runner.py:1363] Graph capturing finished in 22 secs, took 0.64 GiB
INFO 02-22 02:41:14 core.py:116] init engine (profile, create kv cache, warmup model) took 36.18 seconds
Processed prompts: 100%|███████████████████████| 1000/1000 [00:19<00:00, 51.26it/s, est. speed input: 11030.14 toks/s, output: 10166.30 toks/s]
Throughput: 48.69 requests/s, 20135.43 total tokens/s, 9657.43 output tokens/s

Sampler available, penalties applied, prompt token IDs provided.

            if get_pp_group().is_last_rank:
                hidden_states = hidden_states[logit_indices]
                logits = self.model.compute_logits(hidden_states, None)
                dummy_tensors = lambda v: torch.full(
                    (num_reqs, ), v, device=self.device)
                dummy_metadata = SamplingMetadata(
                    temperature=dummy_tensors(0.5),
                    all_greedy=False,
                    all_random=False,
                    spec_token_ids=None,
                    top_p=dummy_tensors(0.9),
                    top_k=dummy_tensors(logits.size(1) - 1),
                    min_p=None,
                    generators={},
                    max_num_logprobs=None,
                    no_penalties=False,
                    prompt_token_ids=torch.ones_like(logits, dtype=torch.long),
                    frequency_penalties=dummy_tensors(0.1),
                    presence_penalties=dummy_tensors(0.1),
                    repetition_penalties=dummy_tensors(0.1),
                    output_token_ids=[[] for _ in range(num_reqs)],
                    min_tokens={},
                    logit_bias=[None for _ in range(num_reqs)])
                sampler_output = self.model.sample(
                    logits=logits, sampling_metadata=dummy_metadata)
            else:
                logits = None
                sampler_output = None
                dummy_metadata = None
            torch.cuda.synchronize()
            del hidden_states, logits, sampler_output, dummy_metadata
INFO 02-22 02:37:14 monitor.py:33] torch.compile takes 7.50 s in total
INFO 02-22 02:37:15 kv_cache_utils.py:524] GPU KV cache size: 394,640 tokens
INFO 02-22 02:37:15 kv_cache_utils.py:527] Maximum concurrency for 131,072 tokens per request: 3.01x
INFO 02-22 02:37:45 gpu_model_runner.py:1363] Graph capturing finished in 30 secs, took 0.64 GiB
INFO 02-22 02:37:45 core.py:116] init engine (profile, create kv cache, warmup model) took 45.18 seconds
Processed prompts: 100%|███████████████████████| 1000/1000 [00:19<00:00, 50.94it/s, est. speed input: 10962.21 toks/s, output: 10103.69 toks/s]
Throughput: 48.44 requests/s, 20031.49 total tokens/s, 9607.57 output tokens/s

FIX #13507

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 20, 2025
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! It seems that this PR adds the sampling logic in profile_run we have on V0 into V1, but rejection sampler is not considered here (which is fine for this PR imo)

Which script did you run to get the # blocks difference in the PR description?

@JenZhao JenZhao changed the title [Bugfix] Add sampler in memory profiling [Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler Feb 20, 2025
Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com>
Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com>
@ywang96
Copy link
Member

ywang96 commented Feb 20, 2025

cc @njhill In case you want to review this PR too

Comment on lines 1324 to 1326
frequency_penalties=penalties,
presence_penalties=penalties,
repetition_penalties=penalties,
Copy link
Member

@ywang96 ywang96 Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update these since we shouldn't be using the same tensor object for all three of them.

JenZhao and others added 2 commits February 20, 2025 11:59
Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com>
@njhill
Copy link
Member

njhill commented Feb 21, 2025

Thanks @JenZhao this looks good.

My only hesitation is that the penalties sampler implementation still needs to be optimized further and including them here may mean more memory ends up being reserved than necessary for a typical case, harming throughput. For a max batch size of 1k, this could be an overhead of more than 2GB just for penalties sampling.

I'm not sure there's a good answer for this though since in theory you could hit OOM in some cases otherwise. I guess it may be better to merge this and then try to address the inefficiency afterwards asap. Also cc @WoosukKwon for thoughts.

@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend structured-output labels Feb 21, 2025
@JenZhao
Copy link
Contributor Author

JenZhao commented Feb 21, 2025

oh no I did incorrect rebase. fixing now

Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com>
Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com>
JenZhao and others added 3 commits February 21, 2025 07:03
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JenZhao Thanks again for your contribution. I'm giving this PR a greenlight but please take a look at my comment!

min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=False,
Copy link
Member

@ywang96 ywang96 Feb 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some peak memory analysis with / without this PR when running VLLM_USE_V1=1 python3 benchmarks/benchmark_throughput.py --model meta-llama/Llama-3.1-8B-Instruct --dataset ShareGPT_V3_unfiltered_cleaned_split.json on 1xH100.

Main(no sampler): 75998MiB / 81559MiB
This PR (no penalties, prompt_token_ids = None): 73536MiB / 81559MiB
This PR (no penalties, prompt_token_ids = torch.ones_like(logits, dtype=torch.long): 72560MiB / 81559MiB
This PR (with penalties): 70046MiB / 81559MiB

Given that we're setting the default GMU to 0.9 which already gives some room for inference with penalties, I think it's ok if we go with the third option (profiling with no_penalties=True).

It is indeed weird though why we're seeing this memory bump in the actual inference workload when there were no penalties applied (scenario 2 vs scenario 3), so perhaps there's somewhere where the prompt_token_ids (or a tensor of same size) are used but not captured during profile_run.

Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com>
@ywang96 ywang96 added ready ONLY add when PR is ready to merge/full CI is needed and removed documentation Improvements or additions to documentation structured-output frontend ci/build labels Feb 22, 2025
@ywang96 ywang96 merged commit da31b53 into vllm-project:main Feb 22, 2025
54 of 56 checks passed
@WoosukKwon
Copy link
Collaborator

@JenZhao Congrats for your first PR!

@JenZhao @ywang96 I've observed that the peak memory goes much beyond 90%, even after this PR:

[4] NVIDIA H100 80GB HBM3 | 64'C,  94 % | 80773 / 81559 MB | woosuk.kwon(80296M)

I encountered this error while running benchmark_serving.py. The only modification I made was changing the temperature parameter from 0 to 0.1.

@WoosukKwon
Copy link
Collaborator

BTW, weirdly, the memory taken by the vLLM instance got reduced when the server finished handling all the requests and became idle.

[4] NVIDIA H100 80GB HBM3 | 33'C,   0 % | 71009 / 81559 MB | woosuk.kwon(70532M)

As we do not call torch.cuda.empty_cache by ourselves, this is quite weird.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[V1][Bug]: Consider sampler in memory profiling
4 participants