Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Schedule] Refactor subgraph matching #35

Merged
merged 21 commits into from
Feb 1, 2023

Conversation

chhzh123
Copy link
Contributor

@chhzh123 chhzh123 commented Feb 1, 2023

Description

This PR refactors the subgraph matching API and changes it to s.find(regex_or_pattern_fn) that is presented in the paper. Also, it adds support for more general patterns (e.g., subgraphs with multiple input arguments). A lightweight pattern language is also proposed to support fuzzy matching.

Basically, there are several use cases:

# 1. Find a `call_module` node in the dataflow graph
#    we can use regex to express the module name pattern
sch.find(r"fc[0-9]")

# 2. Find a general node in the dataflow graph
#    To avoid confusion, this case is separated from the `.find()` API
#    and users can only use `.find_node()` to achieve this goal
sch.find_node(lambda node: node.name == "fc")

# 3. Find a subgraph in the dataflow graph
#    We provide different kinds of matching facilities to enable maximum flexibility
# a) Exact matching
def pattern(x: torch.Tensor):
    return F.relu(x) + x
sch.find(pattern)

# b) Use nn.functional to match a nn.module
def pattern(x: torch.Tensor):
    # `F.relu` can match `nn.ReLU` in the graph
    return F.relu(x) + x
sch.find(pattern)

# c) Use Pattern class to specify a pattern
from slapo.pattern import Pattern
class ReLUAddPattern(Pattern):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        return self.relu(x) + x
sch.find(ReLUAddPattern())

# d) Use pattern language
#    Here we provide `call` function as a wildcard that can be match any
#    nodes satisfying the constraints
from slapo.pattern import call
def pattern(x: torch.Tensor):
    return call("relu", x) + x
sch.find(pattern)

By leveraging the pattern language, we can express complex patterns in a very simple way, like the below horizontal QKV example.

def pattern(x):
    x = call(r"[qkv]_proj", x)
    return x.permute(0, 2, 1, 3)
subgraph = sch.find(pattern)

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @comaniac

@chhzh123 chhzh123 mentioned this pull request Feb 1, 2023
5 tasks
@comaniac
Copy link
Contributor

comaniac commented Feb 1, 2023

Nice done. Will review tomorrow.

@chhzh123 chhzh123 mentioned this pull request Feb 1, 2023
15 tasks
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Just need some minor refactors.

examples/albert/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
slapo/schedule.py Show resolved Hide resolved
slapo/schedule.py Show resolved Hide resolved
slapo/schedule.py Show resolved Hide resolved
slapo/schedule.py Outdated Show resolved Hide resolved
tests/test_subgraph.py Outdated Show resolved Hide resolved
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
nit:

        Returns
        ----------

Should be

        Returns
        -------

But you could fix it in later PRs.

@comaniac comaniac merged commit 3393bdf into awslabs:main Feb 1, 2023
@comaniac
Copy link
Contributor

comaniac commented Feb 1, 2023

Thanks @chhzh123

@chhzh123 chhzh123 mentioned this pull request May 29, 2023
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants