Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with KV cache padding to max_sequence length #422

Closed
dgolubovicTT opened this issue Oct 17, 2024 · 19 comments · Fixed by #607
Closed

Problems with KV cache padding to max_sequence length #422

dgolubovicTT opened this issue Oct 17, 2024 · 19 comments · Fixed by #607
Assignees
Labels
exploration Needs more research before making concrete changes

Comments

@dgolubovicTT
Copy link
Contributor

Since our compiler deals with operations that have static shapes (inputs and outputs) in order to enable KV cache we have to pad K and V tensors to some max_seq_len. It is eather that or compile in every iteration of decode loop.

With this premise, I lay out three issues I've encountered while debugging data mismatches in Llama 3B token generation.
In order to KV cache to work with padding we have to understand and solve these in our compiler:

image

Here is the link for this visualization together with issues: https://excalidraw.com/#json=jBvSa7lieemcWAwXBLyAc,qU8bzz4x8-8kMoELaCbYqw

Issues:

  1. Rotarry embedings for each token depend on sequence length which is changed by padding. Possible action: track actual sequence length and calculate rotary embeddings based on that...
  2. Llama attention updates K matrix with k_previous_token by APPENDING it to the end. After that, it applies rotary embeddings to each row in K by rotating the vector in position "i" for Theta*i where theta is some angle (depends on seq len - issue1), and "i" is index that is for k_previous_token equal to:max_seq_len but should be 6. Possible action: instead of appending new k, our compiler should remap append to inserting new k to first free index...(tracking actual sequence length)
  3. Without padding, there is no need for providing attention mask, because new token is influenced by all previous. Here, we have to have attention mask because we don't want zeros from K matrix to influence softmax. Possible action: we need to create attention mask that takes into account where padding starts.

I wish to avoid hacking these to enable Llama, but solve KV cache more generally.

@dgolubovicTT dgolubovicTT added the exploration Needs more research before making concrete changes label Oct 17, 2024
@dgolubovicTT
Copy link
Contributor Author

fyi @tt-mpantic

@nvukobratTT
Copy link
Contributor

nvukobratTT commented Oct 17, 2024

Attaching .excalidraw file just in case the web link expires :))

Padded KV cache.excalidraw.zip

P.S.
Dark theme is prettier ;)) 😂

@chandrasekaranpradeep
Copy link
Contributor

@nvukobratTT @dgolubovicTT I try to debug further based upon the inputs from @dgolubovicTT .
We are able to support the Llama 3B decode with padded past key values when we pass position_ids and attention_mask. The rotary embedding applied to the query and key state depends upon the position_ids which is invalid earlier and attention_mask is not passed to the foward function which leads to calculating attention on padding_tokens on the past key tensor.
I have made doc for postion_ids and attention_mask issue - DOC LINK

After passing postion_ids and attention_mask to the foward function in decode phase, I am able to generate result same as without cache.

With and without out cache:

Q: What is the largest animal?
A: The blue whale.
Q: What is the largest animal?
A: The blue whale. It is the largest animal on Earth. It is also the largest mammal. It is the largest creature that has ever lived.

Repro:

git checkout pchandrasekaran/llama_3b_decode
 
#check whether the transformers version is 4.35.2 otherwise install the transformers version 4.35.2 
pip show transformers
pip install transformers==4.35.2

Test cases to check.
pytest forge/test/mlir/llama/tests/test_llama_decode.py::test_llama_prefill_on_cpu_decode_on_tt_cache -vss

@nvukobratTT
Copy link
Contributor

Great @chandrasekaranpradeep! Let's see if we can modify it on the current version of transformers on FFE :)) Also, please test it out with longer seq len :)) It'll be good to see demo with few thousand tokens generated ;)

@dgolubovicTT @chandrasekaranpradeep once we confirm that this works properly, what do you think about documenting this publicly? This approach can be good workaround until we bringup more LLMS and start pushing on porting KV cache on metal :))

Let me know your thoughts :))

@chandrasekaranpradeep
Copy link
Contributor

Great @chandrasekaranpradeep! Let's see if we can modify it on the current version of transformers on FFE :)) Also, please test it out with longer seq len :)) It'll be good to see demo with few thousand tokens generated ;)

Sure I will check with max new token = 1000 and also how we modify it with current transformers version on FFE.

@chandrasekaranpradeep
Copy link
Contributor

@dgolubovicTT @chandrasekaranpradeep once we confirm that this works properly, what do you think about documenting this publicly? This approach can be good workaround until we bringup more LLMS and start pushing on porting KV cache on metal :))

Sure

@dgolubovicTT
Copy link
Contributor Author

We have a few iterations to cover, and then when we have working KV cache on our device, we can document publicly how we managed to enable it.

For now I pushed methods for comparing K and V tensors throughout the decoder layers
dgolubovic/debug-pchandrasekaran/llama_decode_tt. @chandrasekaranpradeep you can use them to verify if we now get the correct K and V matrices... We can compare hidden states too.

When we are sure that padded KV cache yields the same results on CPU, we will just switch to compiled_model and test that one.

@chandrasekaranpradeep
Copy link
Contributor

For now I pushed methods for comparing K and V tensors throughout the decoder layers
dgolubovic/debug-pchandrasekaran/llama_decode_tt. @chandrasekaranpradeep you can use them to verify if we now get the correct K and V matrices... We can compare hidden states too.

I will check and let you know guys here

@chandrasekaranpradeep
Copy link
Contributor

chandrasekaranpradeep commented Oct 25, 2024

@nvukobratTT @dgolubovicTT The current transformers version(4.41.0) in tt-forge-fe, will not affect the padded cache test cases for running on cpu but it will throws issue when you to try to compile the model on TT in TVM to Forge Op conversion. Below issue will be thrown

            if run_on_tt_device:
                # Compile the model
>               compiled_model = forge.compile(
                    llama_model, sample_inputs=[model_inputs[0], attention_mask, position_ids, model_inputs[1]]
                )

forge/test/mlir/llama/tests/test_llama_decode.py:200: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
forge/forge/compile.py:245: in compile_main
    return forge_compile_from_context(compile_context)
forge/forge/compile.py:287: in forge_compile_from_context
    next_stage = stage_to_func[current_stage](context)
forge/forge/compile.py:636: in generate_initial_graph
    context.graph, context.outputs, context.intermediate_tensors, context.inputs, _ = generate_graph(
forge/forge/compile.py:1072: in generate_graph
    outputs = module.forward(*outputs)
forge/forge/module.py:689: in wrap_forward
    return orig_forward(*args, **kwargs)
generated_modules/LlamaModelWrapper.py:472: in forward
    matmul_340 = forge.op.Matmul("", broadcast_335, transpose_339).set_src_layer("test_llama_decode.LlamaModelWrapper::/transformers.models.llama.modeling_llama.LlamaForCausalLM::model/transformers.models.llama.modeling_llama.LlamaModel::model/transformers.models.llama.modeling_llama.LlamaDecoderLayer::layers.0/transformers.models.llama.modeling_llama.LlamaAttention::self_attn/transformers.models.llama.modeling_llama.LlamaRotaryEmbedding::rotary_emb")
forge/forge/op/matmul.py:36: in Matmul
    result: Tensor = op("matmul", name, operandA, operandB).get_tensor()
forge/forge/op/common.py:82: in get_tensor
    result.set_value(get_f_forge_eval(self.cpp_op_type)(values))
forge/forge/op/eval/forge/__init__.py:220: in <lambda>
    return lambda *inputs: module_or_class.eval(op_type.op, op_type.attr, *inputs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

type = 'matmul', attr = []
ops = [tensor([[[1.0000e+00],
         [8.3176e-01],
         [6.9183e-01],
         [5.7544e-01],
         [4.7863e-01],
  ...e-04],
         [2.0893e-04],
         [1.7378e-04],
         [1.4454e-04],
         [1.2023e-04]]]), tensor([[[12]]])]

    def eval(type, attr, ops):
        assert len(ops) in [2, 3], "Matrix multiply should have two or three inputs"
        assert len(attr) <= 2, "Matrix multiply should have zero to two attributes"
    
        accumulate = (len(attr) >= 1) and bool(attr[0])
        t_ops = to_torch_operands(*ops)
        t_ops, original_type = cast_for_cpu_eval(t_ops, type)
    
        if type == "matmul":
>           result = torch.matmul(t_ops[0], t_ops[1])
E           RuntimeError: expected scalar type Float but found Long

@chandrasekaranpradeep
Copy link
Contributor

chandrasekaranpradeep commented Oct 25, 2024

We shouldn't use fast tokenizer for the openlm-research/open_llama_3b model variants because it produces invalid results.
Instead we can use LlamaTokenizer or AutoTokenizer with use_fast=False option

Reference Link

@dgolubovicTT
Copy link
Contributor Author

If I am not wrong, In our main we are using AutoTokenizer. And what is the difference between AutoTokenizer and LlamaTokeniser? Do we prefer one over the other, and if yes why?

@chandrasekaranpradeep
Copy link
Contributor

AutoTokenizer is helper/generic class for automatically instantiate different tokenizer based upon the model_name that are passed to the from_pretrained method but LlamaTokenizer as it name says it specific for Llama models.
For more info please take a look into https://huggingface.co/docs/transformers/en/model_doc/auto.
AutoTokenizer class code - reference

@chandrasekaranpradeep
Copy link
Contributor

@dgolubovicTT I have compared the past key values tensor and hidden states, there is no tensor mismatch found between past key values matrix with and without padding test cases. I have used compare_tensors from the dgolubovic/debug-pchandrasekaran/llama_decode_tt branch and compare_with_golden_pcc in forge for verifying the past key values and hidden states tensor. But the max, min and scaled difference will not be exactly 0.0 in all the cases, there are close to zeros like 8.344650268554688e-07

Past key values and Hidden state tensor comparison result log:
test_compare_states.log

Repro:

git checkout pchandrasekaran/llama_3b_decode
pytest forge/test/mlir/llama/tests/test_llama_decode_new.py::test_compare_states -vss

@nvukobratTT
Copy link
Contributor

Great job guys!

With this last confirmation, are you confident about merging these changes to the main? :)) If yes, let's also update visuals that @dgolubovicTT created for KV cache, and potentially use that as a reference for updating our FFE Docs :))

@chandrasekaranpradeep
Copy link
Contributor

chandrasekaranpradeep commented Oct 28, 2024

@nvukobratTT When I try to compile the model with/without cache on TT device, I am seeing unused inputs in the generated forge module forward function def forward(self, input_1, unused_input_1):, in the tvm frontend conversion, the attention mask is not traced properly. Can I update FFE Docs after fixing this and above issue and making it to compile upto lowering?
I will update visuals that @dgolubovicTT created for KV cache and merge the current cache and no_cache test cases which will run on cpu(i.e prefill and decode) into main

@mstojkovicTT
Copy link
Contributor

We shouldn't use fast tokenizer for the openlm-research/open_llama_3b model variants because it produces invalid results. Instead we can use LlamaTokenizer or AutoTokenizer with use_fast=False option

Reference Link

If i am not wrong, this should be fixed with the following: huggingface/transformers#24233

@dgolubovicTT
Copy link
Contributor Author

@nvukobratTT When I try to compile the model with/without cache on TT device, I am seeing unused inputs in the generated forge module forward function def forward(self, input_1, unused_input_1):, in the tvm frontend conversion, the attention mask is not traced properly. Can I update FFE Docs after fixing this and above issue and making it to compile upto lowering? I will update visuals that @dgolubovicTT created for KV cache and merge the current cache and no_cache test cases which will run on cpu(i.e prefill and decode) into main

Good plan. Let's sync offline if you have any blockers...

@dgolubovicTT
Copy link
Contributor Author

Let's first do compile of llama with KV cache and then we will follow up with documentation...

@chandrasekaranpradeep
Copy link
Contributor

@nvukobratTT @dgolubovicTT I have resolved the unused input in forge module forward function issue by calculating the causal_mask inside the LlamaModelWrapper and above datatype issue by supporting cast op in forge. Now when we compile the model, we will hit Found Unsupported operations while lowering from TTForge to TTIR in forward graph error in lowering to MLIR compilation stage. The error is expected in lowering to MLIR stage, so compiling the model upto SPLIT GRAPH and created a PR for the llama 3B decode cache and no cache test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
exploration Needs more research before making concrete changes
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants