You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
currently, LLMEngine (driver) lives in the same process as tensor parallel rank 0 process, which caused a lot trouble for us, e.g. we cannot easily create two instances of vLLM with different GPUs. Spec decode hacks this a lot.
basically, the function we care about is LLMEngine.step, and the core line of code is:
when we use tensor parallel of size N,
this line will:
process execute_model_req into tensors, broadcast tensors to the rest N - 1 workers
the worker process, together with the rest N - 1 workers, execute the model, and gather the output in the worker process
if we want to separate the tp rank 0 process and the engine process, such as #6032 ,there will be two serialization:
execute_model_req will be serialized and sent to tp processes, even with advanced techniques, we can send once, and all processes can receive it, we still need to serialize it.
output will live in the tp rank 0 process at first, and then passed to the engine process
Therefore, we need to measure how large are these objects, what is the cost of serializing them.
Here is a simple script:
fromvllmimportLLM, SamplingParamsprompts= [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params=SamplingParams(temperature=0.8, top_p=0.95)
llm=LLM(model="facebook/opt-125m", tensor_parallel_size=2)
outputs=llm.generate(prompts, sampling_params)
# Print the outputs.foroutputinoutputs:
prompt=output.promptgenerated_text=output.outputs[0].textprint(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
as we can see, the actual message we pass every step (difference, or distance between consecutive messages) is actually quite small, in several dozens of bytes. however, the serialized data are 10x~100x larger. Why?
for the output, this is because we have a very bad serialization format: pickle for dataclasses. it stores field names, and class names, again and again.
for the input, besides the above limitation (e.g. serialization of SamplingParams is terribly long, but not informative), we have another limitation: it sends the prompt again and again.
What's next?
we need to design the interaction between engine and worker to minimize the data transfer. potentially sending diff only. this will make the worker stateful, remembering the last input, but i think this should be fine.
we need to design the serialization format (msgpack does not work here, it cannot serialize ExecuteModelRequest)
What if ExecuteModelRequest or SamplerOutput contains GPU data?
GPU data is expensive to move across processes. it should be used as least as possible. in most cases, we should leave GPU data in the worker. ideally, the engine will not own any GPU data.
we need to design the interaction between engine and worker to minimize the data transfer. potentially sending diff only. this will make the worker stateful, remembering the last input, but i think this should be fine.
we need to design the serialization format (msgpack does not work here, it cannot serialize ExecuteModelRequest)
Btw, both of these are done inside Anyscale before, and last time I benchmarked (this Jan), this could have the nearly same result as nccl broadcast based solution.
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
Proposal to improve performance
currently,
LLMEngine
(driver) lives in the same process as tensor parallel rank 0 process, which caused a lot trouble for us, e.g. we cannot easily create two instances of vLLM with different GPUs. Spec decode hacks this a lot.basically, the function we care about is
LLMEngine.step
, and the core line of code is:when we use tensor parallel of size N,
this line will:
execute_model_req
into tensors, broadcast tensors to the rest N - 1 workersif we want to separate the tp rank 0 process and the engine process, such as #6032 ,there will be two serialization:
execute_model_req
will be serialized and sent to tp processes, even with advanced techniques, we can send once, and all processes can receive it, we still need to serialize it.output
will live in the tp rank 0 process at first, and then passed to the engine processTherefore, we need to measure how large are these objects, what is the cost of serializing them.
Here is a simple script:
And we can use the branch https://github.com/youkaichao/vllm/tree/measure_serialization to measure the serialization overhead (remember to
pip install levenshtein
):as we can see, the actual message we pass every step (difference, or distance between consecutive messages) is actually quite small, in several dozens of bytes. however, the serialized data are 10x~100x larger. Why?
for the output, this is because we have a very bad serialization format:
pickle
fordataclasses
. it stores field names, and class names, again and again.for the input, besides the above limitation (e.g. serialization of
SamplingParams
is terribly long, but not informative), we have another limitation: it sends the prompt again and again.What's next?
msgpack
does not work here, it cannot serializeExecuteModelRequest
)What if
ExecuteModelRequest
orSamplerOutput
contains GPU data?GPU data is expensive to move across processes. it should be used as least as possible. in most cases, we should leave GPU data in the worker. ideally, the engine will not own any GPU data.
cc @WoosukKwon @zhuohan123 @simon-mo @comaniac @cadedaniel @stephanie-wang @ruisearch42
Report of performance regression
No response
Misc discussion on performance
No response
Your current environment (if you think it is necessary)
The text was updated successfully, but these errors were encountered: