Skip to content

Commit

Permalink
[quant][fx][bc-breaking] Add required example_inputs argument to prep…
Browse files Browse the repository at this point in the history
…are_fx and prepare_qat_fx (#77608)

Summary:
X-link: pytorch/pytorch#77608

Pull Request resolved: pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
```python
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)
# or
m = prepare_qat_fx(m, qconfig_dict)
```
After:
```python
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))
# or
m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))
```

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 706c8df71722c9aa5082a6491734f0144f0dd670
  • Loading branch information
jerryzh168 authored and Wei Wei committed Jun 4, 2022
1 parent 4086fdc commit 134de54
Showing 1 changed file with 60 additions and 12 deletions.
72 changes: 60 additions & 12 deletions test/quant/test_quant_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,14 @@ def forward(self, x):

# quantized input, quantized output
m = M()
qconfig_dict = {"": torch.ao.quantization.default_qconfig}
m.eval()
qconfig_dict = {"": torch.ao.quantization.default_qconfig}
example_inputs = (torch.rand(1, 1, 3, 3),)
mp = torch.ao.quantization.quantize_fx.prepare_fx(
m, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict
m,
qconfig_dict,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict,
)
self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
mp(torch.randn(1, 1, 4, 4))
Expand Down Expand Up @@ -221,20 +225,29 @@ def forward(self, x):
original_m.standalone.conv.bias.detach()
)

sm_example_inputs = (data,)
prepare_config = {
"standalone_module_name": [
("standalone", None, interface_config, backend_config_dict)
(
"standalone",
None,
sm_example_inputs,
interface_config,
backend_config_dict,
)
]
}

original_m_copy = copy.deepcopy(original_m)
original_ref_m_copy = copy.deepcopy(original_ref_m)

qconfig_dict = {"": qconfig}
example_inputs = (data,)
# check prepared model
m = prepare_fx(
original_m_copy,
qconfig_dict,
example_inputs,
prepare_custom_config_dict=prepare_config,
backend_config_dict=backend_config_dict,
)
Expand All @@ -255,7 +268,10 @@ def forward(self, x):

# quantize the reference model
ref_m = prepare_fx(
original_ref_m_copy, qconfig_dict, backend_config_dict=backend_config_dict
original_ref_m_copy,
qconfig_dict,
example_inputs,
backend_config_dict=backend_config_dict,
)
ref_m(data)
ref_m = convert_fx(
Expand Down Expand Up @@ -410,8 +426,12 @@ def _test_module(
else:
m = m.eval()
prepare = prepare_fx
example_inputs = tuple(inputs)
prepared = prepare(
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
m,
{"": self.trt_qconfig},
example_inputs,
backend_config_dict=self.trt_backend_config_dict,
)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare)
# calibration
Expand Down Expand Up @@ -528,8 +548,12 @@ def forward(self, x):
return x

m = M().eval()
example_inputs = (torch.rand(1, 3, 5, 5),)
m = prepare_fx(
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
m,
{"": self.trt_qconfig},
example_inputs,
backend_config_dict=self.trt_backend_config_dict,
)
m = convert_fx(
m, is_reference=True, backend_config_dict=self.trt_backend_config_dict
Expand Down Expand Up @@ -558,9 +582,11 @@ def forward(self, x):

m = LinearModule().eval()
trt_unsupported_qconfig = default_qconfig
example_inputs = (torch.rand(1, 5),)
prepared = prepare_fx(
m,
{"": trt_unsupported_qconfig},
example_inputs=example_inputs,
backend_config_dict=self.trt_backend_config_dict,
)
# calibration
Expand Down Expand Up @@ -588,8 +614,12 @@ def forward(self, x):
return torch.cat([x, x], 1)

m = M().eval()
example_inputs = (torch.rand(2, 2),)
prepared = prepare_fx(
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
m,
{"": self.trt_qconfig},
example_inputs,
backend_config_dict=self.trt_backend_config_dict,
)
self.assertTrue(len(dict(prepared.named_children())) == 1)
quantized = convert_fx(
Expand All @@ -615,8 +645,12 @@ def forward(self, x):
return torch.addmm(self.bias, x, self.weight)

m = M().eval()
example_inputs = (torch.rand(1, 5),)
prepared = prepare_fx(
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
m,
{"": self.trt_qconfig},
example_inputs,
backend_config_dict=self.trt_backend_config_dict,
)
node_occurrence = {
# weight
Expand Down Expand Up @@ -684,8 +718,12 @@ def conv_add_extra_inputs_getter(pattern):
m = M().eval()
modified_backend_config_dict = copy.deepcopy(self.trt_backend_config_dict)
modified_backend_config_dict["configs"].insert(0, conv_add_config)
example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1))
m = prepare_fx(
m, {"": self.trt_qconfig}, backend_config_dict=modified_backend_config_dict
m,
{"": self.trt_qconfig},
example_inputs,
backend_config_dict=modified_backend_config_dict,
)
print(m)
node_occurrence = {
Expand Down Expand Up @@ -717,7 +755,7 @@ def __init__(self):
self.conv = torch.nn.Conv2d(3, 3, 3)
self.standalone = Standalone()

def forward(self, x, y):
def forward(self, x):
y = self.conv(x)
return self.standalone(x, y)

Expand Down Expand Up @@ -765,9 +803,16 @@ def forward(self, x, y):
conv_config,
]
}
sm_example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1))
prepare_custom_config_dict = {
"standalone_module_name": [
("standalone", None, {"input_quantized_idxs": [0, 1]}, None)
(
"standalone",
None,
sm_example_inputs,
{"input_quantized_idxs": [0, 1]},
None,
)
]
}
# TODO: use self.trt_qconfig after input_quantized_idxs and output_quantized_idxs
Expand All @@ -778,9 +823,11 @@ def forward(self, x, y):
),
weight=torch.ao.quantization.default_weight_observer,
)
example_inputs = (torch.rand(1, 3, 5, 5),)
m = prepare_fx(
m,
{"": qconfig},
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict,
backend_config_dict=backend_config_dict,
)
Expand Down Expand Up @@ -829,10 +876,11 @@ def forward(self, x):

model = LinearModule().eval()
inputs = [torch.rand(8, 5)]

example_inputs = tuple(inputs)
prepared = prepare_fx(
model,
{"": self.trt_qconfig},
example_inputs,
backend_config_dict=self.trt_backend_config_dict,
)
quantized = convert_fx(
Expand Down

0 comments on commit 134de54

Please sign in to comment.