Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]: Throughput/Latency for guided_json with ~100% GPU cache utilization #3567

Open
jens-create opened this issue Mar 22, 2024 · 65 comments

Comments

@jens-create
Copy link

Anything you want to discuss about vllm.

Hi,

I am running some benchmarks on the vllm.entrypoints.openai.api_server measuring latency and throughput with different number of concurrent requests.

Specs:

  • H100 80GB
  • qwen-1.5-14B-chat

I am sending 1000 requests with random prompts of token length 512. These are the results I get (see attached image):

Guided_json

  • ~100 running requests
  • ~70 generation tokens per second
  • ~1700 ms median token time

Non-guided_json

  • ~100 running requests
  • ~800 generation tokens per second
  • ~75 ms median token time (TPOT)

At 10 concurrent request (GPU utlization << 100%

Non-guided_json: ~20 ms median token time
guided_json: ~ 160 ms median token time

Currently the application I am building heavily relies on guided_json, however, to put it in an online setting I would like to ask 1) are the numbers I experience sensible and 2) what can be done to improve performance in the guided_json paradigm?

I am debating whether I should try and prompt my way to structured outputs and thus avoiding constrained decoding.

Screenshot 2024-03-22 at 10 10 14 )
@simon-mo
Copy link
Collaborator

Is the JSON schema complex at all, and is it the same each time? The 70 toks/s number is a bit lower than I expected. This can be due to several factor if it's the same schema:

  • Currently the logits mask computation is performed on the critical path but it can be moved earlier.
  • We currently don't batch the application of logits processors.
  • Python overhead in general

I'm interested in fixing the performance here.

@simon-mo simon-mo self-assigned this Mar 22, 2024
@jens-create
Copy link
Author

Hi Simon,

The JSON schema is the same at all times, and it is as follows:

"guided_json": {"$defs": {"SearchQuery": {"description": "Search query for the retrieval task.", "properties": {"query_type": {"description": "The type of query most effective for handling the retrieval task.", "title": "Query Type", "type": "string"}, "query": {"description": "A random user's search query.", "title": "Query", "type": "string"}}, "required": ["query_type", "query"], "title": "SearchQuery", "type": "object"}}, "description": "A list of search queries anticipating a user looking for information from a given web page.", "properties": {"queries": {"description": "Brainstormed search queries for the given web page.", "items": {"$ref": "#/$defs/SearchQuery"}, "title": "Queries", "type": "array"}}, "required": ["queries"], "title": "Brainstorm", "type": "object"}

Thanks for looking into this 🫶

@jens-create
Copy link
Author

@simon-mo any update on this? 😊

@taoisu
Copy link

taoisu commented Apr 24, 2024

Facing similar issue here, I have a json with 14 fields, the request stucks forever.

@lithafnium
Copy link

lithafnium commented May 7, 2024

My schema only has 2 fields and also has significant latency issues than when using without guided_json. Would love to have this fixed as model performance severely decreases without it.

@simon-mo
Copy link
Collaborator

simon-mo commented May 7, 2024

I would suggest trying out setting --guided-decoding-backend lm-format-enforcer (through args) or "guided_decoding_backend": "lm-format-enforcer" as part of the request to see whether it helps. see original PR here: #3868 (cc @noamgat)

@noamgat
Copy link
Contributor

noamgat commented May 7, 2024

If testing lm-format-enforcer, I highly recommend adding the latest version of it to the image, as there have been performance improvements to the JsonSchemaParser. The next version of vLLM will include them, but until them, do pip install lm-format-enforcer==0.10.1 in the image before testing.

@jarrelscy
Copy link

what speeds are you getting @noamgat vs the outlines backend?

@noamgat
Copy link
Contributor

noamgat commented May 7, 2024

I didn't test on A100/H100s, but on my dev setup (GTX 3090, Mistral7B), for simple schemas, I was getting a less than 2x reduction of tokens/s.

@nullpointer0xffff
Copy link

+1, it seems not GPU related, I tested with A100 / V100 GPUs both have similar issue.

Using line profiler, I found this get_guided_decoding_logits_processor call takes 93% time

@nullpointer0xffff
Copy link

If testing lm-format-enforcer, I highly recommend adding the latest version of it to the image, as there have been performance improvements to the JsonSchemaParser. The next version of vLLM will include them, but until them, do pip install lm-format-enforcer==0.10.1 in the image before testing.

This get vllm 0.4.1+cu118 requires lm-format-encorcer==0.9.8, requiring to add --no-deps.

Just tested, the speed up is not obvious, probabbly the main bottleneck is still the get_guided_decoding_logits_processor

@nullpointer0xffff
Copy link

If testing lm-format-enforcer, I highly recommend adding the latest version of it to the image, as there have been performance improvements to the JsonSchemaParser. The next version of vLLM will include them, but until them, do pip install lm-format-enforcer==0.10.1 in the image before testing.

This get vllm 0.4.1+cu118 requires lm-format-encorcer==0.9.8, requiring to add --no-deps.

Just tested, the speed up is not obvious (25tok/s -> 32 tok/s on V100), probabbly the main bottleneck is still the get_guided_decoding_logits_processor

@nullpointer0xffff
Copy link

@noamgat here's a profling when I use lm-format-enforcer 0.10.1.

/lib/python3.10/site-packages/lmformatenforcer/integrations/transformers.py

Function: _build_regular_tokens_list at line 58

 

Line #      Hits         Time  Per Hit   % Time  Line Contents

==============================================================

    58                                           @profile

    59                                           def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str, bool]]:

    60         1  912794903.0    9e+08      9.5      token_0 = tokenizer.encode("0")[-1]

    61         1       8025.0   8025.0      0.0      regular_tokens = []

    62    128257   28050361.0    218.7      0.3      for token_idx in range(len(tokenizer)):

    63    128256   78294452.0    610.5      0.8          if token_idx in tokenizer.all_special_ids:

    64         2        450.0    225.0      0.0              continue

    65                                                   # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.

    66    128254 5319568501.0  41476.8     55.3          decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]

    67    128254 3162992335.0  24661.9     32.9          decoded_regular = tokenizer.decode([token_idx])

    68    128254   56427009.0    440.0      0.6          is_word_start_token = len(decoded_after_0) > len(decoded_regular)

    69    128254   61975079.0    483.2      0.6          regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))

70         1        240.0    240.0      0.0      return regular_tokens

The two decoding in for loop seems took most time. Happy to make further test if needed.

@noamgat
Copy link
Contributor

noamgat commented May 7, 2024 via email

@noamgat
Copy link
Contributor

noamgat commented May 8, 2024

Just clarifying - if possible, start the tokens/s measuring and/or profiling from the second request onwards. While the warm-up time is also something that can be optimized, the post-warmup performance matters much more for real-world use cases. This is true for all guided decoding backends.

@Qubitium
Copy link
Contributor

Qubitium commented May 8, 2024

@nullpointer0xffff @jens-create I just confirmed the caching of LMFE tokenizer init (very very slow) via @lru_cache is working so build_regular_tokens_list should never be called past the first request.

@SaloniGandhi
Copy link

SaloniGandhi commented May 15, 2024

maybe we can modify the call method to separate the mask computation from the logits adjustment. This allows the mask to be computed once and reused. let me know if this makes sense @simon-mo

@lynkz-matt-psaltis
Copy link

lynkz-matt-psaltis commented May 22, 2024

Just sharing my experience with this issue - Seems to align with the OPs experience.

Summary: CPU constrained guidance means that batching can't scale correctly.

Vllm 0.4.2
Outlines: 0.0.34
lm_format_enforcer: 0.10.2
Model: Llama 3 8b instruct
Hardware:

  • Single A100 (80G)
  • AMD EPYC 7V13 (24 cores)

Single request:

Outlines: ~70 tps - CPU 100%
lm_format_enforcer: ~45 tps - CPU 100%
No guidance: ~140 tps

Batched requests:

Outlines: ~70 tps - CPU 100%
lm_format_enforcer: ~45 tps - CPU 100%
No guidance: ~1000 tps

Guided regex and json both effected:

Example guidance:

regex
~~~response\n# Content\\n([.\\W\\w]+)\\n{2}~{3}
json
{"type":"object","properties":{"test":{"type":"string"}},"required":["test"]}

@lynkz-matt-psaltis
Copy link

Here's line timings for model_executor/guided_decoding/outlines_logits_processors.py

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    41                                               @line_profiler.profile
    42                                               def __call__(self, input_ids: List[int],
    43                                                            scores: torch.Tensor) -> torch.Tensor:
    44                                                   """Use the FSM to bias the logits before sampling the next token."""
    45      2686      22898.6      8.5      0.1          seq_id = hash(tuple(input_ids))
    46
    47      2686       1994.5      0.7      0.0          if len(input_ids) == 0:
    48         3         13.6      4.5      0.0              self.init_state()
    49                                                   else:
    50      2683        953.2      0.4      0.0              last_token = input_ids[-1]
    51      2683      12007.7      4.5      0.0              last_seq_id = hash(tuple(input_ids[:-1]))
    52      5366      15540.9      2.9      0.0              self.fsm_state[seq_id] = self.fsm.next_state(
    53      2683       2226.4      0.8      0.0                  self.fsm_state[last_seq_id], last_token)
    54
    55      2686    2022417.3    752.9      5.2          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
    56
    57      5372      83383.0     15.5      0.2          mask = torch.full((scores.shape[-1], ),
    58      2686       3307.1      1.2      0.0                            -math.inf,
    59      2686       1901.2      0.7      0.0                            device=scores.device)
    60      2686   36566141.1  13613.6     94.3          mask[allowed_tokens] = 0
    61      2686      36379.1     13.5      0.1          scores.add_(mask)
    62      2686        794.1      0.3      0.0          return scores

@felixzhu555
Copy link
Contributor

Based on that timing breakdown, can you try to replace mask[allowed_tokens] = 0 by using torch index_fill? e.g. mask.index_fill_(0, allowed_tokens, 0)
This might be faster than manually indexing the mask tensor.

@lynkz-matt-psaltis
Copy link

I've been doing some further perf analysis and breaking things out a bit to try and understand the bottleneck. Doesn't seem to be related to the indexer but rather, moving the allowed_tokens array around.

cpu first, move to gpu

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    43                                               @line_profiler.profile
    44                                               def __call__(self, input_ids: List[int],
    45                                                            scores: torch.Tensor) -> torch.Tensor:
    46                                                   """Use the FSM to bias the logits before sampling the next token."""
    47      2529      18368.3      7.3      0.1          seq_id = hash(tuple(input_ids))
    48
    49      2529       2418.0      1.0      0.0          if len(input_ids) == 0:
    50         3         22.6      7.5      0.0              self.init_state()
    51                                                   else:
    52      2526        886.1      0.4      0.0              last_token = input_ids[-1]
    53      2526      11457.4      4.5      0.1              last_seq_id = hash(tuple(input_ids[:-1]))
    54      5052      14539.8      2.9      0.1              self.fsm_state[seq_id] = self.fsm.next_state(
    55      2526       1931.4      0.8      0.0                  self.fsm_state[last_seq_id], last_token)
    56
    57      2529    1903376.3    752.6     10.5          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
    58      2529   15524244.9   6138.5     85.5          allowed_tokens_tensor = torch.tensor(allowed_tokens, dtype=torch.int32, device='cpu')
    59
    60      2529       3262.6      1.3      0.0          if self.mask is None or self.allowed_tokens_tensor is None:
    61      2529      82721.9     32.7      0.5              self.mask = torch.full_like(scores, -math.inf)
    62                                                   else:
    63                                                       self.mask.fill_(-math.inf)
    64                                                   
    65      2529       4009.6      1.6      0.0          if (allowed_tokens_tensor.device != scores.device):
    66      2529     489064.9    193.4      2.7              allowed_tokens_tensor = allowed_tokens_tensor.to(scores.device)
    67                                                       
    68      2529      39004.4     15.4      0.2          allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64)
    69                                                   
    70      2529      35650.6     14.1      0.2          self.mask.index_fill_(0, allowed_tokens_tensor, 0)
    71      2529      23729.9      9.4      0.1          scores.add_(self.mask)
    72
    73      2529        630.8      0.2      0.0          return scores

straight to gpu:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    43                                               @line_profiler.profile
    44                                               def __call__(self, input_ids: List[int],
    45                                                            scores: torch.Tensor) -> torch.Tensor:
    46                                                   """Use the FSM to bias the logits before sampling the next token."""
    47      2252      14057.0      6.2      0.1          seq_id = hash(tuple(input_ids))
    48
    49      2252       1943.8      0.9      0.0          if len(input_ids) == 0:
    50         3         12.1      4.0      0.0              self.init_state()
    51                                                   else:
    52      2249        696.3      0.3      0.0              last_token = input_ids[-1]
    53      2249       9021.9      4.0      0.1              last_seq_id = hash(tuple(input_ids[:-1]))
    54      4498      14201.8      3.2      0.1              self.fsm_state[seq_id] = self.fsm.next_state(
    55      2249       1836.6      0.8      0.0                  self.fsm_state[last_seq_id], last_token)
    56
    57      2252    1692571.2    751.6     10.5          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
    58      2252   14277941.7   6340.1     88.3          allowed_tokens_tensor = torch.tensor(allowed_tokens, dtype=torch.int32, device=scores.device)
    59
    60      2252       2582.4      1.1      0.0          if self.mask is None or self.allowed_tokens_tensor is None:
    61      2252      55524.5     24.7      0.3              self.mask = torch.full_like(scores, -math.inf)
    62                                                   else:
    63                                                       self.mask.fill_(-math.inf)
    64                                                   
    65      2252       3560.3      1.6      0.0          if (allowed_tokens_tensor.device != scores.device):
    66                                                       allowed_tokens_tensor = allowed_tokens_tensor.to(scores.device)
    67                                                       
    68      2252      34986.8     15.5      0.2          allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64)
    69                                                   
    70      2252      32876.8     14.6      0.2          self.mask.index_fill_(0, allowed_tokens_tensor, 0)
    71      2252      22152.8      9.8      0.1          scores.add_(self.mask)
    72
    73      2252        633.6      0.3      0.0          return scores

@lynkz-matt-psaltis
Copy link

58     12693    9401753.9    740.7     17.2          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
59     12693   42707835.5   3364.7     78.1          np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32)
60     12693      73736.7      5.8      0.1          allowed_tokens_tensor = torch.from_numpy(np_allowed_tokens)

Halved the cost by using a numpy array first also tried torch.as_tensor but no significant changes.

@lynkz-matt-psaltis
Copy link

Beyond this, I'm not sure I see a way forward without changes to outlines and lm-format-enforcer to provide the information in a more efficient structure than a List. Does anyone see any memorisation opportunities here to at least reduce the iteration counts?

@jarrelscy
Copy link

jarrelscy commented May 23, 2024

One thing I think we could do to make it faster is to use the fact that allowed_tokens is either almost all the tokens, or none of the tokens. Currently the mask is created at -math.inf, but we could also create the mask at 0 if the length of allowed_tokens is < scores.shape[0]/2 and then fill_ with -math.inf instead?

@lynkz-matt-psaltis
Copy link

I went down that same line of thinking - I don't think the timings above support it however. Its getting the python List into a Tensor that seems to be 80%+ of the cost per iteration. So short of data structure changes upstream, my current thinking is we're left with iteration optimisations - can we avoid going back to fsm.allowed_token_ids in certain situations. Not sure on that yet - still learning how this all fits together.

@JGSweets
Copy link
Contributor

Are the PRs for this issue currently stalled due to competing priorities?

@lapp0
Copy link

lapp0 commented Aug 27, 2024

Hi, would like to ask if caching the outlines allowed tokens might cause the GPU to OOM in high load, since it is unbounded? Or is there another process within VLLM that will manage VRAM usage?

For each state in the automata, outlines stores a tensor with a list of legal token IDs. However we don't store these tensors on GPU, so it shouldn't result in CudaOOM.

@Jason-CKY
Copy link
Contributor

I believe this issue has been fixed upstream on outlines on their v0.1.0 release
Relevant PR: dottxt-ai/outlines#1013

We just need to update requirements-common.txt to download the latest version of outlines

outlines >= 0.0.43, < 0.1

@robcaulk
Copy link
Contributor

@Jason-CKY , unfortunately, that is not likely the root of the issue.

The current problem is present whether lm-enforcer or outlines is used for constrained output in vLLM ( #3567 (comment) ). This indicates that the problem is in how vLLM is handling the logits processing. Most likely, it is due to the fact that they are using threads instead of processes.

@Jason-CKY
Copy link
Contributor

Jason-CKY commented Oct 18, 2024

I see. Is there any work on this issue right now? I see from this thread that there is a draft PR #6900 that should help when using outlines, but it seems that there isn't any progress after the initial PR and test failure https://buildkite.com/vllm/fastcheck/builds/1264#0190ff17-7643-4664-8108-9d2abc4bf589/192-1046.

Just wondering if there is anybody working on it, and if not i'll be happy to help out on this issue with some guidance :)

@robcaulk
Copy link
Contributor

@Jason-CKY No one appears to be working on it, so your help would be welcomed.

Personally, if I had the time, I would start by using processes instead of threads here

global_thread_pool = concurrent.futures.ThreadPoolExecutor(

My suspicion is that this is the bottleneck.

It may be easier said than done, though depending on which objects need to be serialized and if they can be serialized. Further, I am not sure if any objects need to be shared, in which case you need to introduce a lock.

@Jason-CKY
Copy link
Contributor

Jason-CKY commented Oct 20, 2024

Hi, I've done a latency test for json guided generation, and the fix from @lapp0 dottxt-ai/outlines#1013 seems to fix the latency/throughput figures to be on par with un-guided generation.

There doesn't seem to be any sort of bottlenecks at least from my preliminary investigations to warrant further optimizations (threads vs processes) that might be slowing down guided generation.

The test script as well as results can be found here. For information, these results are run with a RTX A500 gpu, on llama 3.2-3b-instruct unquantized.

The latency numbers are also run 10 times and the average number is plotted out on the table below.

without guided generation with guided generation
vllm-openai:latest batch size 1: 2.0754491329193114
batch size 10: 3.5337918519973757
batch size 30: 4.945922183990478
batch size 1: 1.8324383974075318
batch size 10: 8.734584641456603
batch size 30: 23.51435124874115
with outlines cache patch in vllm batch size 1: 2.058398962020874
batch size 10: 3.8099711894989015
batch size 30: 5.031021142005921
batch size 1: 2.474873948097229
batch size 10: 1.929900097846985
batch size 30: 2.5930819272994996

the results are run using vllm/vllm-openai:latest docker image, with vllm serve meta-llama/Llama-3.2-3B-Instruct --max-model-len 4096 entrypoint

@bannsec
Copy link

bannsec commented Oct 20, 2024

Confused, is that patch in vllm currently?

@Jason-CKY
Copy link
Contributor

No, that fix was in outlines library, I patched the outlines library with the fix and compared the results.

The fix was introduced in outlines v0.1.0 as I mentioned in a previous reply.

@bannsec
Copy link

bannsec commented Oct 20, 2024

The patch mentioned was merged 4 months ago into outlines, thus my confusion. If we're using the current version of outlines library then what's left to patch?

@Jason-CKY
Copy link
Contributor

The current version of vllm is not using the latest version of outlines. The merge 4 months ago from outlines was not properly released until 2 weeks ago (0.1.0), and the latest released version of vllm is pinned to version 0.0.46 of outlines

@lynkz-matt-psaltis
Copy link

My understanding is there is still work on the vllm side to consume the upstream changes.

Notably:

  1. Removal of the duplicate Outlines Logits Processors within the vllm repo.
  2. Calls from outlines_decoding.py should instead be made directly to the outlines upstream package instead of the duplicate.

These changes were ready in: #6900 - I'd need to refresh that PR against latest main.
The work I haven't got to was adding a proper regression test for this performance issue against vllm. Specifically the openai api endpoints.

@Jason-CKY The test harness you've linked to above, whilst the batch sizes are changing for each test scenario, how many concurrent requests/threads are being tested in each batch size? From the quick glance I gave it, it wasn't obvious what the request sizes were. The current performance issue doesn't really manifest until you have multiple concurrent inference requests running within the batch. On our A100 test scenarios, we start to see the request duration increase at 3 concurrent requests on a 48 core machine.

@Jason-CKY
Copy link
Contributor

i was changing the batch size by calling the /chat/completions endpoint and varying the n request parameter. Each run is a single client-side call with n=batch_size. I didn't try separate concurrent requests so maybe that's why i didn't run into the increase in latencies

@bannsec
Copy link

bannsec commented Oct 21, 2024

Yes, concurrency is the issue, not n parameter.

@robcaulk
Copy link
Contributor

Indeed, concurrency is the problem. Here is a script to reproduce the bug if you want to understand @Jason-CKY :

dottxt-ai/outlines#1011

@Jason-CKY
Copy link
Contributor

Jason-CKY commented Oct 21, 2024

Thanks for the example! I have written another script that awaits multiple calls at the same time using asyncio.gather, which results in multiple concurrent calls to the vllm server.

running the test with 30 concurrency calls with and without the patch, i got the following results. I ran each permutation over 10 runs and list down the response times and average times:

without guided generation with guided generation
vllm-openai:latest {

    "response_times": [

        3.3515219688415527,

        3.2652950286865234,

        3.4849653244018555,

        3.2668240070343018,

        3.286436080932617,

        3.285951852798462,

        3.2925519943237305,

        3.2971129417419434,

        3.2944278717041016,

        3.3238115310668945

    ],

    "average_response_time": 3.3148898601531984

}
{

    "response_times": [

        66.37526655197144,

        33.00612998008728,

        35.22352361679077,

        36.37863898277283,

        33.52225971221924,

        34.01370286941528,

        36.545432567596436,

        33.04991841316223,

        35.484585762023926,

        33.356016397476196

    ],

    "average_response_time": 37.69554748535156

}s
with outlines cache patch in vllm {

    "response_times": [

        42.935458183288574,

        8.999914646148682,

        9.664096593856812,

        9.662794351577759,

        9.768625974655151,

        8.679716110229492,

        9.30372953414917,

        9.63024616241455,

        10.138444185256958,

        9.631026029586792

    ],

    "average_response_time": 12.841405177116394

}

i referenced https://github.com/dottxt-ai/outlines/pull/1013/files#diff-202c3676a40bf3fd70a140e8e4fa2959cb88548cf134a7f809ad50e0f6b4176d to implement the outlines patch. Specifically, i replaced the contents of /usr/local/lib/python3.12/dist-packages/outlines with https://raw.githubusercontent.com/lapp0/outlines/refs/heads/bench-logits-processors/outlines/fsm/guide.py.

Again, i'm running all of this on vllm/vllm-openai:v0.6.3.post1 image, on a single A5000 gpu.

There is a big improvement in the latencies, but still significantly slower than without guidance. The high latency on the first run is due to the caching of FSM states for guidance.

@bannsec
Copy link

bannsec commented Oct 21, 2024

I would love to have even this level of improvement in the base vllm.

@robcaulk
Copy link
Contributor

If we are going to bump the library, it might be worth it to consider using outlines-core instead (https://github.com/dottxt-ai/outlines-core) since it is written in Rust and more efficient.

@sfc-gh-zhyao
Copy link

face the same issue

@bodybreaker
Copy link

Same here

@lynkz-matt-psaltis
Copy link

As always, a thumbs up is preferred over me too comments :)

@robcaulk do you know if outlines-core is stable and ready for production use?

Just trying to work out if I find a free minute this weekend where I should focus my attention :)

@robcaulk
Copy link
Contributor

@lynkz-matt-psaltis it’s a good question. The creators of outlines publicly pinged Vllm when the yannnounced this release (https://www.linkedin.com/posts/activity-7254738693855363072-PHIb).

I suppose we can ping them here to ask, @lapp0 @rlouf , is outlines-core recommended now for production use?

@rlouf
Copy link

rlouf commented Oct 25, 2024

outlines-core is almost ready for production use, we are waiting for one refactoring PR to be merged. But that's not where the performance hit comes from. outlines is significantly faster than lm-format-enforcer at inference time on every benchmark that we've run, with added latency of the order of a few milliseconds.

My best guess is that vllm uses Outlines' cache function which uses diskcache to cache indexes on disk. It is possible that the call to read the cache is blocking, and that would explain why the slowdown is only observed in a concurrent setting.

@francescov1
Copy link

Is there a simple way to try this out with a custom built vllm image?

@WoutDeRijck
Copy link

I'm interested in contributing to improving the throughput/latency of guided decoding. Any pointers to specific areas where investigation would be most valuable, or what are the open challenges now?

@cruzanstx
Copy link

Any update on this?

@stodoran
Copy link

Wondering if there has been any progress on this? I'm still observing a roughly 50% increase in latency when using structured outputs with outlines compared to regular inference (which is better than the 10x latency increase I observe using structured outputs with lm-format-enforcer but that isn't much consolation).

@mgoin
Copy link
Member

mgoin commented Dec 19, 2024

We now have integrated xgrammar in 0.6.5 as the default backend in supported cases. We also updated outlines to the latest version using its Rust core. Please give your benchmarks another try!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.