Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] set OMP_NUM_THREADS to 1 by default when using the multiproc_gpu_executor #6109

Merged
merged 3 commits into from
Jul 3, 2024

Conversation

tjohnson31415
Copy link
Contributor

@tjohnson31415 tjohnson31415 commented Jul 3, 2024

Set OMP_NUM_THREADS to 1 by default (if unset) when using the MultiprocessingGPUExecutor to prevent CPU contention amongst the sharded processes.

Without this change, use of the MP backend has caused some regressions because of this CPU contention. From what I can tell, the Ray GPU executor uses num_cpus=0 which then sets OMP_NUM_THREADS=1 for the Ray worker, regardless of the outside environment.

#5230 changed the distributed backend to default to MP instead of Ray in certain conditions. This was released in version 0.5.0, which explains the regression noted in #5564 between versions 0.4.3 and 0.5.0.

For reference, here's how Ray states that it handles OMP_NUM_THREADS generally:

Ray sets the environment variable OMP_NUM_THREADS=<num_cpus> if num_cpus is set on the task/actor via ray.remote() and task.options()/actor.options(). Ray sets OMP_NUM_THREADS=1 if num_cpus is not specified; this is done to avoid performance degradation with many workers (issue #6998). You can also override this by explicitly setting OMP_NUM_THREADS to override anything Ray sets by default.

REF

FIX #6072
FIX #5564
FIX #5532

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
@youkaichao
Copy link
Member

which code uses OMP_NUM_THREADS?

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @tjohnson31415!!

vllm/executor/multiproc_gpu_executor.py Outdated Show resolved Hide resolved
@njhill njhill mentioned this pull request Jul 3, 2024
@tjohnson31415
Copy link
Contributor Author

tjohnson31415 commented Jul 3, 2024

which code uses OMP_NUM_THREADS?

Libraries like Pytorch and numpy that wrap lower level BLAS libraries often use OpenMP which uses OMP_NUM_THREADS to configure the number of CPU threads to spawn for parallel computations.

In my testing, I was looking at the base model load time and the lora adapter load/optmize times, both of which are loading data to the CPU. Any CPU operation on tensors will be affected by this setting. So some more performance testing may be good, but it is clear to me that not setting OMP_NUM_THREADS is more likely to cause problems than setting it to 1 and this change aligns the default behavior w.r.t. threading for the Ray and MP backends.

tjohnson31415 and others added 2 commits July 3, 2024 11:16
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
@tjohnson31415
Copy link
Contributor Author

Adding some notes/timing measurements from my testing:

My test env was in a Pod in Kubernetes with cpu requests set to 8 and cpu limits set to 16 running on a node with 80 total cores. Without OMP_NUM_THREADS, I think each shard would spawn 80 threads and they'd fight for the CPU. Watching the load with top I noticed CPU time exceeding 16, which caused throttling on top of the context switching overhead. Data recorded below was measured in a Jupyter notebook with %%time using offline vLLM (single run results, but in line with what I saw from running multiple times).

Code used for testing.
%%time

import os
# Change for testing
#os.environ["OMP_NUM_THREADS"] = "1"

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    tensor_parallel_size=4,
    enable_lora=True,
    enforce_eager=True,
    # change for testing
    distributed_executor_backend="ray",
)
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=4
)

prompts = [
     "### Text: I have no water and the bill is current and paid. Can you do something about this?\n\n### Label:",
]

lora_path = '/tmp/my_lora'

lora_request = LoRARequest("my_adapter", 1, lora_path)
%%time
outputs = llm.generate(
    prompts,
    sampling_params,
    lora_request=lora_request
)

For loading Llama 2 70B with TP=4:

No OMP_NUM_THREADS set:

CPU times: user 5min 50s, sys: 48.8 s, total: 6min 39s
Wall time: 6min

OMP_NUM_THREADS = 1:

CPU times: user 1min 24s, sys: 38.1 s, total: 2min 2s
Wall time: 4min 26s

Using Ray backend (one thread per worker):

CPU times: user 1min 25s, sys: 42 s, total: 2min 7s
Wall time: 4min 29s

For loading and running generate() with a LoRA adapter model

No OMP_NUM_THREADS set:

First call, loading the model

CPU times: user 7min 43s, sys: 1.97 s, total: 7min 45s
Wall time: 2min 2s

Second call

CPU times: user 198 ms, sys: 4.8 ms, total: 203 ms
Wall time: 210 ms

OMP_NUM_THREADS = 1:

First call, loading the model

CPU times: user 15.9 s, sys: 238 ms, total: 16.1 s
Wall time: 2.58 s

Second call

CPU times: user 205 ms, sys: 6.32 ms, total: 212 ms
Wall time: 213 ms

Using Ray backend (one thread per worker):

First call, loading the model

CPU times: user 6.88 s, sys: 33.3 ms, total: 6.91 s
Wall time: 4.79 s

Second call

CPU times: user 223 ms, sys: 0 ns, total: 223 ms
Wall time: 240 ms

@youkaichao youkaichao merged commit 1dab9bc into vllm-project:main Jul 3, 2024
70 checks passed
@tjohnson31415 tjohnson31415 deleted the set-omp-num-threads branch July 5, 2024 20:39
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 7, 2024
…m-project#6109)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
…m-project#6109)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
…m-project#6109)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…m-project#6109)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants