Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce experimental FX library #42741

Closed
wants to merge 11 commits into from

Conversation

jamesr66a
Copy link
Collaborator

@jamesr66a jamesr66a commented Aug 7, 2020

Stack from ghstack:

This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an nn.Module instance as its argument, GraphModule will trace through the computation of that Module's forward method symbolically and record those operations in the FX intermediate representation.

import torch
from torch.fx import GraphModule

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)

The Intermediate Representation centers around a 5-opcode format:

from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}

The semantics are as follows:

  • placeholder represents a function input. The name attribute specifies the name this value will take on. target is similarly the name of the argument. args and kwargs are don't-care
  • get_param retrieves a parameter from the module hierarchy. name is similarly the name the result of the fetch is assigned to. target is the fully-qualified name of the parameter's position in the module hierarchy. args and kwargs are don't-care
  • call_function applies a free function to some values. name is similarly the name of the value to assign to. target is the function to be applied. args and kwargs represent the arguments to the function, following the Python calling convention
  • call_module applies a module in the module hierarchy's forward() method to given arguments. name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. args and kwargs represent the arguments to invoke the module on, including the self argument.
  • call_method calls a method on a value. name is as similar. target is the string name of the method to apply to the self argument. args and kwargs represent the arguments to invoke the module on, including the self argument.

GraphModule automatically generates Python code for the operations it symbolically observed:

print(gm.src)
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1

Because this code is valid PyTorch code, the resulting GraphModule can be used in any context another nn.Module can be used, including in TorchScript tracing/compilation.

Differential Revision: D23006383

@dr-ci
Copy link

dr-ci bot commented Aug 7, 2020

💊 CI failures summary and remediations

As of commit 515d9d6 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 38 times.

@jamesr66a jamesr66a requested review from zdevito and suo August 7, 2020 19:28
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation.

```
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)
```

The Intermediate Representation centers around a 5-opcode format:

```
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
```

```
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}
```

The semantics are as follows:

- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.

GraphModule automatically generates Python code for the operations it symbolically observed:

```
print(gm.src)
```

```
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1
```

Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation.

Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383)

[ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation.

```
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)
```

The Intermediate Representation centers around a 5-opcode format:

```
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
```

```
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}
```

The semantics are as follows:

- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.

GraphModule automatically generates Python code for the operations it symbolically observed:

```
print(gm.src)
```

```
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1
```

Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation.

Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383)

[ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation.

```
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)
```

The Intermediate Representation centers around a 5-opcode format:

```
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
```

```
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}
```

The semantics are as follows:

- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.

GraphModule automatically generates Python code for the operations it symbolically observed:

```
print(gm.src)
```

```
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1
```

Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation.

Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383)

[ghstack-poisoned]
jamesr66a pushed a commit that referenced this pull request Aug 7, 2020
ghstack-source-id: 04a9dc36c6a7d56a74be2d46cceb73a2f91921f3
Pull Request resolved: #42741
@dzhulgakov dzhulgakov self-requested a review August 7, 2020 22:45
def __iter__(self):
frame = inspect.currentframe()
calling_frame = frame.f_back
inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment on what this is doing? trying to allow only unpack sequence iteration?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this code is super confusing, who wrote this!? (I wrote this...) Normally, unpacking a tuple would try to call __iter__. In most cases, we want to error on calling __iter__ because the length of the symbolic value being iterated is unknown. However, in the pattern matching case x, y = rhs we know rhs must have length 2, otherwise this will error. This code snoops the callers code to see if an unpack is taking place and that there are two elements being matched. It then returns two things in the iterator to make sure the match succeeds. This makes things like torch.maxi which returns two values traceable when it otherwise wouldn't be. However, it isn't strictly needed since you can always rewrite the calling code to explicitly match x = rhs[0]; y = rhs[1].

test/test_fx.py Show resolved Hide resolved
test/test_fx.py Show resolved Hide resolved
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I just have minor organization comments.

torch/fx/graph_module.py Outdated Show resolved Hide resolved
torch/fx/graph_module.py Outdated Show resolved Hide resolved
torch/fx/graph_module.py Show resolved Hide resolved
Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a general question, will we use this file as the source of the truth? :-)

cc: @jamesr66a @zdevito @dzhulgakov

**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation.

```
import torch
from torch.fx import GraphModule

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)
```

The Intermediate Representation centers around a 5-opcode format:

```
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
```

```
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}
```

The semantics are as follows:

- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.

GraphModule automatically generates Python code for the operations it symbolically observed:

```
print(gm.src)
```

```
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1
```

Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation.

Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383)

[ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation.

```
import torch
from torch.fx import GraphModule

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)
```

The Intermediate Representation centers around a 5-opcode format:

```
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
```

```
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}
```

The semantics are as follows:

- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.

GraphModule automatically generates Python code for the operations it symbolically observed:

```
print(gm.src)
```

```
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1
```

Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation.

Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383)

[ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation.

```
import torch
from torch.fx import GraphModule

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)
```

The Intermediate Representation centers around a 5-opcode format:

```
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
```

```
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}
```

The semantics are as follows:

- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.

GraphModule automatically generates Python code for the operations it symbolically observed:

```
print(gm.src)
```

```
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1
```

Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation.

Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383)

[ghstack-poisoned]
@jamesr66a
Copy link
Collaborator Author

Also a general question, will we use this file as the source of the truth? :-)

Yes, that's the intent!

**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk**

FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation.

```
import torch
from torch.fx import GraphModule

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
    
m = MyModule()
gm = GraphModule(m)
```

The Intermediate Representation centers around a 5-opcode format:

```
from tabulate import tabulate
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes]
print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']))
```

```
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_param      linear_weight  linear.weight                                            ()                  {}
call_function  add_1          <built-in function add>                                  (x, linear_weight)  {}
call_module    linear_1       linear                                                   (add_1,)            {}
call_method    relu_2         relu                                                     [linear_1]          {}
call_function  sum_1          <built-in method sum of type object at 0x7f1c29dd36e0>   (relu_2,)           {'dim': -1}
call_function  topk_1         <built-in method topk of type object at 0x7f1c29dd36e0>  (sum_1, 3)          {}
```

The semantics are as follows:

- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.

GraphModule automatically generates Python code for the operations it symbolically observed:

```
print(gm.src)
```

```
def forward(self, x):
    self = self.root
    linear_weight = self.linear.weight
    add_1 = x + linear_weight
    linear_1 = self.linear(add_1)
    relu_2 = linear_1.relu()
    sum_1 = torch.sum(relu_2, dim = -1)
    topk_1 = torch.topk(sum_1, 3)
    

    return topk_1
```

Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation.

Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383)

[ghstack-poisoned]
jamesr66a pushed a commit that referenced this pull request Aug 10, 2020
ghstack-source-id: b5c3ce6ea01ad230882c1b462e4fe0fe2dd7773a
Pull Request resolved: #42741
@jamesr66a jamesr66a requested a review from zdevito August 10, 2020 17:55
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs
        - fx - implementations
          - fuse.py - fuse conv bn
          - quantize.py - implementation for Quantizer
          - pattern_utils.py - utility function for patterns
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c8f6f52e16bbf54ac9dc34bc7446cc08e8a5de7b
Pull Request resolved: #43175
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Notice that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f7a704a74045043821b242c307eef56cf1fe86e3
Pull Request resolved: #43175
jerryzh168 added a commit that referenced this pull request Aug 18, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 19, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 19, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 19, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 19, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 20, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 20, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Aug 20, 2020
Summary:
This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

File structure
```
torch/quantization
        - _quantize_fx.py - top level APIs, currently Quantizer and fuse
        - fx - implementations
          - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu
          - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well
          - pattern_utils.py - utility function for pattern registration
          - utils.py - other utility functions

test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization
```
Note that the current API is not final.

Next Step:
- Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Aug 20, 2020
Summary:
Pull Request resolved: #43175

This PR added graph mode quantization on fx: #42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

Test Plan:
python test/test_quantization.py TestQuantizeFx

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D23178602

fbshipit-source-id: 8e7e0322846fbda2cfa79ad188abd7235326f879
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants