Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tracer] Remove SelfAttention renaming #44

Merged
merged 1 commit into from
Feb 7, 2023
Merged

Conversation

chhzh123
Copy link
Contributor

@chhzh123 chhzh123 commented Feb 7, 2023

Description

This PR fixes #13 and #30 by removing the legacy code of renaming self attention module for HuggingFace models. As now we only trace the module when needed and also replace the attention module with our own module, the naming issue probably may not be a problem. We also do not refer to self_m in our example schedules, so I remove the renaming part from the tracer, and only prompt a warning when the tracer meets those self modules.

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @szhengac @comaniac

Comment on lines +41 to +43
logger.warning(
"`self` in {root.__class__.__name__} is a Python keyword, please rename it to avoid possible conflicts."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a warning. What happen if we keep running?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the traced module and it can execute normally. If we only call the SelfAttention module with the self prefix, it seems to work. I forget why at the first place I met the naming error, which was probably related to the previous implementation of the tracer.

class Model(nn.Module):
  def __init__(self):
    self.self = ...

  def forward(self, x):
    x = self.self(x)
    ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok got it.

@comaniac comaniac merged commit cd22329 into awslabs:main Feb 7, 2023
@comaniac
Copy link
Contributor

comaniac commented Feb 7, 2023

Thanks @chhzh123

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] Buffer is not maintained and redundant buffer is returned after tracing bert example
2 participants