forked from facebookresearch/ClassyVision
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add required example_inputs argument to prepare_fx and prepare_qat_fx…
… (#77608) Summary: X-link: pytorch/pytorch#77608 X-link: pytorch/fx2trt#76 X-link: facebookresearch/d2go#249 X-link: fairinternal/ClassyVision#104 X-link: pytorch/benchmark#916 Pull Request resolved: facebookresearch#791 X-link: facebookresearch/mobile-vision#68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict) # or m = prepare_qat_fx(m, qconfig_dict) ``` After: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) # or m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) ``` Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 68861512ebb3dd1ae6e069a144c30eee2d63e5da
- Loading branch information
1 parent
35c8f33
commit 4846183
Showing
4 changed files
with
47 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters