-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mamba generation throughput lower than original due to DecodingCGCache #29699
Comments
Hi! Thanks for raising this issue. I agree. I've realized the same since some weeks ago while testing Mamba, as previously I was using the mamba-ssm repo and was able to to generation much faster using the DecodingCGCache. I've been tracking that other issue too, as although it seems just variables refactor it includes a change in the prepare_inputs_for_generation function passing the kwargs (https://github.com/huggingface/transformers/pull/29605/files#diff-e1d4758c08973fdac2c23a8a3710872d943ce8509035205da4a681bc4dcaf1c3R694). I didn't created any issue as I wondered that PR will be merged soon due to its simplicity, but seems it is not. Also, I'm not sure if that PR solves all the issue. |
Hey! Thanks both I'll dive a bit on this, contributions are also welcome. |
Hi folks 👋 It doesn't look like a caching issue, but a compilation one -- under the hood, the Our implementation is not compatible with @ArthurZucker in case you want to reproduce: from transformers import MambaForCausalLM
import torch
from time import time
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", device_map="auto")
model.eval()
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_ids = torch.arange(100, device="cuda").unsqueeze(0)
max_length = 1000
start = time()
model = torch.compile(model)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
)
print(f"Time: {time() - start:.2f}s")
print(out.shape) |
Might be the |
@gante Nice catch and thank you for your effort to address it! |
#29544 is what made me think of that! |
Hi! Thanks for the findings :) |
I did not test the compilation with mamba, I right now can't investigate 😢 would love if you can. |
I'll set this as a feature request: add support for |
Hi @ArthurZucker :) |
Added it to our generate + torch.compile tracker, which we are actively tackling :) |
yes, same issue. is there any update for the solution? |
Nice to see that updates! Are there any other pending issues to complete this 29699 one? Thanks! |
I think it's completed |
Alright! Was the code above re-tested by anyone to compare speed again? Using which transformers release? Thanks! |
@ArthurZucker Thank you for the updates on the nice job! I have tested the throughput again, there seems some improvement, however, the difference remains significant. Package info:
I run the generation with the same 100 prompt length and 1000 generation length 5 times and get the average:
I assume the reason of the generation speed difference would be:
Reference:
|
When you said you tested, did you use |
Did you read this #31247 (comment) in the linked PR for compile support ? |
|
想问下基于tensformer调用模型的吞吐量测试代码是在哪里找的? |
可以在这里找到: @torch.no_grad
def perf():
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer("Hey how are you doing today ? " * 100, return_tensors="pt", padding=True).to('cuda')
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16)
model.config.use_cache = True
model.to('cuda')
input_ids = inputs.input_ids
cache = MambaCache(model.config, 1, device=input_ids.device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
logits = model(input_ids, cache_params = cache).logits
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
torch.cuda.synchronize()
for i in range(10):
start.record()
logits = model(next_token.clone(), cache_params = cache).logits
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
end.record()
torch.cuda.synchronize()
print(f'Step {i}, Total time: {start.elapsed_time(end)} ms, next_token = {next_token.int()}') |
System Info
Python 3.10.13, CUDA 12.1
GPU = NVIDIA GeForce RTX 2080 Ti. Max memory = 10.747 GB.
torch==2.2.1
torchaudio==2.1.0
torchvision==0.16.0
tokenizers==0.15.2
transformers ==git+https://github.com/huggingface/transformers@dd1c9052159ae824c8acef7c2552f9fad5ca020a
triton==2.2.0
causal_conv1d==git+https://github.com/Dao-AILab/causal-conv1d.git@96456720c00393a5c32872d8352d7a7ec31fb3db#egg=causal_conv1d
mamba_ssm==git+https://github.com/state-spaces/mamba.git@9127d1f47f367f5c9cc49c73ad73557089d02cb8#egg=mamba_ssm
Who can help?
text models: @ArthurZucker and @younesbelkada
generate: @gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The key model initialization and generation parts are given as below.
Original code repo
In the original code repo
Then throughput for generating 1K length is
Using the HF library
Then throughput for generating 1K length is
Expected behavior
The "cg=True" is confirmed to be the part has a significant impact on the generation performance for mamba.
I have tried:
I assume this is related to the #29605, but modifying the argument directly seems not solving the problem.
The text was updated successfully, but these errors were encountered: