Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Generation Inputs: input_embeds #745

Open
AlekseyKorshuk opened this issue Jul 26, 2024 · 12 comments
Open

[Feature] Generation Inputs: input_embeds #745

AlekseyKorshuk opened this issue Jul 26, 2024 · 12 comments
Labels
enhancement New feature or request good first issue Good for newcomers high priority

Comments

@AlekseyKorshuk
Copy link

AlekseyKorshuk commented Jul 26, 2024

Motivation

I propose to add input_embeds as an optional input to the generation params.

Why is this important

Nowadays there are a lot of Vision Language Models (VLMs) and they all have similar architecture: vision tower, projector, LLM. This means vision_tower+projector just prepares embeddings for "image" tokens. So why not allow model developers to handle by themselves the preparation of input_embeds for the LLM?
Lots of new models tend to allow the user to work with bounding boxes and segmentation masks like PaliGemma and Florence, making it quite complicated to add different processors and conversation templates to the codebase.
By allowing the user to provide input_embeds instead of list of messages or text prompts, you reduce your own headache in the future.
Another point is that VLM developers can focus on caching image embeddings while building on top of the SGLang, allowing even higher throughput.

vLLM users required this feature long time ago and this topic gained a lot of positive attention from the community:

This unique feature will make the SGLang the main framework for all VLMs.

I am happy to help implement this if you direct me in the codebase and thank you for your time and consideration 🤗

Proposed usages

response = client.chat.completions.create(
    model="default",
    input_embeds=[...],
    temperature=0.8,
    max_tokens=64,
)
backend.run(input_embeds=input_embeds)
@dataclass
class GenerateReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The embeddings for input_ids; if specified, input_ids should also be provided
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
    image_data: Optional[Union[List[str], str]] = None
    # The sampling_params.
    sampling_params: Union[List[Dict], Dict] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Whether to return logprobs.
    return_logprob: Optional[Union[List[bool], bool]] = None
    # The start location of the prompt for return_logprob.
    logprob_start_len: Optional[Union[List[int], int]] = None
    # The number of top logprobs to return.
    top_logprobs_num: Optional[Union[List[int], int]] = None
    # Whether to detokenize tokens in logprobs.
    return_text_in_logprobs: bool = False
    # Whether to stream output.
    stream: bool = False

Related resources

@joshpxyne
Copy link

+1!

@jsdir
Copy link

jsdir commented Jul 26, 2024

+1

@zhyncs zhyncs added the backlog label Jul 26, 2024
@tunahfishy
Copy link

!!!

@ummagumm-a
Copy link

having this feature would be nice, indeed

@merrymercy merrymercy added enhancement New feature or request high priority and removed backlog labels Jul 27, 2024
@merrymercy
Copy link
Contributor

merrymercy commented Jul 27, 2024

Great suggestions. Let's prioritize this one. I can share some ideas and pointers.

High-level Idea

Since many parts of the existing code rely on the concept of "input_ids: List[int]," it is not easy to fully change all of them, as this will create many problematic "if/else" conditions. I think one possible implementation idea is to create some random fake "input_ids" to make most of the existing code runnable. Then, during the actual forward pass, we can feed input_embeds instead of calling the embedding layer to encode input_ids.

You can learn more about this idea by looking at how the existing Llava implementation directly feeds input_embeds into the underlying Llama:

return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)

if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds

Implementation

The inference of a request starts with GenerateReqInput from the HTTP server, then it will go through several important classes: TokenizerManager, ModelTpServer, ModelRunner, Req, nferBatch. To implement your change, we need to update these places.

  1. Implement your proposed changes to GenerateReqInput
    class GenerateReqInput:
  2. Skip the input tokenization in TokenizerManager
    input_ids = (
    self.tokenizer.encode(input_text)
    if obj.input_ids is None
    else obj.input_ids
    )
    if index is not None and obj.input_ids:
    input_ids = obj.input_ids[index]
  3. When creating the Req, record the input_embeds. Maybe here is also a good place to generate the fake input_ids mentioned above.
    req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
  4. When preparing the inputs of a prefill batch. Save input_embeds into InferBatch. In SGLang, "prefill" is also called "extend".
    def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
  5. When running the actual forward pass. Feed input_embeds to the model,
    def forward_extend(self, batch: Batch):
    input_metadata = InputMetadata.create(
    self,
    forward_mode=ForwardMode.EXTEND,
    req_pool_indices=batch.req_pool_indices,
    seq_lens=batch.seq_lens,
    prefix_lens=batch.prefix_lens,
    position_ids_offsets=batch.position_ids_offsets,
    out_cache_loc=batch.out_cache_loc,
    top_logprobs_nums=batch.top_logprobs_nums,
    return_logprob=batch.return_logprob,
    )
    return self.model.forward(
    batch.input_ids, input_metadata.positions, input_metadata
    )

This is my rough idea. I haven't implemented it yet, so there may be some mistakes. I hope it is helpful.

@Ying1123
Copy link
Member

Ying1123 commented Aug 4, 2024

@AlekseyKorshuk any updates?

@AlekseyKorshuk
Copy link
Author

Last week was quite busy for me, so unfortunately have not started yet

Copy link

github-actions bot commented Oct 4, 2024

This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.

@github-actions github-actions bot closed this as completed Oct 4, 2024
@merrymercy merrymercy reopened this Oct 6, 2024
@RinRin-32
Copy link
Contributor

RinRin-32 commented Oct 15, 2024

Great suggestions. Let's prioritize this one. I can share some ideas and pointers.

High-level Idea

Since many parts of the existing code rely on the concept of "input_ids: List[int]," it is not easy to fully change all of them, as this will create many problematic "if/else" conditions. I think one possible implementation idea is to create some random fake "input_ids" to make most of the existing code runnable. Then, during the actual forward pass, we can feed input_embeds instead of calling the embedding layer to encode input_ids.

You can learn more about this idea by looking at how the existing Llava implementation directly feeds input_embeds into the underlying Llama:

return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)

if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds

Implementation

The inference of a request starts with GenerateReqInput from the HTTP server, then it will go through several important classes: TokenizerManager, ModelTpServer, ModelRunner, Req, nferBatch. To implement your change, we need to update these places.

1. Implement your proposed changes to GenerateReqInput https://github.com/sgl-project/sglang/blob/3fdab91912fb271c20642e21c2055df0e23d514e/python/sglang/srt/managers/io_struct.py#L15

2. Skip the input tokenization in `TokenizerManager` https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/tokenizer_manager.py#L142-L148

3. When creating the `Req`, record the `input_embeds`. Maybe here is also a good place to generate the fake input_ids mentioned above. https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/controller/tp_worker.py#L263

4. When preparing the inputs of a prefill batch. Save input_embeds into `InferBatch`. In SGLang, "prefill" is also called "extend". https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/controller/infer_batch.py#L313

5. When running the actual forward pass. Feed `input_embeds` to the model, https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/controller/model_runner.py#L295-L309

This is my rough idea. I haven't implemented it yet, so there may be some mistakes. I hope it is helpful.

Hello, I implemented accordingly to this high level overview and managed to get input_embeds working/generating response.

My current issue is that I can only generate using input_embeds once, if I use input_embeds to generate again I get this error:

[08:49:39 TP0] Traceback (most recent call last): File "/data/rin_experiements/sglang/python/sglang/srt/managers/scheduler.py", line 994, in run_scheduler_process scheduler.event_loop() File "/data/rin_experiements/sglang/python/sglang/srt/managers/scheduler.py", line 242, in event_loop self.forward_step() File "/home/azureuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/data/rin_experiements/sglang/python/sglang/srt/managers/scheduler.py", line 292, in forward_step self.forward_prefill_batch(new_batch) File "/data/rin_experiements/sglang/python/sglang/srt/managers/scheduler.py", line 592, in forward_prefill_batch logits_output, next_token_ids = self.tp_worker.forward_batch_generation( File "/data/rin_experiements/sglang/python/sglang/srt/managers/tp_worker.py", line 114, in forward_batch_generation logits_output = self.model_runner.forward(forward_batch) File "/data/rin_experiements/sglang/python/sglang/srt/model_executor/model_runner.py", line 521, in forward return self.forward_extend(forward_batch) File "/data/rin_experiements/sglang/python/sglang/srt/model_executor/model_runner.py", line 496, in forward_extend return self.model.forward( File "/home/azureuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/data/rin_experiements/sglang/python/sglang/srt/models/qwen2.py", line 290, in forward hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/data/rin_experiements/sglang/python/sglang/srt/models/qwen2.py", line 256, in forward hidden_states, residual = layer( File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/data/rin_experiements/sglang/python/sglang/srt/models/qwen2.py", line 208, in forward hidden_states = self.self_attn( File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/data/rin_experiements/sglang/python/sglang/srt/models/qwen2.py", line 157, in forward attn_output = self.attn(q, k, v, forward_batch) File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/azureuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/data/rin_experiements/sglang/python/sglang/srt/layers/radix_attention.py", line 60, in forward return forward_batch.attn_backend.forward(q, k, v, self, forward_batch) File "/data/rin_experiements/sglang/python/sglang/srt/layers/attention/__init__.py", line 41, in forward return self.forward_extend(q, k, v, layer, forward_batch) File "/data/rin_experiements/sglang/python/sglang/srt/layers/attention/flashinfer_backend.py", line 222, in forward_extend forward_batch.token_to_kv_pool.set_kv_buffer( File "/data/rin_experiements/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 200, in set_kv_buffer self.k_buffer[layer_id][loc] = cache_k RuntimeError: shape mismatch: value tensor of shape [6, 4, 128] cannot be broadcast to indexing result of shape [1, 4, 128]

Do you have any recommendations on how to navigate the repository for fixes?

Update

turn out using --disable-radix solves my issue

@majunze2001
Copy link

@RinRin-32 Do you have a commit/branch? I am interested to take a further look.

@RinRin-32
Copy link
Contributor

@majunze2001 Sure thing! My organization worked based on a fork of 0.3.2. I was discourage to do a pull request seeing that 0.3.3 structure changed drastically. Seeing the current main, my implementation would likely work here. I'll make the pull request in a week or two and link it here.

The main changes I worked on are in
python/sglang/srt/managers/
io_struct.py
schedule_batch.py
scheduler.py
tokenizer_manager.py

python/sglang/srt/model_executor/
forward_batch_info.py
model_runner.py

@RinRin-32
Copy link
Contributor

RinRin-32 commented Nov 16, 2024

@majunze2001 I've just made my pull request
Please check it out at
#2052

There are still some flaws like the lack of args for serving using input_embeds, I've documented this in the pull request.
I tried to keep the if else conditions to the minimum, hoping other contributors can help optimize it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers high priority
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants