Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace python random with torch.rand to enable dynamo.export #24434

Merged
merged 5 commits into from
Jun 23, 2023

Conversation

BowenBao
Copy link
Contributor

@BowenBao BowenBao commented Jun 22, 2023

What does this PR do?

Related and Fixes pytorch/pytorch#102794

TL;DR dynamo graph breaks on python random.uniform(0, 1). The graph break can be prevented by replacing with torch.randn([]).

Example repro script

import torch
import torch._dynamo

from transformers import AutoTokenizer, BartForCausalLM

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

torch._dynamo.export(model, return_dict=False, **inputs)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are touching multiple Flax model which shouldn't depend on torch, could you revert that?

@BowenBao
Copy link
Contributor Author

You are touching multiple Flax model which shouldn't depend on torch, could you revert that?

Nice catch! Done.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! @amyeroberts could you have a quick second look here?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 23, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

Changes LGTM 👍 I think technically this is slightly different as random.uniform can return 1, but for all practical purposes it doesn't change things.

@sgugger sgugger merged commit a28325e into huggingface:main Jun 23, 2023
@anijain2305
Copy link
Contributor

@BowenBao Does this really solve the export problem. We are seeing export issues here - pytorch/pytorch#107587

If one peeks at the tensor value for the conditional, its a legit dynamic control flow. We might have to use torch.where or torch.cond.

@BowenBao
Copy link
Contributor Author

@anijain2305 it does for inference export, since the actual condition is short circuited by self.training being False. Without the change, random.uniform(0, 1) leads to a graph break, although the value is unused.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[dynamo.export] AssertionError: Dynamo attempts to add additional input during export for Huggingface Bart
5 participants