Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE: generate compatible with torch.compile(fullgraph=True) #29374

Closed
wants to merge 20 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Feb 29, 2024

What does this PR do?

This PR is a 🔪 mangled🔪 version of generate where torch.compile(model.generate, fullgraph=True) works and returns the same values. It should NOT be merged, but rather be used as a reference -- other PRs will be created that push the needed changes, once at a time, to ensure we don't break other features.


Script to test correctness

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
import copy

torch_device = "cuda"

EXPECTED_GENERATION = [
    "The best color is the one that complements the skin tone of the",
    "We should not undermind the issues at hand.\nWe should not undermind the issues",
]

tokenizer = AutoTokenizer.from_pretrained(
    "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
    "NousResearch/Llama-2-7b-chat-hf",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
).to(torch_device)
inputs = tokenizer(
    ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)

generation_kwargs = {
    "do_sample": False,
    "max_new_tokens": 10,
}

print("Dynamic cache")
gen_out = model.generate(**inputs, **generation_kwargs)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print(decoded)
assert decoded == EXPECTED_GENERATION

print("Static cache")
model.generation_config.cache_implementation = "static"
gen_out = model.generate(**inputs, **generation_kwargs)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
assert decoded == EXPECTED_GENERATION
print(decoded)

print("Compiled static cache")
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
gen_out = compiled_generate(**inputs, generation_config=generation_config)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print(decoded)
assert decoded == EXPECTED_GENERATION

fixes #27837

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@amyeroberts
Copy link
Collaborator

@gante If this PR is going to be long lived - you can add the WIP label and it will stop the bot closing if stale

@ArthurZucker ArthurZucker added WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress Compilation Issues related to torchdynamo and torchinductor Cache labels Apr 22, 2024
@gante gante mentioned this pull request May 13, 2024
3 tasks
@gante
Copy link
Member Author

gante commented May 29, 2024

Closed in favor of #30788

@gante gante closed this May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cache Compilation Issues related to torchdynamo and torchinductor WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch CUDA graphs with HF generate
4 participants