Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Schedule] Add .fuse() primitive #25

Merged
merged 23 commits into from
Feb 2, 2023
Merged

Conversation

chhzh123
Copy link
Contributor

@chhzh123 chhzh123 commented Jan 29, 2023

Description

This PR adds a new primitive called .fuse(subgraph, compiler) for operator fusion. Currently we only support pattern-based vertical operator fusion using TorchScript as the backend compiler. A simple example is shown below.

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 32, 3)

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        x = x + 1
        return x

mod = Model()
sch = slapo.create_schedule(mod)

# Define a function pattern whose ops will be fused
def pattern(x: torch.Tensor):
    x = F.relu(x)
    x = x + 1
    return x

# Find the subgraph in the original module that matches the specified pattern
subgraph = sch.find(pattern)
# Fuse the subgraph
sch.fuse(subgraph, compiler="TorchScript", name="FusedReLU")

The fusion performance needs to be tested. As I first create an identical torch.fx subgraph and pass it into TorchScript's scripting mode for optimization, it may cause performance issue if some of the operators are not recognized by TorchScript. Tracing mode can achieve the best performance but it is hard to leverage since we cannot always obtain example inputs for each subgraph. Scripting a user-defined function is also not a good approach since the backward pass is not captured by TorchScript. Therefore, only scripting an entire module is a good fit for our case, and we need to test the compatibility of torch.fx and TorchScript.

Checklist

Future plan (Updated)

The following features will be added in separate PRs.

@comaniac
Copy link
Contributor

Seems like we could ask for sample inputs in kwargs if required by the compiler backend. For example: compiler="torchscript", example_inputs=...

@chhzh123
Copy link
Contributor Author

It seems this PR requires several changes to the subgraph matching mechanism. I will open a new PR to make step 1 & 2 work first.

epoi Outdated Show resolved Hide resolved
@chhzh123 chhzh123 changed the title [WIP][Schedule] Add .fuse() primitive [Schedule] Add .fuse() primitive Feb 2, 2023
@chhzh123 chhzh123 mentioned this pull request Feb 2, 2023
15 tasks
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Later we should apply .fuse to the examples, which currently replace an entire MLP for bias_new_gelu.

examples/t5/schedule.py Show resolved Hide resolved
slapo/pattern.py Outdated
@@ -9,5 +9,14 @@ def forward(self, *args):
raise NotImplementedError


class CallModule(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names of CallModule and call_module are confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are just used in different places. CallModule is used in a pattern class, while call_module can be used in a pattern function. Do you have any suggestions about the naming?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm base on your use cases, such as

    class LinearReLUPattern(slapo.Pattern):
        def __init__(self):
            super().__init__()
            self.fc = CallModule(r"fc?")
            self.relu = nn.ReLU()

Here you're actually constructing a module instead of calling it, so I guess you could use the name like ModulePattern or something like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may also lead to confusion with the Pattern class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, to me it is a kind of pattern, especially you have a regex in its argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree it is a pattern, and ModulePattern is a proper name, but since we have the Pattern base class which is also a module, I'm afraid it will cause confusion between these two.

@chhzh123
Copy link
Contributor Author

chhzh123 commented Feb 2, 2023

Later we should apply .fuse to the examples, which currently replace an entire MLP for bias_new_gelu.

Sure, I'll add it in the next PR.

@comaniac comaniac merged commit 2fe133a into awslabs:main Feb 2, 2023
@comaniac
Copy link
Contributor

comaniac commented Feb 2, 2023

Thanks @chhzh123

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants