Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: guided generation is very slow in offline mode #8313

Open
stas00 opened this issue Sep 10, 2024 · 19 comments
Open

[Performance]: guided generation is very slow in offline mode #8313

stas00 opened this issue Sep 10, 2024 · 19 comments
Labels
performance Performance-related issues

Comments

@stas00
Copy link
Contributor

stas00 commented Sep 10, 2024

Proposal to improve performance

With a single request / online mode I'm getting:

  • no guided 300 tok/sec
  • outlines 150 tok/sec (2x slower)
  • lm-format-enforcer 90 tok/sec (~3x slower)

with offline mode I get:

  • outlines is about 10-20x slower than no guided generation
  • lm-format-enforcer is about 4x faster than outlines (note that it is slower than outlines for online)

for online I was using this schema:

json_template = {
    "type": "object",
    "properties": {
        "criteria": {"type": "array", "items": {"type": "string"}, "minItems": 1},
        "response": { "type": "string" }
    },
    "required": ["criteria", "response"]
}

for offline I was using an even simpler schema:


{
   "type":"object",
   "properties":{
      "name":{
         "type":"string", "minLength":2, "maxLength":5
      },
      "age":{
         "type":"integer"
      }
   },
   "required":[ "name", "age"]
}

the huge performance hit in the offline mode is very strange for both backends.

2x slow down in the online mode is pretty bad too as it's already a huge impact. The offline mode can actually tolerate 2x no problem as there is no human in the loop, but 10-20x is a way impractical.

vllm=0.6.0 and outlines==0.0.46

@stas00 stas00 added the performance Performance-related issues label Sep 10, 2024
@Quang-elec44
Copy link

@stas00 In my experience, guided generation is always slower than normal. I recommend you try sglang instead. Sglang achieves better throughput than vLLM, but the guided generation is still slower.

@stas00
Copy link
Contributor Author

stas00 commented Sep 10, 2024

Thank you for this suggestion, @Quang-elec44 - I understand that it'll be slower, but it should be marginally slower, not 20x slower. Possibly some problem in the integration?

@robertgshaw2-neuralmagic
Copy link
Collaborator

Thank you for this suggestion, @Quang-elec44 - I understand that it'll be slower, but it should be marginally slower, not 20x slower. Possibly some problem in the integration?

@stas00 is this a new issue on v0.6.0?

@robertgshaw2-neuralmagic
Copy link
Collaborator

@stas00 - looks like someone has just fixed the issue:

@robertgshaw2-neuralmagic
Copy link
Collaborator

@stas00 In my experience, guided generation is always slower than normal. I recommend you try sglang instead. Sglang achieves better throughput than vLLM, but the guided generation is still slower.

Hey @Quang-elec44 - this is actually not true as of vllm's current release: https://blog.vllm.ai/2024/09/05/perf-update.html

We also have more performance optimizations that will come out in v0.6.0:

@stas00
Copy link
Contributor Author

stas00 commented Sep 10, 2024

@stas00 is this a new issue on v0.6.0?

no, same with older versions - e.g. 0.5.5

looks like someone has just fixed the issue:

Robert, I have already tried it to no avail.

I'm working on a reproducible test case - will share as soon as I have it.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Thanks @stas00 - @simon-mo can you copy in the folks who implemented the guided decoding?

@Lap1n
Copy link

Lap1n commented Sep 10, 2024

@stas00 Have you (or can you) tried with vllm v0.5.2 if you observe the same performance issue?

In my investigation, I found a regression in performance from v0.5.2 -> 0.5.3 which this PR fixes.

However, in my benchmarks, there seems to be another significant performance regression in guided generation with Outlines from 0.5.3.post1 to 0.5.4 that I have not investigated yet.

Note that in my tests, I only tested in online mode.

@stas00
Copy link
Contributor Author

stas00 commented Sep 10, 2024

@Lap1n, I can't try v0.5.2 since it didn't support guided generated in the offline mode.

But I think the problem here is something entirely different, I'm trying to dig to the root of it.

As I'm writing the offline repro scripts w/ TinyLlama-1.1B-Chat-v0.6 vllm is only 10% slower w/ outlines than w/o them once FSM has been compiled - which is absolutely normal and expected. So something else is going on - will update once I have a more clear picture.

It's possible that there is a multitude of issues here, and it's possible that somehow the big overhead is model-dependent (even though in theory the type or a size of a model shouldn't matter at all overhead-wise).

@stas00
Copy link
Contributor Author

stas00 commented Sep 10, 2024

@robertgshaw2-neuralmagic, do you know why outlines cache doesn't get used and instead it rebuilds the same. e.g. here is one of the repro scripts I created:

import time

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.6"

schema = '{ "type": "object", "properties": { "name": { "type": "string", "minLength": 2,  "maxLength": 5 }, "age": { "type": "integer"} }, "required": ["name", "age"] }'

model = LLM(
    model=model_name_or_path,
    tokenizer=model_name_or_path,
    tokenizer_mode="auto",
    tensor_parallel_size=1,
    trust_remote_code=True,
    dtype="bfloat16",
    gpu_memory_utilization=0.8,
    guided_decoding_backend="outlines",
)

prompt = "Give an example of a person's profile that fits this JSON schema: {schema}"

sampling_params = SamplingParams(
    temperature=0.0,
    seed=42,
    max_tokens=2048,
)
kwargs = dict(
    sampling_params=sampling_params,
    guided_options_request=dict(
        guided_json=schema,
    ),
)

# warmup
output = model.generate(prompt, **kwargs)

response = ''
start_time = time.time()
for i in range(5):
    output = model.generate(prompt, **kwargs)
    response += output[0].outputs[0].text
end_time = time.time()

print(response)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

response_tokens = tokenizer.encode(response)
response_tokens_len = len(response_tokens)
#print(response_tokens)

duration = end_time - start_time
print(f"decode had {response_tokens_len} tokens")
print(f"decode latency: {duration:0.2f}secs")
print(f"decode throughput: {response_tokens_len/duration:0.2f} tok/sec")

w/o warmup we get about 6 tok/sec due to FSM setup overhead. if we warm up it's 166 tok/sec.

re-running it the 2nd time w/o warm up repeats the same recompilation - shouldn't that be pulled from the cache and re-used, rather than recompiling it?

@Lap1n
Copy link

Lap1n commented Sep 10, 2024

@stas00 I also got the same issue, seems like an issue with the key's format that Outlines uses to store cached data. I opened a PR a few days ago to fix this on Outlines' repo.

@stas00
Copy link
Contributor Author

stas00 commented Sep 10, 2024

so I'm waiting for my colleague to give me a repro case so that I could narrow it down for him.

Meanwhile, why is this pure outlines script so slow?

import outlines
import time
from transformers import AutoTokenizer
import json

model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.6"

schema = '{ "type": "object", "properties": { "name": { "type": "string", "minLength": 2,  "maxLength": 5 }, "age": { "type": "integer"} }, "required": ["name", "age"] }'

model = outlines.models.transformers(model_name_or_path)
generator = outlines.generate.json(model, schema)

prompt = "Give an example of a person's profile that fits this JSON schema: {schema}"

# warmup
response = generator(prompt)

# real
response = ''
start_time = time.time()
for i in range(5):
    response += json.dumps(generator(prompt))
end_time = time.time()
print(response)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

response_tokens = tokenizer.encode(response)
response_tokens_len = len(response_tokens)
print(response_tokens)

duration = end_time - start_time
print(f"decode had {response_tokens_len} tokens")
print(f"decode latency: {duration:0.2f}secs")
print(f"decode throughput: {response_tokens_len/duration:0.2f} tok/sec")

I wrote it first to check that standalone outlines was fast and use it as a base-line before I did the exact same generation with vllm+outlines as I shared here:
#8313 (comment)

but actually I discovered that it's very slow!

On A100:

$ python outlines-direct.py
{"name": "John", "age": 29}{"name": "John", "age": 30}{"name": "John", "age": 50}{"name": "Dan", "age": 20}{"name": "John", "age": 30}
decode had 70 tokens
decode latency: 15.89secs
decode throughput: 4.41 tok/sec

$ python outlines-vllm.py 
{"name":"John","age":25}{"name":"John","age":25}{"name":"John","age":25}{"name":"John","age":25}{"name":"John","age":25}
decode had 51 tokens
decode latency: 0.28secs
decode throughput: 181.05 tok/sec

that's a 40x difference! Perhaps this is what we are hitting, but this is in reverse - vllm+outlines is 40x faster than outlines on its own.

edit: I found the issue with the direct outlines use - I had to put it on cuda - it was 40x slower on cpu

model = outlines.models.transformers(model_name_or_path, device='cuda:0')

now it's 2x slower than vllm integration:

decode had 67 tokens
decode latency: 0.75secs
decode throughput: 89.89 tok/sec

not sure why this time.

@stas00
Copy link
Contributor Author

stas00 commented Sep 10, 2024

the other mismatch thing I noticed is that vllm strips spaces between elements of json:

# outlines:
{"name": "John", "age": 29}{"name": "John", "age": 30}

# vllm:
{"name":"John","age":25}{"name":"John","age":25}

I guess that saves a few tokens from needing to be generated w/o compromising the output structure - nice!

@stas00
Copy link
Contributor Author

stas00 commented Sep 10, 2024

ok, I see that vllm isn't using outlines.generate.json generator but uses the regex engine. which has multiple issues:

  1. why does it recompile the regex on every request?

regex_string = build_regex_from_schema(schema_str, whitespace_pattern)

this should be cached during subsequent requests of the same run!

to see that it recompiles it run my script here: #8313 (comment)
and add a print on that line 142, you will see it gets called 6 times with 6 requests.

  1. this leads to vllm+outline failing to complete json last characters when max_tokens isn't too long. whereas outlines.generate.json seems to always return the correct (complete) json.

The problem here is that if the generator runs out of max_tokens and the json structure isn't closed - i.e. missing say the very last token of } vllm will just return a broken json. So if a regex is used vllm has to know to back off and force the json closure before it runs out of max_tokens. Currently I have to retry multiple times before I get all prompts to generate the correct structure - which sort of defeats the purpose of guided generation as it takes much much longer.

edit: the problem with repro code is now defined here: #8350

@robertgshaw2-neuralmagic
Copy link
Collaborator

@stas00 - Thank you for the detailed analysis here!

QQ - are you planning to open up a PR to fix this? We would definitely appreciate a contribution if you have the bandwidth

@stas00
Copy link
Contributor Author

stas00 commented Sep 11, 2024

QQ - are you planning to open up a PR to fix this? We would definitely appreciate a contribution if you have the bandwidth

The summary of things so far:

  1. one problem proved to be a repeated recomputation of FSM in outlines because the cache was being ignored in vllm integration. @Lap1n has proposed a PR to fix it here Make Outlines' cache reusable across startup by making the cache's key as string dottxt-ai/outlines#1129

  2. the other issue is re-compiling of the regex on each request - I proposed to add it to this other PR by @Lap1n: [Bugfix] Reenable LRU cache on Outlines' guide getters #8308 (comment) and opened an Issue [Performance]: JSONLogitsProcessor repeats the same build_regex_from_schema again and again #8383

  3. the truncated JSON is being actively discussed in this Issue: [Bug]: guided generation can't always finish generating the requested structure #8350

For the bigger issue that started the current main issue in the OP I'm blocked by my colleague to give me a larger repro test case, so that I could verify it and if true reduce it to a small repro case, so please bear with me until they come through.

@Quang-elec44
Copy link

@

@stas00 In my experience, guided generation is always slower than normal. I recommend you try sglang instead. Sglang achieves better throughput than vLLM, but the guided generation is still slower.

Hey @Quang-elec44 - this is actually not true as of vllm's current release: https://blog.vllm.ai/2024/09/05/perf-update.html

We also have more performance optimizations that will come out in v0.6.0:

Thanks for your information. I know that vllm=0.06.0 made strong progress compared to sglang=0.3.0. However, in my experiments, sglang still achieves better throughput (my GPUs are A30). I may spend time testing the benchmark on my machine to have a closer look.

@aabbccddwasd
Copy link

any progress?

@dbuades
Copy link

dbuades commented Oct 29, 2024

We are experimenting the same issues on our side, guided generation is a great feature but it is very slow in offline mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

No branches or pull requests

6 participants