Skip to content

Commit

Permalink
Add required example_args argument to prepare_fx and prepare_qat_fx
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#77608

X-link: pytorch/fx2trt#76

Pull Request resolved: facebookresearch#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 5b7837005a34a095b331dbca7d6a8c2d6fa5ee51
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed May 19, 2022
1 parent 56243ca commit 56701e8
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 14 deletions.
18 changes: 16 additions & 2 deletions d2go/modeling/meta_arch/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import logging

import torch
import torch.nn as nn
from d2go.export.api import PredictorExportConfig
from d2go.quantization.modeling import set_backend_and_create_qconfig
Expand Down Expand Up @@ -199,9 +200,13 @@ def _fx_quant_prepare(self, cfg):
prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
# TODO[quant-example-inputs]: Expose example_inputs as argument
# Note: this is used in quantization for all submodules
example_inputs = (torch.rand(1, 3, 3, 3),)
self.backbone = prep_fn(
self.backbone,
qconfig,
example_inputs,
prepare_custom_config_dict={
"preserved_attributes": ["size_divisibility"],
# keep the output of backbone quantized, to avoid
Expand All @@ -215,38 +220,47 @@ def _fx_quant_prepare(self, cfg):
self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature,
qconfig,
example_inputs,
prepare_custom_config_dict={
# rpn_feature expecting quantized input, this is used to avoid redundant
# quant
"input_quantized_idxs": [0]
},
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits, qconfig
self.proposal_generator.rpn_head.rpn_regressor.cls_logits,
qconfig,
example_inputs,
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred, qconfig
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred,
qconfig,
example_inputs,
)
self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv,
qconfig,
example_inputs,
prepare_custom_config_dict={
"output_quantized_idxs": [0],
},
)
self.roi_heads.box_head.avgpool = prep_fn(
self.roi_heads.box_head.avgpool,
qconfig,
example_inputs,
prepare_custom_config_dict={"input_quantized_idxs": [0]},
)
self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score,
qconfig,
example_inputs,
prepare_custom_config_dict={"input_quantized_idxs": [0]},
)
self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred,
qconfig,
example_inputs,
prepare_custom_config_dict={"input_quantized_idxs": [0]},
)

Expand Down
6 changes: 4 additions & 2 deletions d2go/quantization/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,12 @@ def default_prepare_for_quant(cfg, model):
# here, to be consistent with the FX branch
else: # FX graph mode quantization
qconfig_dict = {"": qconfig}
# TODO[quant-example-inputs]: needs follow up to change the api
example_inputs = (torch.rand(1, 3, 3, 3),)
if model.training:
model = prepare_qat_fx(model, qconfig_dict)
model = prepare_qat_fx(model, qconfig_dict, example_inputs)
else:
model = prepare_fx(model, qconfig_dict)
model = prepare_fx(model, qconfig_dict, example_inputs)

logger.info("Setup the model with qconfig:\n{}".format(qconfig))

Expand Down
8 changes: 6 additions & 2 deletions d2go/runner/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,16 @@ def prepare(
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
prepared = root
# TODO[quant-example-inputs]: expose example_inputs as argument
# may need a dictionary that stores a map from submodule fqn to example_inputs
# for submodule
example_inputs = (torch.rand(1, 3, 3, 3),)
if "" in configs:
prepared = prep_fn(root, configs[""])
prepared = prep_fn(root, configs[""], example_inputs)
else:
for name, config in configs.items():
submodule = rgetattr(root, name)
rsetattr(root, name, prep_fn(submodule, config))
rsetattr(root, name, prep_fn(submodule, config, example_inputs))
for attr, value in old_attrs.items():
rsetattr(prepared, attr, value)
return prepared
Expand Down
18 changes: 13 additions & 5 deletions d2go/runner/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,23 @@ def _build_model(self, cfg, eval_only=False):
# Disable fake_quant and observer so that the model will be trained normally
# before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
if hasattr(model, "get_rand_input"):
model = setup_qat_model(
cfg, model, enable_fake_quant=eval_only, enable_observer=True
)
imsize = cfg.INPUT.MAX_SIZE_TRAIN
rand_input = model.get_rand_input(imsize)
model(rand_input, {})
example_inputs = (rand_input, {})
model = setup_qat_model(
cfg,
model,
enable_fake_quant=eval_only,
enable_observer=True,
)
model(*example_inputs)
else:
imsize = cfg.INPUT.MAX_SIZE_TRAIN
model = setup_qat_model(
cfg, model, enable_fake_quant=eval_only, enable_observer=False
cfg,
model,
enable_fake_quant=eval_only,
enable_observer=False,
)

if eval_only:
Expand Down
2 changes: 2 additions & 0 deletions d2go/utils/testing/meta_arch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def inference(self, inputs):
return ret

def prepare_for_quant(self, cfg):
example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
example_inputs,
)
return self

Expand Down
14 changes: 11 additions & 3 deletions tests/runner/test_runner_lightning_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,11 @@ class _CustomQAT(QuantizationAwareTraining):
"""Only quantize TestModule.another_layer."""

def prepare(self, model, configs, attrs):
model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
example_inputs = (torch.rand(1, 2),)
model.another_layer = prepare_qat_fx(
model.another_layer, configs[""], example_inputs
)

return model

def convert(self, model, submodules, attrs):
Expand Down Expand Up @@ -466,7 +470,11 @@ class _CustomStaticQuant(PostTrainingQuantization):
"""Only quantize TestModule.another_layer."""

def prepare(self, model, configs, attrs):
model.another_layer = prepare_fx(model.another_layer, configs[""])
example_inputs = (torch.randn(1, 2),)
model.another_layer = prepare_fx(
model.another_layer, configs[""], example_inputs
)

return model

def convert(self, model, submodules, attrs):
Expand Down Expand Up @@ -499,6 +507,6 @@ def convert(self, model, submodules, attrs):
# While quantized/original won't be exact, they should be close.
self.assertLess(
((((test_out - base_out) ** 2).sum(axis=1)) ** (1 / 2)).mean(),
0.015,
0.02,
"RMSE should be less than 0.007 between quantized and original.",
)
2 changes: 2 additions & 0 deletions tests/runner/test_runner_lightning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,11 @@ def __init__(self, cfg):
self.avgpool.not_preserved_attr = "bar"

def prepare_for_quant(self, cfg):
example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
example_inputs,
self.custom_config_dict,
)
return self
Expand Down

0 comments on commit 56701e8

Please sign in to comment.