Skip to content

Commit 0c31a33

Browse files
committed
Extending ConvBNReLU for reuse.
1 parent 2ebe8ba commit 0c31a33

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torchvision/models/mobilenetv2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,33 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->
3232
return new_v
3333

3434

35-
class ConvBNReLU(nn.Sequential):
35+
class ConvBNActivation(nn.Sequential):
3636
def __init__(
3737
self,
3838
in_planes: int,
3939
out_planes: int,
4040
kernel_size: int = 3,
4141
stride: int = 1,
4242
groups: int = 1,
43-
norm_layer: Optional[Callable[..., nn.Module]] = None
43+
norm_layer: Optional[Callable[..., nn.Module]] = None,
44+
activation_layer: Optional[Callable[..., nn.Module]] = None,
4445
) -> None:
4546
padding = (kernel_size - 1) // 2
4647
if norm_layer is None:
4748
norm_layer = nn.BatchNorm2d
49+
if activation_layer is None:
50+
activation_layer = nn.ReLU6
4851
super(ConvBNReLU, self).__init__(
4952
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
5053
norm_layer(out_planes),
51-
nn.ReLU6(inplace=True)
54+
activation_layer(inplace=True)
5255
)
5356

5457

58+
# necessary for backwards compatibility
59+
ConvBNReLU = ConvBNActivation
60+
61+
5562
class InvertedResidual(nn.Module):
5663
def __init__(
5764
self,

0 commit comments

Comments
 (0)