-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Schedule] Add .fuse() primitive #25
Conversation
Seems like we could ask for sample inputs in kwargs if required by the compiler backend. For example: |
It seems this PR requires several changes to the subgraph matching mechanism. I will open a new PR to make step 1 & 2 work first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Later we should apply .fuse to the examples, which currently replace an entire MLP for bias_new_gelu.
slapo/pattern.py
Outdated
@@ -9,5 +9,14 @@ def forward(self, *args): | |||
raise NotImplementedError | |||
|
|||
|
|||
class CallModule(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names of CallModule and call_module are confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are just used in different places. CallModule
is used in a pattern class, while call_module
can be used in a pattern function. Do you have any suggestions about the naming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm base on your use cases, such as
class LinearReLUPattern(slapo.Pattern):
def __init__(self):
super().__init__()
self.fc = CallModule(r"fc?")
self.relu = nn.ReLU()
Here you're actually constructing a module instead of calling it, so I guess you could use the name like ModulePattern
or something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may also lead to confusion with the Pattern
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, to me it is a kind of pattern, especially you have a regex in its argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree it is a pattern, and ModulePattern
is a proper name, but since we have the Pattern
base class which is also a module, I'm afraid it will cause confusion between these two.
Sure, I'll add it in the next PR. |
Thanks @chhzh123 |
Description
This PR adds a new primitive called
.fuse(subgraph, compiler)
for operator fusion. Currently we only support pattern-based vertical operator fusion using TorchScript as the backend compiler. A simple example is shown below.The fusion performance needs to be tested. As I first create an identical torch.fx subgraph and pass it into TorchScript's scripting mode for optimization, it may cause performance issue if some of the operators are not recognized by TorchScript. Tracing mode can achieve the best performance but it is hard to leverage since we cannot always obtain example inputs for each subgraph. Scripting a user-defined function is also not a good approach since the backward pass is not captured by TorchScript. Therefore, only scripting an entire module is a good fit for our case, and we need to test the compatibility of torch.fx and TorchScript.
Checklist
torch.nn.functional
to match modules intorch.nn
. For example,F.relu
andnn.ReLU
should be treated as the same in pattern matching, since users cannot specify a module in a function pattern. ([Schedule] Refactor subgraph matching #35).decompose()
primitive and support decoupling bias fromnn.Linear
flatten
argument to .trace() #29)Future plan (Updated)
The following features will be added in separate PRs.