-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
27 lines (22 loc) · 831 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from mxnet import ndarray as nd
from mxnet.gluon import HybridBlock, nn
from mxnet.gluon.model_zoo import vision as models
class Net(HybridBlock):
def __init__(self, **kwargs):
super().__init__(**kwargs)
with self.name_scope():
self.features = models.resnet50_v2(pretrained=True).features
# self.features = models.densenet121(pretrained=True)
self.features.collect_params().setattr('lr_mult', 0.1)
self.output = nn.Dense(units=101)
def hybrid_forward(self, F, x, *args, **kwargs):
x = self.features(x)
x = self.output(x)
x = F.softmax(x)
return x
if __name__ == '__main__':
data = nd.random_uniform(shape=(4, 3, 224, 224))
net = Net()
net.collect_params().initialize()
output = net(data)
print(output)