Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions timm/models/mobilenetv5.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
)
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq, checkpoint
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model

__all__ = ['MobileNetV5', 'MobileNetV5Encoder']

_GELU = partial(nn.GELU, approximate='tanh')


@register_notrace_module
class MobileNetV5MultiScaleFusionAdapter(nn.Module):
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(
self.layer_scale_init_value = layer_scale_init_value
self.noskip = noskip

act_layer = act_layer or nn.GELU
act_layer = act_layer or _GELU
norm_layer = norm_layer or RmsNorm2d
self.ffn = UniversalInvertedResidual(
in_chs=self.in_channels,
Expand Down Expand Up @@ -167,7 +169,7 @@ def __init__(
global_pool: Type of pooling to use for global pooling features of the FC head.
"""
super().__init__()
act_layer = act_layer or nn.GELU
act_layer = act_layer or _GELU
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
Expand Down Expand Up @@ -410,7 +412,7 @@ def __init__(
block_args: BlockArgs,
in_chans: int = 3,
stem_size: int = 64,
stem_bias: bool = False,
stem_bias: bool = True,
fix_stem: bool = False,
pad_type: str = '',
msfa_indices: Sequence[int] = (-2, -1),
Expand All @@ -426,7 +428,7 @@ def __init__(
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
act_layer = act_layer or nn.GELU
act_layer = act_layer or _GELU
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
se_layer = se_layer or SqueezeExcite
self.num_classes = 0 # Exists to satisfy ._hub module APIs.
Expand Down Expand Up @@ -526,6 +528,7 @@ def forward_intermediates(
feat_idx = 0 # stem is index 0
x = self.conv_stem(x)
if feat_idx in take_indices:
print("conv_stem is captured")
intermediates.append(x)
if feat_idx in self.msfa_indices:
msfa_intermediates.append(x)
Expand Down Expand Up @@ -777,7 +780,7 @@ def _gen_mobilenet_v5(
fix_stem=channel_multiplier < 1.0,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=RmsNorm2d,
act_layer=nn.GELU,
act_layer=_GELU,
layer_scale_init_value=1e-5,
)
model_kwargs = dict(model_kwargs, **kwargs)
Expand Down