diff --git a/python/paddle/vision/models/mobilenetv2.py b/python/paddle/vision/models/mobilenetv2.py index 60914b48f008f..9122e1cfdb381 100644 --- a/python/paddle/vision/models/mobilenetv2.py +++ b/python/paddle/vision/models/mobilenetv2.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + TypedDict, +) + +from typing_extensions import NotRequired, Unpack + import paddle from paddle import nn from paddle.utils.download import get_weights_path_from_url @@ -19,6 +28,9 @@ from ..ops import ConvNormActivation from ._utils import _make_divisible +if TYPE_CHECKING: + from paddle import Tensor + __all__ = [] model_urls = { @@ -29,10 +41,21 @@ } +class _MobileNetV2Options(TypedDict): + scale: NotRequired[float] + num_classes: NotRequired[int] + with_pool: NotRequired[bool] + + class InvertedResidual(nn.Layer): def __init__( - self, inp, oup, stride, expand_ratio, norm_layer=nn.BatchNorm2D - ): + self, + inp: int, + oup: int, + stride: int, + expand_ratio: float, + norm_layer: nn.Layer = nn.BatchNorm2D, + ) -> None: super().__init__() self.stride = stride assert stride in [1, 2] @@ -67,7 +90,7 @@ def __init__( ) self.conv = nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: return x + self.conv(x) else: @@ -102,7 +125,12 @@ class MobileNetV2(nn.Layer): [1, 1000] """ - def __init__(self, scale=1.0, num_classes=1000, with_pool=True): + def __init__( + self, + scale: float = 1.0, + num_classes: int = 1000, + with_pool: bool = True, + ) -> None: super().__init__() self.num_classes = num_classes self.with_pool = with_pool @@ -171,7 +199,7 @@ def __init__(self, scale=1.0, num_classes=1000, with_pool=True): nn.Dropout(0.2), nn.Linear(self.last_channel, num_classes) ) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.features(x) if self.with_pool: @@ -183,7 +211,9 @@ def forward(self, x): return x -def _mobilenet(arch, pretrained=False, **kwargs): +def _mobilenet( + arch: str, pretrained: bool = False, **kwargs: Unpack[_MobileNetV2Options] +) -> MobileNetV2: model = MobileNetV2(**kwargs) if pretrained: assert ( @@ -199,7 +229,11 @@ def _mobilenet(arch, pretrained=False, **kwargs): return model -def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): +def mobilenet_v2( + pretrained: bool = False, + scale: float = 1.0, + **kwargs: Unpack[_MobileNetV2Options], +) -> MobileNetV2: """MobileNetV2 from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.