1+ import torch
12from torch import nn , Tensor
23from torchvision .models .utils import load_state_dict_from_url
34from torchvision .models .mobilenetv3 import InvertedResidual , InvertedResidualConfig , ConvBNActivation , MobileNetV3 ,\
45 SqueezeExcitation , model_urls , _mobilenet_v3_conf
56from torch .quantization import QuantStub , DeQuantStub , fuse_modules
6- from typing import Any , List
7- from .utils import _replace_relu , quantize_model
7+ from typing import Any , List , Optional
8+ from .utils import _replace_relu
89
910
1011__all__ = ['QuantizableMobileNetV3' , 'mobilenet_v3_large' , 'mobilenet_v3_small' ]
1112
12- # TODO: Add URLs
1313quant_model_urls = {
14- 'mobilenet_v3_large_qnnpack' : None ,
14+ 'mobilenet_v3_large_qnnpack' :
15+ "https://github.com/datumbox/torchvision-models/raw/main/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth" ,
1516 'mobilenet_v3_small_qnnpack' : None ,
1617}
1718
@@ -69,6 +70,18 @@ def fuse_model(self):
6970 m .fuse_model ()
7071
7172
73+ def _load_weights (
74+ arch : str ,
75+ model : QuantizableMobileNetV3 ,
76+ model_url : Optional [str ],
77+ progress : bool ,
78+ ):
79+ if model_url is None :
80+ raise ValueError ("No checkpoint is available for {}" .format (arch ))
81+ state_dict = load_state_dict_from_url (model_url , progress = progress )
82+ model .load_state_dict (state_dict )
83+
84+
7285def _mobilenet_v3_model (
7386 arch : str ,
7487 inverted_residual_setting : List [InvertedResidualConfig ],
@@ -83,17 +96,18 @@ def _mobilenet_v3_model(
8396
8497 if quantize :
8598 backend = 'qnnpack'
86- quantize_model (model , backend )
87- model_url = quant_model_urls .get (arch + '_' + backend , None )
99+
100+ model .fuse_model ()
101+ model .qconfig = torch .quantization .get_default_qat_qconfig (backend )
102+ torch .quantization .prepare_qat (model , inplace = True )
103+
104+ if pretrained :
105+ _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress )
106+
107+ torch .quantization .convert (model , inplace = True )
88108 else :
89- assert pretrained in [True , False ]
90- model_url = model_urls .get (arch , None )
91-
92- if pretrained :
93- if model_url is None :
94- raise ValueError ("No checkpoint is available for {}" .format (arch ))
95- state_dict = load_state_dict_from_url (model_url , progress = progress )
96- model .load_state_dict (state_dict )
109+ if pretrained :
110+ _load_weights (arch , model , model_urls .get (arch , None ), progress )
97111
98112 return model
99113
0 commit comments