From 342be64543c41ea0db766019c2882b5b78e94ea4 Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 10 Feb 2021 23:52:32 +0100 Subject: [PATCH] refactor: Replaced Flatten by torch version --- pyrovision/models/utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pyrovision/models/utils.py b/pyrovision/models/utils.py index b2a94b5b..90b0e271 100644 --- a/pyrovision/models/utils.py +++ b/pyrovision/models/utils.py @@ -30,16 +30,6 @@ def init_module(m, init=nn.init.kaiming_normal_): m.bias.data.fill_(0.) -class Flatten(nn.Module): - """Implements a flattening layer""" - def __init__(self): - super(Flatten, self).__init__() - - @staticmethod - def forward(x): - return x.view(x.size(0), -1) - - def head_stack(in_features, out_features, bn=True, p=0., actn=None): """Stacks batch norm, dropout and fully connected layers together @@ -98,7 +88,7 @@ def create_head(in_features, num_classes, lin_features=512, dropout_prob=0.5, activations = [nn.ReLU(inplace=True)] * (len(lin_features) - 2) + [None] # Flatten pooled feature maps - layers = [pool, Flatten()] + layers = [pool, nn.Flatten()] for in_feats, out_feats, prob, activation in zip(lin_features[:-1], lin_features[1:], dropout_prob, activations): layers.extend(head_stack(in_feats, out_feats, True, prob, activation)) # Final batch norm