-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce experimental FX library #42741
Conversation
[ghstack-poisoned]
Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 515d9d6 (more details on the Dr. CI page):
Extra GitHub checks: 1 failed
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 38 times. |
Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. ``` class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = GraphModule(m) ``` The Intermediate Representation centers around a 5-opcode format: ``` from tabulate import tabulate node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) ``` ``` opcode name target args kwargs ------------- ------------- ------------------------------------------------------- ------------------ ----------- placeholder x x () {} get_param linear_weight linear.weight () {} call_function add_1 <built-in function add> (x, linear_weight) {} call_module linear_1 linear (add_1,) {} call_method relu_2 relu [linear_1] {} call_function sum_1 <built-in method sum of type object at 0x7f1c29dd36e0> (relu_2,) {'dim': -1} call_function topk_1 <built-in method topk of type object at 0x7f1c29dd36e0> (sum_1, 3) {} ``` The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care - `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. GraphModule automatically generates Python code for the operations it symbolically observed: ``` print(gm.src) ``` ``` def forward(self, x): self = self.root linear_weight = self.linear.weight add_1 = x + linear_weight linear_1 = self.linear(add_1) relu_2 = linear_1.relu() sum_1 = torch.sum(relu_2, dim = -1) topk_1 = torch.topk(sum_1, 3) return topk_1 ``` Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation. Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. ``` class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = GraphModule(m) ``` The Intermediate Representation centers around a 5-opcode format: ``` from tabulate import tabulate node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) ``` ``` opcode name target args kwargs ------------- ------------- ------------------------------------------------------- ------------------ ----------- placeholder x x () {} get_param linear_weight linear.weight () {} call_function add_1 <built-in function add> (x, linear_weight) {} call_module linear_1 linear (add_1,) {} call_method relu_2 relu [linear_1] {} call_function sum_1 <built-in method sum of type object at 0x7f1c29dd36e0> (relu_2,) {'dim': -1} call_function topk_1 <built-in method topk of type object at 0x7f1c29dd36e0> (sum_1, 3) {} ``` The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care - `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. GraphModule automatically generates Python code for the operations it symbolically observed: ``` print(gm.src) ``` ``` def forward(self, x): self = self.root linear_weight = self.linear.weight add_1 = x + linear_weight linear_1 = self.linear(add_1) relu_2 = linear_1.relu() sum_1 = torch.sum(relu_2, dim = -1) topk_1 = torch.topk(sum_1, 3) return topk_1 ``` Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation. Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. ``` class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = GraphModule(m) ``` The Intermediate Representation centers around a 5-opcode format: ``` from tabulate import tabulate node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) ``` ``` opcode name target args kwargs ------------- ------------- ------------------------------------------------------- ------------------ ----------- placeholder x x () {} get_param linear_weight linear.weight () {} call_function add_1 <built-in function add> (x, linear_weight) {} call_module linear_1 linear (add_1,) {} call_method relu_2 relu [linear_1] {} call_function sum_1 <built-in method sum of type object at 0x7f1c29dd36e0> (relu_2,) {'dim': -1} call_function topk_1 <built-in method topk of type object at 0x7f1c29dd36e0> (sum_1, 3) {} ``` The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care - `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. GraphModule automatically generates Python code for the operations it symbolically observed: ``` print(gm.src) ``` ``` def forward(self, x): self = self.root linear_weight = self.linear.weight add_1 = x + linear_weight linear_1 = self.linear(add_1) relu_2 = linear_1.relu() sum_1 = torch.sum(relu_2, dim = -1) topk_1 = torch.topk(sum_1, 3) return topk_1 ``` Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation. Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
ghstack-source-id: 04a9dc36c6a7d56a74be2d46cceb73a2f91921f3 Pull Request resolved: #42741
torch/fx/graph_module.py
Outdated
def __iter__(self): | ||
frame = inspect.currentframe() | ||
calling_frame = frame.f_back | ||
inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment on what this is doing? trying to allow only unpack sequence iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this code is super confusing, who wrote this!? (I wrote this...) Normally, unpacking a tuple would try to call __iter__
. In most cases, we want to error on calling __iter__
because the length of the symbolic value being iterated is unknown. However, in the pattern matching case x, y = rhs
we know rhs
must have length 2, otherwise this will error. This code snoops the callers code to see if an unpack is taking place and that there are two elements being matched. It then returns two things in the iterator to make sure the match succeeds. This makes things like torch.maxi
which returns two values traceable when it otherwise wouldn't be. However, it isn't strictly needed since you can always rewrite the calling code to explicitly match x = rhs[0]; y = rhs[1]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I just have minor organization comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a general question, will we use this file as the source of the truth? :-)
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. ``` import torch from torch.fx import GraphModule class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = GraphModule(m) ``` The Intermediate Representation centers around a 5-opcode format: ``` from tabulate import tabulate node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) ``` ``` opcode name target args kwargs ------------- ------------- ------------------------------------------------------- ------------------ ----------- placeholder x x () {} get_param linear_weight linear.weight () {} call_function add_1 <built-in function add> (x, linear_weight) {} call_module linear_1 linear (add_1,) {} call_method relu_2 relu [linear_1] {} call_function sum_1 <built-in method sum of type object at 0x7f1c29dd36e0> (relu_2,) {'dim': -1} call_function topk_1 <built-in method topk of type object at 0x7f1c29dd36e0> (sum_1, 3) {} ``` The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care - `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. GraphModule automatically generates Python code for the operations it symbolically observed: ``` print(gm.src) ``` ``` def forward(self, x): self = self.root linear_weight = self.linear.weight add_1 = x + linear_weight linear_1 = self.linear(add_1) relu_2 = linear_1.relu() sum_1 = torch.sum(relu_2, dim = -1) topk_1 = torch.topk(sum_1, 3) return topk_1 ``` Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation. Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. ``` import torch from torch.fx import GraphModule class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = GraphModule(m) ``` The Intermediate Representation centers around a 5-opcode format: ``` from tabulate import tabulate node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) ``` ``` opcode name target args kwargs ------------- ------------- ------------------------------------------------------- ------------------ ----------- placeholder x x () {} get_param linear_weight linear.weight () {} call_function add_1 <built-in function add> (x, linear_weight) {} call_module linear_1 linear (add_1,) {} call_method relu_2 relu [linear_1] {} call_function sum_1 <built-in method sum of type object at 0x7f1c29dd36e0> (relu_2,) {'dim': -1} call_function topk_1 <built-in method topk of type object at 0x7f1c29dd36e0> (sum_1, 3) {} ``` The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care - `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. GraphModule automatically generates Python code for the operations it symbolically observed: ``` print(gm.src) ``` ``` def forward(self, x): self = self.root linear_weight = self.linear.weight add_1 = x + linear_weight linear_1 = self.linear(add_1) relu_2 = linear_1.relu() sum_1 = torch.sum(relu_2, dim = -1) topk_1 = torch.topk(sum_1, 3) return topk_1 ``` Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation. Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. ``` import torch from torch.fx import GraphModule class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = GraphModule(m) ``` The Intermediate Representation centers around a 5-opcode format: ``` from tabulate import tabulate node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) ``` ``` opcode name target args kwargs ------------- ------------- ------------------------------------------------------- ------------------ ----------- placeholder x x () {} get_param linear_weight linear.weight () {} call_function add_1 <built-in function add> (x, linear_weight) {} call_module linear_1 linear (add_1,) {} call_method relu_2 relu [linear_1] {} call_function sum_1 <built-in method sum of type object at 0x7f1c29dd36e0> (relu_2,) {'dim': -1} call_function topk_1 <built-in method topk of type object at 0x7f1c29dd36e0> (sum_1, 3) {} ``` The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care - `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. GraphModule automatically generates Python code for the operations it symbolically observed: ``` print(gm.src) ``` ``` def forward(self, x): self = self.root linear_weight = self.linear.weight add_1 = x + linear_weight linear_1 = self.linear(add_1) relu_2 = linear_1.relu() sum_1 = torch.sum(relu_2, dim = -1) topk_1 = torch.topk(sum_1, 3) return topk_1 ``` Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation. Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
Yes, that's the intent! |
**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. ``` import torch from torch.fx import GraphModule class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = GraphModule(m) ``` The Intermediate Representation centers around a 5-opcode format: ``` from tabulate import tabulate node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) ``` ``` opcode name target args kwargs ------------- ------------- ------------------------------------------------------- ------------------ ----------- placeholder x x () {} get_param linear_weight linear.weight () {} call_function add_1 <built-in function add> (x, linear_weight) {} call_module linear_1 linear (add_1,) {} call_method relu_2 relu [linear_1] {} call_function sum_1 <built-in method sum of type object at 0x7f1c29dd36e0> (relu_2,) {'dim': -1} call_function topk_1 <built-in method topk of type object at 0x7f1c29dd36e0> (sum_1, 3) {} ``` The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` and `kwargs` are don't-care - `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. GraphModule automatically generates Python code for the operations it symbolically observed: ``` print(gm.src) ``` ``` def forward(self, x): self = self.root linear_weight = self.linear.weight add_1 = x + linear_weight linear_1 = self.linear(add_1) relu_2 = linear_1.relu() sum_1 = torch.sum(relu_2, dim = -1) topk_1 = torch.topk(sum_1, 3) return topk_1 ``` Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another `nn.Module` can be used, including in TorchScript tracing/compilation. Differential Revision: [D23006383](https://our.internmc.facebook.com/intern/diff/D23006383) [ghstack-poisoned]
ghstack-source-id: b5c3ce6ea01ad230882c1b462e4fe0fe2dd7773a Pull Request resolved: #42741
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs - fx - implementations - fuse.py - fuse conv bn - quantize.py - implementation for Quantizer - pattern_utils.py - utility function for patterns - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c8f6f52e16bbf54ac9dc34bc7446cc08e8a5de7b Pull Request resolved: #43175
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Notice that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f7a704a74045043821b242c307eef56cf1fe86e3 Pull Request resolved: #43175
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip File structure ``` torch/quantization - _quantize_fx.py - top level APIs, currently Quantizer and fuse - fx - implementations - fuse.py - fuse conv bn, similar to eager mode fusion but also works for functional relu - quantize.py - implementation for Quantizer, static/qat/dynamic patterns are registered here as well - pattern_utils.py - utility function for pattern registration - utils.py - other utility functions test/quantization/test_quantize_fx.py -- test for fx based graph mode quantization ``` Note that the current API is not final. Next Step: - Add support for all quantized ops that's implemented right now: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23178602](https://our.internmc.facebook.com/intern/diff/D23178602) [ghstack-poisoned]
Summary: Pull Request resolved: #43175 This PR added graph mode quantization on fx: #42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: vkuzo Differential Revision: D23178602 fbshipit-source-id: 8e7e0322846fbda2cfa79ad188abd7235326f879
Stack from ghstack:
This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk
FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an
nn.Module
instance as its argument, GraphModule will trace through the computation of that Module'sforward
method symbolically and record those operations in the FX intermediate representation.The Intermediate Representation centers around a 5-opcode format:
The semantics are as follows:
placeholder
represents a function input. Thename
attribute specifies the name this value will take on.target
is similarly the name of the argument.args
andkwargs
are don't-careget_param
retrieves a parameter from the module hierarchy.name
is similarly the name the result of the fetch is assigned to.target
is the fully-qualified name of the parameter's position in the module hierarchy.args
andkwargs
are don't-carecall_function
applies a free function to some values.name
is similarly the name of the value to assign to.target
is the function to be applied.args
andkwargs
represent the arguments to the function, following the Python calling conventioncall_module
applies a module in the module hierarchy'sforward()
method to given arguments.name
is as previous.target
is the fully-qualified name of the module in the module hierarchy to call.args
andkwargs
represent the arguments to invoke the module on, including the self argument.call_method
calls a method on a value.name
is as similar.target
is the string name of the method to apply to theself
argument.args
andkwargs
represent the arguments to invoke the module on, including the self argument.GraphModule automatically generates Python code for the operations it symbolically observed:
Because this code is valid PyTorch code, the resulting
GraphModule
can be used in any context anothernn.Module
can be used, including in TorchScript tracing/compilation.Differential Revision: D23006383