Skip to content

Commit

Permalink
Add required example_args argument to prepare_fx and prepare_qat_fx (…
Browse files Browse the repository at this point in the history
…#77608)

Summary:
X-link: pytorch/pytorch#77608

X-link: pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

Pull Request resolved: pytorch#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 7e1ce6dc13a1ecc4d46939c8e3b3f3721248c727
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed May 17, 2022
1 parent 211a0ef commit a5b8463
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
def prep_qat_train(self):
qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')}
self.model.train()
self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict)
self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict, self.example_inputs)

def train(self, niter=3):
optimizer = optim.Adam(self.model.parameters())
Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/models/resnet50_quantized_qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
def prep_qat_train(self):
qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')}
self.model.train()
self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict)
self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict, self.example_inputs)

def get_module(self):
return self.model, self.example_inputs
Expand Down

0 comments on commit a5b8463

Please sign in to comment.