Skip to content

Commit 918e230

Browse files
author
sennnnn
committed
[Refactor] Using mmcv bricks to refactor vit
1 parent 66b0525 commit 918e230

File tree

5 files changed

+345
-334
lines changed

5 files changed

+345
-334
lines changed
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
from abc import ABCMeta, abstractmethod
3+
4+
from mmcv.runner import load_checkpoint
5+
from mmcv.runner.base_module import BaseModule
6+
7+
8+
class BaseBackbone(BaseModule, metaclass=ABCMeta):
9+
"""Base backbone.
10+
11+
This class defines the basic functions of a backbone. Any backbone that
12+
inherits this class should at least define its own `forward` function.
13+
"""
14+
15+
def __init__(self, init_cfg=None):
16+
super(BaseBackbone, self).__init__(init_cfg)
17+
18+
def init_weights(self, pretrained=None):
19+
"""Init backbone weights.
20+
21+
Args:
22+
pretrained (str | None): If pretrained is a string, then it
23+
initializes backbone weights by loading the pretrained
24+
checkpoint. If pretrained is None, then it follows default
25+
initializer or customized initializer in subclasses.
26+
"""
27+
if isinstance(pretrained, str):
28+
logger = logging.getLogger()
29+
load_checkpoint(self, pretrained, strict=False, logger=logger)
30+
elif pretrained is None:
31+
# use default initializer or customized initializer in subclasses
32+
super(BaseBackbone, self).init_weights()
33+
else:
34+
raise TypeError('pretrained must be a str or None.'
35+
f' But received {type(pretrained)}.')
36+
37+
@abstractmethod
38+
def forward(self, x):
39+
"""Forward computation.
40+
41+
Args:
42+
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
43+
Torch.tensor, containing input data for forward computation.
44+
"""
45+
pass
46+
47+
def train(self, mode=True):
48+
"""Set module status before forward computation.
49+
50+
Args:
51+
mode (bool): Whether it is train_mode or test_mode
52+
"""
53+
super(BaseBackbone, self).train(mode)

0 commit comments

Comments
 (0)