-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problems with KV cache padding to max_sequence length #422
Comments
fyi @tt-mpantic |
Attaching Padded KV cache.excalidraw.zip P.S. |
@nvukobratTT @dgolubovicTT I try to debug further based upon the inputs from @dgolubovicTT . After passing postion_ids and attention_mask to the foward function in decode phase, I am able to generate result same as without cache. With and without out cache:
Repro:
|
Great @chandrasekaranpradeep! Let's see if we can modify it on the current version of transformers on FFE :)) Also, please test it out with longer seq len :)) It'll be good to see demo with few thousand tokens generated ;) @dgolubovicTT @chandrasekaranpradeep once we confirm that this works properly, what do you think about documenting this publicly? This approach can be good workaround until we bringup more LLMS and start pushing on porting KV cache on metal :)) Let me know your thoughts :)) |
Sure I will check with max new token = 1000 and also how we modify it with current transformers version on FFE. |
Sure |
We have a few iterations to cover, and then when we have working KV cache on our device, we can document publicly how we managed to enable it. For now I pushed methods for comparing K and V tensors throughout the decoder layers When we are sure that padded KV cache yields the same results on CPU, we will just switch to compiled_model and test that one. |
I will check and let you know guys here |
@nvukobratTT @dgolubovicTT The current transformers version(4.41.0) in tt-forge-fe, will not affect the padded cache test cases for running on cpu but it will throws issue when you to try to compile the model on TT in TVM to Forge Op conversion. Below issue will be thrown
|
We shouldn't use fast tokenizer for the |
If I am not wrong, In our main we are using AutoTokenizer. And what is the difference between AutoTokenizer and LlamaTokeniser? Do we prefer one over the other, and if yes why? |
AutoTokenizer is helper/generic class for automatically instantiate different tokenizer based upon the model_name that are passed to the from_pretrained method but LlamaTokenizer as it name says it specific for Llama models. |
@dgolubovicTT I have compared the past key values tensor and hidden states, there is no tensor mismatch found between past key values matrix with and without padding test cases. I have used Past key values and Hidden state tensor comparison result log: Repro:
|
Great job guys! With this last confirmation, are you confident about merging these changes to the main? :)) If yes, let's also update visuals that @dgolubovicTT created for KV cache, and potentially use that as a reference for updating our FFE Docs :)) |
@nvukobratTT When I try to compile the model with/without cache on TT device, I am seeing unused inputs in the generated forge module forward function |
If i am not wrong, this should be fixed with the following: huggingface/transformers#24233 |
Good plan. Let's sync offline if you have any blockers... |
Let's first do compile of llama with KV cache and then we will follow up with documentation... |
@nvukobratTT @dgolubovicTT I have resolved the unused input in forge module forward function issue by calculating the causal_mask inside the LlamaModelWrapper and above datatype issue by supporting cast op in forge. Now when we compile the model, we will hit |
Since our compiler deals with operations that have static shapes (inputs and outputs) in order to enable KV cache we have to pad K and V tensors to some max_seq_len. It is eather that or compile in every iteration of decode loop.
With this premise, I lay out three issues I've encountered while debugging data mismatches in Llama 3B token generation.
In order to KV cache to work with padding we have to understand and solve these in our compiler:
Here is the link for this visualization together with issues: https://excalidraw.com/#json=jBvSa7lieemcWAwXBLyAc,qU8bzz4x8-8kMoELaCbYqw
Issues:
I wish to avoid hacking these to enable Llama, but solve KV cache more generally.
The text was updated successfully, but these errors were encountered: