-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: guided generation can't always finish generating the requested structure #8350
Comments
If I manipulate the schema to swap the order of
it then works much better:
6/10 success as compared to 2/10 success with the original schema. So even the order of items in schema can make a difference. and as I mentioned in the OP if I make the schema very strict by adding
I get 100% success:
which works fine for a contrived example or a test, but it won't work well in the general case. |
another thought - the model has no clue of how many tokens it can use to build the output other than based on training experience, so if it was taught to make json in 512 tokens and max_new_tokens is 512 it might have some predictable size of the "canvas" but since the prompt is of a varying length could the model know when to wrap up the requested structure? So perhaps structure-friendly models need to have some signal during training for the model to know how much space does it have for each particular prompt. or perhaps it could be instructed as in?
? |
@stas00 one idea that might help - you could additionally pass in a custom LogitsProcessor that increasingly boosts the scores of tokens that end in a terminating json character (i.e. The You would probably want to generate this list of tokens up-front based on the vocab (there may be other tokens whose string representation ends in |
That's a great idea, Nick! So I tried a simple logit processor that promotes the select few tokens to the top towards the end of the context window and it works! The POC logit processor:
end-to-end code:
running it:
all valid json endings! albeit the algorithm needs some polishing not to have |
ok, this seems to be quite automated:
full code:
output:
Do you think I still need to support a multi-char token like Anything else I'm missing to generalize this solution? I think the input from the user will be the ending of the schema - in this example |
hmm, running bigger batches, it's still failing at time, so it's not foolproof. e.g.:
so I suppose I have to be even more precise and not promote all ending tokens at once, but one at a time. |
@stas00 that's great that it "worked"! My hunch is that you could improve the quality of the outputs and have it work better for more general json cases by doing some of the other things I mentioned. What if it happens to be in a json array or more nested objects/arrays rather than a string. Rather than "forcing" it to produce these end chars when there's only a couple of slots left, having a longer "ramp down" (say 10 - 20 tokens but that is very arbitrary guess), where you increasingly boost the score of ending-type tokens over this range. I also think that generating a larger list of such tokens would help too... rather than encoding those chars, scan the entire vocab for tokens which end with them. And probably wouldn't harm to include EOS in that list (though maybe that won't make much difference in this case). But generalizing this approach to arbitrary non-json schemas/regexmay need a bit more thought :)
I'm a bit surprised that this was generated since it's not valid json. But using a more complete list of valid token ids again might help. |
OK, so I switched to making my own guided generation for the last few tokens, where I prescribe the one exact character to choose:
seems to work well now. I added a json validation at the end as I was missing things by manual inspection of the output
output:
|
It's a valid JSON in a sense that it's inside a string But if that's the way to go then I don't know how to apply your suggestion of having multiple tokens - e.g. those matching So my solution isn't general to other json schemas and will require the user to input the exact ending characters - and ideally we would want to derive this automatically from the schema. |
The main problem with this solution is that the generation is still chopped off wrt to the contents of the strings - it'd be nice to be able to signal to the model to wrap the sentences up. You can see what I mean:
|
Ah my bad I was thinking that braces need escaping but that's not the case.
This is where I thought that giving a longer period to wrap up and increasing the boosting factor over that time may help it finish more gracefully. e.g. boosted enough that it would choose the end string when a sentence comes to a natural end rather than starting a new sentence. And having this be an increasing multiplicative factor rather than just adding 100 as you're doing now. I think this may work for general json schema, because by boosting these kinds of tokens you're encouraging it to close the current string/list/object and that will happen repeatedly. I'm sure that having this integrated into the guiding logic itself would be best, but guessing that would be quite a bit more involved. |
I agree, following your suggestion I wrote a hack that enforces a valid json at a cost of abrupt ending. Surely the situation is extreme because I'm using an extremely short seqlen in my POC, and in the general case it probably won't be an issue 99% of the time. Additionally, making the schema more strict would naturally aid with the model doing the right thing w/o coercion. As you're saying all this work should be done by either vllm or even better by the backends - as clearly this calls for some smart algorithm based on many use cases. Should I take this next to say "outlines" Issues and see if they would see it as a problem they would want to solve? After all they promise
it fails to produce valid json for many of the tries. But the problem is that |
As Mihai Balint mentioned on twitter,
but it's slower than |
I was using vllm [v0.5.0.post1] and guided generation was working great. I upgraded to vllm [v0.6.2] and the only response I get is { " sometimes when adding truncate_prompt_tokens=30, I get the initial character of my template. { "r I tried changing max_tokens and many other parameters. Nothing works. I think around v.0.5.2, Outlines was updated to 0.0.46. v0.5.0.post1 used outlines 0.0.38. I don't know if that update broke it or some other changes on vllm. |
so we ended up using |
So it appears that guided generation returns the requested structure like json only if the model has an infinite number of tokens it can generate, but otherwise very often it fails to close the structure, e.g. if it's a simple
{ "key": "value" }
json schema and the new tokens are limited, it'll often return{ "key": "value
.I understand why this is happening - it's because the guiding can only guide when it has a subset of legal tokens to be used next, but if it's any token it can exhaust the full max_length and by the time it discovered it's unfinished it's too late to wrap up the structure.
Here is how to reproduce this problem:
gives:
It's quite obvious what the problem is here.
With
lm-format-enforcer
backend things are even worse.I then went to the origin and re-wrote the same code using
outlines
json generator and the problem is the same:gives
as you can see, it succeeds on the first request and fails to generate valid json on a second request.
Making
max_tokens=50
gives more completed jsons, but some still fail.While this example is contrived to be simple to understand, this problem can occur at any max_tokens if the greedy any character regex stage catches the model at the boundery of max_tokens.
Currently the workarounds I have been using is to make the schema more strict and define minLength and maxItems - but obviously this makes the generation far from ideal.
I also added a retry mechanism, which is very inefficient at the moment since it re-tries the whole thing. The efficient retry would be to chop off some ending and feed the now longer prompt which includes most of the generated output to the model and only get it to generate a different ending which hopefully would finish the structure.
Since the problem appears to be to close the structure, my feeling is that the correct solution is to identify the closing pattern and generate with max_tokens minus length of the closing structure and then somehow append it by changing that regex to fast-forward to the closure.
I realize the problem comes from the back-ends but perhaps we have more interested people here to discuss the correct algorithm to resolve this and then we could share the solution with the back-ends. At the very least it'd be good to document that currently guided generation in vllm is not guaranteed to work.
Using latest vllm==0.6.0 and outlines==0.0.46 here.
The text was updated successfully, but these errors were encountered: