Skip to content
This repository has been archived by the owner on Jan 22, 2025. It is now read-only.

Commit

Permalink
expose example_input argument in setup_qat_model()
Browse files Browse the repository at this point in the history
Summary:

Major changes
- **example_input** argument in **prepare_fake_quant_model()** is useful in certain cases. For example, in Argos model **custom_prepare_fx()** method under FX graph + QAT setup (D52760682), it is used to prepare example inputs to individual sub-modules by running one forward pass and bookkeeping the inputs to individual sub-modules. Therefore, we export argument **example_input** in **setup_qat_model()** function.
- For QAT model, currently we assert # of state dict keys (excluding observers) should be equal to # of state dict keys in the original model. However, when the assertion fails, it does not log useful information for debugging. We make changes to report what are the unique keys in each state dict.

Differential Revision: D52760688
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Jan 16, 2024
1 parent 573bd45 commit 7d36eed
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions d2go/quantization/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import logging
import math
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

import detectron2.utils.comm as comm
import torch
Expand Down Expand Up @@ -471,6 +471,7 @@ def setup_qat_model(
enable_fake_quant: bool = False,
enable_observer: bool = False,
enable_learnable_observer: bool = False,
example_input: Optional[Any] = None,
):
assert cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD in [
"default",
Expand All @@ -490,7 +491,7 @@ def setup_qat_model(
model_fp32_state_dict = model_fp32.state_dict()

# prepare model for qat
model = prepare_fake_quant_model(cfg, model_fp32, True)
model = prepare_fake_quant_model(cfg, model_fp32, True, example_input=example_input)

# make sure the proper qconfig are used in the model
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
Expand Down Expand Up @@ -554,9 +555,23 @@ def _setup_non_qat_to_qat_state_dict_map(
)
)

assert len(new_state_dict_non_observer_keys_not_ignored) == len(
if not len(new_state_dict_non_observer_keys_not_ignored) == len(
original_state_dict_shapes
), f"keys in state dict of original and new qat model {len(new_state_dict_non_observer_keys_not_ignored)} vs {len(original_state_dict_shapes)}"
):
a = set(new_state_dict_non_observer_keys_not_ignored)
b = set(original_state_dict_shapes.keys())
a_diff_b = a.difference(b)
b_diff_a = b.difference(a)
logger.info("unique keys in qat model state dict")
for key in a_diff_b:
logger.info(f"{key}")
logger.info("unique keys in original model state dict")
for key in b_diff_a:
logger.info(f"{key}")

raise RuntimeError(
f"an inconsistent number of keys in state dict of new qat and original model: {len(a)} vs {len(b)}"
)

if is_eager_mode:
for n_k, o_k in zip(
Expand Down

0 comments on commit 7d36eed

Please sign in to comment.