-
Notifications
You must be signed in to change notification settings - Fork 368
fix(llama): buffer tokens until valid UTF-8 #122
Conversation
Nice, this seems to fix my problems. I can't test with Vicuna because it's GGJT. 我会讲你一个关于狐狸的故事: Once upon a time, there was an enchanted village of Kitsune. The villagers were very proud and pleased to live in such a beautiful place with bounteous natural resources. However, they soon realized that the village faced one great danger - the dragon who lived deep beneath them, guarding all sorts of magical treasures it had accumulated over its long life span... The prompt is the bold part. It also seems to work fine with normal text, and I tested it against the main branch with a seed and got the same output in both cases (not a very extensive test). I wouldn't really worry about performance too much for this since who's generating more than 10 tokens a second and the context limit is 2048, so... It's going to be pretty insignificant in terms of effects. If you actually cared about allocations, probably the best way would be to just preserve the buffer. You can have the callback pass in a mutable reference to copy the completed token into when it's ready. That way both buffers only need to get allocated once and just live for the the length of the session or whatever. |
Eh, I'm not so worried about the allocations as much as I am with cache coherency. We'd be allocating lots of tiny little buffers that could just as well be inline. You might be right though, we can figure that out later. |
Same problem with Vicuna (using #114)
🤦 the model isn't trained to speak Unicode codepoints coherently |
Are you saying it's worse than it was originally? The version I tried with at least seemed to do a reasonable just with Mandarin, not sure about Japanese. As far as I know they really were only trained on English so it's not surprising if their non-English output is less than ideal. |
I haven't tried this patch yet.
Not the point. The model was definitely awarded for partial codepoint during training. |
I found a fix. I set the logits of invalid tokens to 0.0. Here's Vicuna speaking fluently.
|
Better fix:
|
Which ones? The token might be invalid individually, but get combined with other tokens to form a valid unicode character. So if you just set them all to 0.0, you'll prevent it from expressing any unicode characters where the components aren't all valid individually. |
All of them.
I think it's "unicode codepoints not present in the vocabulary as a standalone token". |
Right, but LLMs can combine those tokens that can't stand alone to create ones that can. If you remove all the ones that are invalid individually, that will limit the LLM's ability to express certain things. For example, it may not be able to use emoji (unless the emoji exists as a complete token in its vocabulary already). |
That's what this pull is intending to fix. Or do you mean it doesn't work even with this pull? |
Sorry. This patch works for me. I've merged this in my repo as branch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! 😄 Only one big comment w.r.t. error recovery but other than that we should be good to go.
As discussed on Discord and in #11.
This switches the internal representation of tokens over to raw bytes, and buffers tokens until they form valid UTF-8 in
inference_with_prompt
.Open questions:
smallvec
or similar for tokens? We're going to be making a lot of unnecessary tiny allocations as-is.FnMut
as a bound is OK, right?