Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: empirical measurement of object serialization for input/output of worker #6241

Open
youkaichao opened this issue Jul 9, 2024 · 4 comments
Labels
performance Performance-related issues stale

Comments

@youkaichao
Copy link
Member

Proposal to improve performance

currently, LLMEngine (driver) lives in the same process as tensor parallel rank 0 process, which caused a lot trouble for us, e.g. we cannot easily create two instances of vLLM with different GPUs. Spec decode hacks this a lot.

basically, the function we care about is LLMEngine.step, and the core line of code is:

output = self.model_executor.execute_model(
                execute_model_req=execute_model_req)

when we use tensor parallel of size N,
this line will:

  1. process execute_model_req into tensors, broadcast tensors to the rest N - 1 workers
  2. the worker process, together with the rest N - 1 workers, execute the model, and gather the output in the worker process

if we want to separate the tp rank 0 process and the engine process, such as #6032 ,there will be two serialization:

  1. execute_model_req will be serialized and sent to tp processes, even with advanced techniques, we can send once, and all processes can receive it, we still need to serialize it.
  2. output will live in the tp rank 0 process at first, and then passed to the engine process

Therefore, we need to measure how large are these objects, what is the cost of serializing them.

Here is a simple script:

from vllm import LLM, SamplingParams
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2)
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

And we can use the branch https://github.com/youkaichao/vllm/tree/measure_serialization to measure the serialization overhead (remember to pip install levenshtein):

len(this_input)=2616, len(this_output)=610
len(this_input)=2640, len(this_output)=601
input_distance=93, output_distance=49
len(this_input)=2660, len(this_output)=603
input_distance=44, output_distance=33
len(this_input)=2678, len(this_output)=603
input_distance=39, output_distance=35
len(this_input)=2696, len(this_output)=603
input_distance=39, output_distance=34
len(this_input)=2714, len(this_output)=605
input_distance=40, output_distance=45
len(this_input)=2734, len(this_output)=605
input_distance=38, output_distance=38
len(this_input)=2754, len(this_output)=605
input_distance=42, output_distance=37
len(this_input)=2774, len(this_output)=605
input_distance=41, output_distance=34
len(this_input)=2800, len(this_output)=603
input_distance=43, output_distance=37
len(this_input)=2818, len(this_output)=601
input_distance=39, output_distance=35
len(this_input)=2852, len(this_output)=605
input_distance=55, output_distance=37
len(this_input)=2872, len(this_output)=607
input_distance=41, output_distance=44
len(this_input)=2894, len(this_output)=605
input_distance=43, output_distance=45
len(this_input)=2914, len(this_output)=601
input_distance=42, output_distance=36
len(this_input)=2930, len(this_output)=609
input_distance=37, output_distance=42

as we can see, the actual message we pass every step (difference, or distance between consecutive messages) is actually quite small, in several dozens of bytes. however, the serialized data are 10x~100x larger. Why?

for the output, this is because we have a very bad serialization format: pickle for dataclasses. it stores field names, and class names, again and again.

for the input, besides the above limitation (e.g. serialization of SamplingParams is terribly long, but not informative), we have another limitation: it sends the prompt again and again.

What's next?

  1. we need to design the interaction between engine and worker to minimize the data transfer. potentially sending diff only. this will make the worker stateful, remembering the last input, but i think this should be fine.
  2. we need to design the serialization format (msgpack does not work here, it cannot serialize ExecuteModelRequest)

What if ExecuteModelRequest or SamplerOutput contains GPU data?

GPU data is expensive to move across processes. it should be used as least as possible. in most cases, we should leave GPU data in the worker. ideally, the engine will not own any GPU data.

cc @WoosukKwon @zhuohan123 @simon-mo @comaniac @cadedaniel @stephanie-wang @ruisearch42

Report of performance regression

No response

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

The output of `python collect_env.py`
@youkaichao youkaichao added the performance Performance-related issues label Jul 9, 2024
@youkaichao
Copy link
Member Author

okay, SamplingParams seems to be the main trouble:

from vllm.sampling_params import SamplingParams
import pickle
print(len(pickle.dumps(SamplingParams()))) # 611

even if it does not contain any information, it will use 611 bytes.

@rkooo567
Copy link
Collaborator

  • we need to design the interaction between engine and worker to minimize the data transfer. potentially sending diff only. this will make the worker stateful, remembering the last input, but i think this should be fine.
  • we need to design the serialization format (msgpack does not work here, it cannot serialize ExecuteModelRequest)

Btw, both of these are done inside Anyscale before, and last time I benchmarked (this Jan), this could have the nearly same result as nccl broadcast based solution.

@youkaichao
Copy link
Member Author

this could have the nearly same result as nccl broadcast based solution.

glad to know that. then i think we should work for this, to replace nccl broadcast.

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues stale
Projects
None yet
Development

No branches or pull requests

2 participants