Skip to content

Commit

Permalink
add self.with_neck (open-mmlab#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Apr 10, 2021
1 parent af4cc63 commit 905f07a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
7 changes: 6 additions & 1 deletion mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(self,

self.fp16_enabled = False

@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None

def init_weights(self):
"""Initialize the model network weights."""
if self.backbone_from in ['mmcls', 'mmaction2']:
Expand All @@ -105,7 +110,7 @@ def init_weights(self):
f'{self.backbone_from}!')

self.cls_head.init_weights()
if hasattr(self, 'neck'):
if self.with_neck:
self.neck.init_weights()

@auto_fp16()
Expand Down
8 changes: 4 additions & 4 deletions mmaction/models/recognizers/recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward_train(self, imgs, labels, **kwargs):
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))

if hasattr(self, 'neck'):
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
Expand Down Expand Up @@ -60,7 +60,7 @@ def _do_test(self, imgs):
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))

if hasattr(self, 'neck'):
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
Expand Down Expand Up @@ -97,7 +97,7 @@ def _do_fcn_test(self, imgs):
imgs = torch.flip(imgs, [-1])
x = self.extract_feat(imgs)

if hasattr(self, 'neck'):
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
Expand Down Expand Up @@ -147,7 +147,7 @@ def forward_dummy(self, imgs, softmax=False):
num_segs = imgs.shape[0] // batches

x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
Expand Down
8 changes: 4 additions & 4 deletions mmaction/models/recognizers/recognizer3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def forward_train(self, imgs, labels, **kwargs):
losses = dict()

x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
if self.with_neck:
x, loss_aux = self.neck(x, labels.squeeze())
losses.update(loss_aux)

Expand All @@ -42,15 +42,15 @@ def _do_test(self, imgs):
while view_ptr < total_views:
batch_imgs = imgs[view_ptr:view_ptr + self.max_testing_views]
x = self.extract_feat(batch_imgs)
if hasattr(self, 'neck'):
if self.with_neck:
x, _ = self.neck(x)
cls_score = self.cls_head(x)
cls_scores.append(cls_score)
view_ptr += self.max_testing_views
cls_score = torch.cat(cls_scores)
else:
x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
if self.with_neck:
x, _ = self.neck(x)
cls_score = self.cls_head(x)

Expand All @@ -76,7 +76,7 @@ def forward_dummy(self, imgs, softmax=False):
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
x = self.extract_feat(imgs)

if hasattr(self, 'neck'):
if self.with_neck:
x, _ = self.neck(x)

outs = self.cls_head(x)
Expand Down

0 comments on commit 905f07a

Please sign in to comment.