Skip to content

Commit

Permalink
[Typing][A-79] Add type annotations for `python/paddle/vision/models/…
Browse files Browse the repository at this point in the history
…mobilenetv2.py` (PaddlePaddle#65326)
  • Loading branch information
DrRyanHuang authored and co63oc committed Jun 25, 2024
1 parent a850c43 commit e80a020
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions python/paddle/vision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import (
TYPE_CHECKING,
TypedDict,
)

from typing_extensions import NotRequired, Unpack

import paddle
from paddle import nn
from paddle.utils.download import get_weights_path_from_url

from ..ops import ConvNormActivation
from ._utils import _make_divisible

if TYPE_CHECKING:
from paddle import Tensor

__all__ = []

model_urls = {
Expand All @@ -29,10 +41,21 @@
}


class _MobileNetV2Options(TypedDict):
scale: NotRequired[float]
num_classes: NotRequired[int]
with_pool: NotRequired[bool]


class InvertedResidual(nn.Layer):
def __init__(
self, inp, oup, stride, expand_ratio, norm_layer=nn.BatchNorm2D
):
self,
inp: int,
oup: int,
stride: int,
expand_ratio: float,
norm_layer: nn.Layer = nn.BatchNorm2D,
) -> None:
super().__init__()
self.stride = stride
assert stride in [1, 2]
Expand Down Expand Up @@ -67,7 +90,7 @@ def __init__(
)
self.conv = nn.Sequential(*layers)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
return x + self.conv(x)
else:
Expand Down Expand Up @@ -102,7 +125,12 @@ class MobileNetV2(nn.Layer):
[1, 1000]
"""

def __init__(self, scale=1.0, num_classes=1000, with_pool=True):
def __init__(
self,
scale: float = 1.0,
num_classes: int = 1000,
with_pool: bool = True,
) -> None:
super().__init__()
self.num_classes = num_classes
self.with_pool = with_pool
Expand Down Expand Up @@ -171,7 +199,7 @@ def __init__(self, scale=1.0, num_classes=1000, with_pool=True):
nn.Dropout(0.2), nn.Linear(self.last_channel, num_classes)
)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.features(x)

if self.with_pool:
Expand All @@ -183,7 +211,9 @@ def forward(self, x):
return x


def _mobilenet(arch, pretrained=False, **kwargs):
def _mobilenet(
arch: str, pretrained: bool = False, **kwargs: Unpack[_MobileNetV2Options]
) -> MobileNetV2:
model = MobileNetV2(**kwargs)
if pretrained:
assert (
Expand All @@ -199,7 +229,11 @@ def _mobilenet(arch, pretrained=False, **kwargs):
return model


def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
def mobilenet_v2(
pretrained: bool = False,
scale: float = 1.0,
**kwargs: Unpack[_MobileNetV2Options],
) -> MobileNetV2:
"""MobileNetV2 from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Expand Down

0 comments on commit e80a020

Please sign in to comment.