Skip to content

Commit

Permalink
refactor: Replaced Flatten by torch version (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
frgfm authored Feb 11, 2021
1 parent b5cbeb3 commit bad74ad
Showing 1 changed file with 1 addition and 11 deletions.
12 changes: 1 addition & 11 deletions pyrovision/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,6 @@ def init_module(m, init=nn.init.kaiming_normal_):
m.bias.data.fill_(0.)


class Flatten(nn.Module):
"""Implements a flattening layer"""
def __init__(self):
super(Flatten, self).__init__()

@staticmethod
def forward(x):
return x.view(x.size(0), -1)


def head_stack(in_features, out_features, bn=True, p=0., actn=None):
"""Stacks batch norm, dropout and fully connected layers together
Expand Down Expand Up @@ -98,7 +88,7 @@ def create_head(in_features, num_classes, lin_features=512, dropout_prob=0.5,
activations = [nn.ReLU(inplace=True)] * (len(lin_features) - 2) + [None]

# Flatten pooled feature maps
layers = [pool, Flatten()]
layers = [pool, nn.Flatten()]
for in_feats, out_feats, prob, activation in zip(lin_features[:-1], lin_features[1:], dropout_prob, activations):
layers.extend(head_stack(in_feats, out_feats, True, prob, activation))
# Final batch norm
Expand Down

0 comments on commit bad74ad

Please sign in to comment.