-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet_models.py
132 lines (109 loc) · 4.26 KB
/
resnet_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flax implementation of ResNet V1."""
# See issue #620.
# pytype: disable=wrong-arg-count
from functools import partial
from typing import Any, Callable, Sequence, Tuple
from flax import linen as nn
import jax.numpy as jnp
ModuleDef = Any
class ResNetBlock(nn.Module):
"""ResNet block."""
filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x,):
norm_kwargs = {"num_groups":32, "dtype":self.dtype}
residual = x
y = self.conv(self.filters, (3, 3), self.strides, kernel_dilation=(2,2))(x)
y = nn.GroupNorm(**norm_kwargs)(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3), kernel_dilation=(2,2))(y)
y = nn.GroupNorm(scale_init=nn.initializers.zeros_init(), **norm_kwargs)(y)
if residual.shape != y.shape:
residual = self.conv(self.filters, (1, 1),
self.strides, kernel_dilation=(2,2),
name='conv_proj')(residual)
residual = nn.GroupNorm(**norm_kwargs)(residual)
return self.act(residual + y)
class BottleneckResNetBlock(nn.Module):
"""Bottleneck ResNet block."""
filters: int
conv: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x):
norm_kwargs = {"num_groups":32, "dtype":self.dtype}
residual = x
y = self.conv(self.filters, (1, 1))(x)
y = nn.GroupNorm(**norm_kwargs)(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3), self.strides, kernel_dilation=(2,2))(y)
y = nn.GroupNorm(**norm_kwargs)(y)
y = self.act(y)
y = self.conv(self.filters * 4, (1, 1))(y)
y = nn.GroupNorm(scale_init=nn.initializers.zeros_init(), **norm_kwargs)(y)
if residual.shape != y.shape:
residual = self.conv(self.filters * 4, (1, 1),
self.strides, name='conv_proj')(residual)
residual = nn.GroupNorm(**norm_kwargs)(residual)
return self.act(residual + y)
class ResNet(nn.Module):
"""ResNetV1."""
stage_sizes: Sequence[int]
block_cls: ModuleDef
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
conv: ModuleDef = nn.Conv
@nn.compact
def __call__(self, x, train: bool = True):
conv = partial(self.conv, use_bias=False, dtype=self.dtype)
norm_kwargs = {"num_groups":32, "dtype":self.dtype}
x = conv(self.num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')(x)
x = nn.GroupNorm(**norm_kwargs)(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(self.num_filters * 2 ** i,
strides=strides,
conv=conv,
act=self.act)(x)
x = jnp.asarray(x, self.dtype)
return x
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3],
block_cls=ResNetBlock)
ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3],
block_cls=BottleneckResNetBlock)
ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3],
block_cls=BottleneckResNetBlock)
ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3],
block_cls=BottleneckResNetBlock)
ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3],
block_cls=BottleneckResNetBlock)
ResNet18Local = partial(ResNet, stage_sizes=[2, 2, 2, 2],
block_cls=ResNetBlock, conv=nn.ConvLocal)