You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I thought of a way to speed up inference by using batches. This assumes that you can run a batch of 2 faster much than you can run 2 passes. So it will work with GPUs with a lot of compute cores or multi-GPU setups. The algorithm scales so the more computing power (more GPUs) the faster it will go.
First create a dictionary that gives the most common token to follow each particular token.
e.g. the most common token to follow 'there' might be 'was'.
You could probably get this data by just going through every token with a window of 1. And store the most likely next token. Then store these in a dictionary.
Say your tokens are this:
[Once, upon, a time, there,]
Then you put them as a batch of two like this. In the second batch, you simply guess the next token using your dictionary. (In this case your dictionary says that the most common word to follow 'there' is 'was'.)
[ _ ,Once, upon, a, time, there,]
[Once, upon, a , time, there, was,]
So now, if the output is this:
[ Once, upon, a, time, there, was]
[ upon, a , time, there, was, a]
It means you have got two tokens for the price of one [was, a]. I'm not sure what percent of the time you will get lucky like this. You might only do a double batch if you are fairly certain of the next word(s). You can always do bigger batches if you are less certain of the next word. Or you can even guess several words ahead.
Thus with dictionary lookups, and guessing ahead you might be able to speed up inference maybe two times!
This is the simplest way, a more complicated way would be to train a very small neural network (or use the same NN but on a very small window) to guess the next word, before running the full neural network. This means that if the small NN guesses correctly, you skip ahead several tokens! 🚀
(I wonder if such an algorithm is implemented by Chat GPT or Bard 🤔)
Unfortunately using the "window of 1" method the most common token to follow any word is usually one of these:
,
.
and
to
of
the
Which may make the method not so useful 🤔 Although for some words such as 'suggest' the most likely word to follow is 'that'.
I have found that I can use a smaller LLM such as the 111M cerebras model to make an initial good guess for the next word in 0.1 seconds then run a batch of 2. It gets the guess right a lot of the time. So in this way you can use a bad model to speed up a good model!
The text was updated successfully, but these errors were encountered:
elephantpanda
changed the title
An ingenious way to speed up inference!
An ingenious way to speed up inference! 🚀
Apr 4, 2023
This is indeed a great idea - it's called speculative decoding. You specific idea of having a dictionary lookup is close to staged speculative decoding, where there is a hierarchy of LMs, starting with (essentially) a lookup ngram model, then a small (transformer) LM, then the biggest "oracle" LM,
I thought of a way to speed up inference by using batches. This assumes that you can run a batch of 2 faster much than you can run 2 passes. So it will work with GPUs with a lot of compute cores or multi-GPU setups. The algorithm scales so the more computing power (more GPUs) the faster it will go.
First create a dictionary that gives the most common token to follow each particular token.
e.g. the most common token to follow 'there' might be 'was'.
You could probably get this data by just going through every token with a window of 1. And store the most likely next token. Then store these in a dictionary.
Say your tokens are this:
[Once, upon, a time, there,]
Then you put them as a batch of two like this. In the second batch, you simply guess the next token using your dictionary. (In this case your dictionary says that the most common word to follow 'there' is 'was'.)
So now, if the output is this:
It means you have got two tokens for the price of one [was, a]. I'm not sure what percent of the time you will get lucky like this. You might only do a double batch if you are fairly certain of the next word(s). You can always do bigger batches if you are less certain of the next word. Or you can even guess several words ahead.
Thus with dictionary lookups, and guessing ahead you might be able to speed up inference maybe two times!
This is the simplest way, a more complicated way would be to train a very small neural network (or use the same NN but on a very small window) to guess the next word, before running the full neural network. This means that if the small NN guesses correctly, you skip ahead several tokens! 🚀
(I wonder if such an algorithm is implemented by Chat GPT or Bard 🤔)
Unfortunately using the "window of 1" method the most common token to follow any word is usually one of these:
Which may make the method not so useful 🤔 Although for some words such as 'suggest' the most likely word to follow is 'that'.
I have found that I can use a smaller LLM such as the 111M cerebras model to make an initial good guess for the next word in 0.1 seconds then run a batch of 2. It gets the guess right a lot of the time. So in this way you can use a bad model to speed up a good model!
The text was updated successfully, but these errors were encountered: