Skip to content

Commit

Permalink
clean notes
Browse files Browse the repository at this point in the history
  • Loading branch information
aptsunny committed Oct 22, 2022
1 parent 484a7d2 commit a03d18f
Showing 1 changed file with 5 additions and 78 deletions.
83 changes: 5 additions & 78 deletions mmrazor/models/architectures/backbones/autoformer_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,58 +96,25 @@ def mutate_encoder_layer(self, mutable_num_heads: BaseMutable,
self.mutable_mlp_ratios = mutable_mlp_ratios

# handle the mutable of the first dynamic LN
# self.norm1.register_mutable_attr('num_features', mutable_embed_dims)
MutableChannelContainer.register_mutable_channel_to_module(
self.norm1, self.mutable_embed_dims, True)

# handle the mutable in multihead attention
# mutable_value = SampleExpandDerivedMutable(64)
# mutable_q_embed_dims = mutable_num_heads.derive_expand_mutable(64)
mutable_q_embed_dims = 64 * mutable_num_heads

# 某一个有两个
# 如果有 DynamicMHAMixin 有 attr_mappings,是不是就不用 in_label
MutableChannelContainer.register_mutable_channel_to_module(
self.attn, self.mutable_embed_dims, False)
MutableChannelContainer.register_mutable_channel_to_module(
self.attn, mutable_q_embed_dims, True, end=640)
# self.attn, mutable_q_embed_dims, True, in_label='embed_dims')



# MutableChannelContainer.register_mutable_channel_to_module(
# self.attn, mutable_q_embed_dims, True)

# MutableChannelContainer.register_mutable_channel_to_module(
# self.attn.rel_pos_embed_k, self.mutable_head_dims, True)
# MutableChannelContainer.register_mutable_channel_to_module(
# self.attn.rel_pos_embed_v, self.mutable_head_dims, True)

# self.attn.register_mutable_attr('embed_dims', mutable_embed_dims)
# handle the mutable in multihead attention
mutable_q_embed_dims = 64 * mutable_num_heads
self.attn.register_mutable_attr('num_heads', mutable_num_heads)
# self.attn.register_mutable_attr('q_embed_dims', mutable_q_embed_dims)
# self.attn.rel_pos_embed_k.register_mutable_attr(
# 'head_dims', self.mutable_head_dims)
# self.attn.rel_pos_embed_v.register_mutable_attr(
# 'head_dims', self.mutable_head_dims)

# MutableChannelContainer.register_mutable_channel_to_module(
# self.attn, self.mutable_embed_dims, True)

# handle the mutable of the second dynamic LN
# self.norm2.register_mutable_attr('num_features', mutable_embed_dims)
MutableChannelContainer.register_mutable_channel_to_module(
self.norm2, self.mutable_embed_dims, True)


# handle the mutable of FFN
# mutable channel x mutable value
# 这还有结合两种的
# self.middle_channels = mutable_embed_dims.derive_expand_mutable(
# mutable_mlp_ratios)
self.middle_channels = mutable_mlp_ratios * mutable_embed_dims
# self.middle_channels = mutable_embed_dims

# !!bugfix: support the derive_range
MutableChannelContainer.register_mutable_channel_to_module(
self.fc1, mutable_embed_dims, False)
MutableChannelContainer.register_mutable_channel_to_module(
Expand All @@ -159,11 +126,6 @@ def mutate_encoder_layer(self, mutable_num_heads: BaseMutable,
MutableChannelContainer.register_mutable_channel_to_module(
self.fc2, mutable_embed_dims, True)

# self.fc1.register_mutable_attr('in_channels', mutable_embed_dims)
# self.fc1.register_mutable_attr('out_channels', self.middle_channels)
# self.fc2.register_mutable_attr('in_channels', self.middle_channels)
# self.fc2.register_mutable_attr('out_channels', mutable_embed_dims)

def forward(self, x: Tensor) -> Tensor:
"""Forward of Transformer Encode Layer."""
residual = x
Expand Down Expand Up @@ -302,11 +264,9 @@ def __init__(self,
norm_cfg, self.mutable_embed_dims.num_channels)
self.add_module(self.norm1_name, norm1)

# 注册一个MutableChannelUnit
MutableChannelUnit._register_channel_container(
self, MutableChannelContainer)

# 和value相关的,和channel相关的先跳过
self.register_mutate()

@property
Expand All @@ -327,17 +287,14 @@ def make_layers(self, embed_dims, depth):
qkv_bias=self.qkv_bias,
act_cfg=self.act_cfg)
layers.append(layer)
return DynamicSequential(*layers) # 不加搜索空间会报错的bug
# return nn.Sequential(*layers)
return DynamicSequential(*layers)

def register_mutate(self):
"""Mutate the autoformer."""
# handle the mutation of depth
self.blocks.register_mutable_attr('depth', self.mutable_depth)

# handle the mutation of patch embed
# self.patch_embed.register_mutable_attr(
# 'embed_dims', self.mutable_embed_dims.derive_same_mutable())
MutableChannelContainer.register_mutable_channel_to_module(
self.patch_embed, self.mutable_embed_dims, True)

Expand All @@ -349,18 +306,11 @@ def register_mutate(self):
mutable_num_heads=self.mutable_num_heads[i],
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
mutable_embed_dims=self.last_mutable)
# mutable_embed_dims=self.last_mutable.derive_same_mutable())

# handle the mutable of final norm
if self.final_norm:
# self.norm1.register_mutable_attr(
# 'num_features', self.last_mutable.derive_same_mutable())
MutableChannelContainer.register_mutable_channel_to_module(
self.norm1, self.last_mutable, True)
# MutableChannelContainer.register_mutable_channel_to_module(
# self.norm1, self.last_mutable.derive_same_mutable(), True)



def forward(self, x: Tensor):
"""Forward of Autoformer."""
Expand Down Expand Up @@ -390,29 +340,6 @@ def forward(self, x: Tensor):


if __name__ == '__main__':
# supernet = dict(
# _scope_='mmrazor',
# type='SearchableImageClassifier',
# data_preprocessor=data_preprocessor,
# backbone=dict(_scope_='mmrazor', type='AutoformerBackbone'),
# neck=None,
# head=dict(
# type='DynamicLinearClsHead',
# num_classes=1000,
# in_channels=624,
# loss=dict(
# type='mmcls.LabelSmoothLoss',
# mode='original',
# num_classes=1000,
# label_smooth_val=0.1,
# loss_weight=1.0),
# topk=(1, 5)),

# model = AutoformerBackbone()
# inputs = torch.randn(1, 3, 224, 224)
# outputs = model(inputs)
# print(outputs.shape)

model = AutoformerBackbone()
# inputs = torch.randn(1, 3, 224, 224)
# outputs = model(inputs)
Expand All @@ -426,7 +353,7 @@ def forward(self, x: Tensor):
},
parse_cfg={'type': 'Predefined'})

mutator.prepare_from_supernet(model) # 解析模型中的动态OP
mutator.prepare_from_supernet(model)
print(mutator.sample_choices())


0 comments on commit a03d18f

Please sign in to comment.