Skip to content

Commit

Permalink
swin pipeline setting (#215)
Browse files Browse the repository at this point in the history
* set layer idx

* disable clip grad

* add set_pipeline_stage_id function to swin model

* fix 3d parallel training convergence bug

* add checkpointing setting

* format code
  • Loading branch information
Ldpe2G authored Mar 30, 2022
1 parent c3936dd commit 7bf3820
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 27 deletions.
16 changes: 14 additions & 2 deletions configs/swin_cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,33 @@
optim.lr = 5e-4
optim.eps = 1e-8
optim.weight_decay = 0.05
optim.params.clip_grad_max_norm = 5.0
optim.params.clip_grad_max_norm = None
optim.params.clip_grad_norm_type = None

# Refine train cfg for swin model
train.train_micro_batch_size = 32
train.num_accumulation_steps = 1
train.test_micro_batch_size = 32
train.train_epoch = 300
train.warmup_ratio = 20 / 300
train.evaluation.eval_period = 200
train.log_period = 1
train.log_period = 20

# Scheduler
train.scheduler.warmup_factor = 5e-7
train.scheduler.alpha = 0.0
train.scheduler.warmup_method = "linear"

# parallel strategy settings
train.dist.data_parallel_size = 8
train.dist.tensor_parallel_size = 1
train.dist.pipeline_parallel_size = 1
train.dist.pipeline_num_layers = sum(model.depths)
train.output_dir = "./output"

# Set fp16 ON
train.amp.enabled = False
train.activation_checkpoint.enabled = False
# train.zero_optimization.enabled = True
# train.zero_optimization.stage = 1
graph.enabled = False
96 changes: 76 additions & 20 deletions libai/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
attn_drop=0.0,
proj_drop=0.0,
fused_bias_add_dropout=False,
layer_idx=0,
):

super().__init__()
Expand All @@ -76,7 +77,7 @@ def __init__(
flow.zeros(
(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads,
placement=dist.get_layer_placement(0),
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
) # 2*Wh-1 * 2*Ww-1, nH
Expand All @@ -96,14 +97,14 @@ def __init__(
self.register_buffer(
"relative_position_index",
relative_position_index.to_global(
placement=dist.get_layer_placement(0),
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)

self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
self.qkv = Linear(dim, dim * 3, bias=qkv_bias, layer_idx=layer_idx)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = Linear(dim, dim)
self.proj = Linear(dim, dim, layer_idx=layer_idx)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.fused_bias_add_dropout = fused_bias_add_dropout
Expand Down Expand Up @@ -191,6 +192,7 @@ def __init__(
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=LayerNorm,
layer_idx=0,
):
super().__init__()
self.dim = dim
Expand All @@ -199,13 +201,14 @@ def __init__(
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.layer_idx = layer_idx
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

self.norm1 = norm_layer(dim)
self.norm1 = norm_layer(dim, layer_idx=layer_idx)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
Expand All @@ -215,17 +218,19 @@ def __init__(
attn_drop=attn_drop,
proj_drop=drop,
fused_bias_add_dropout=True,
layer_idx=layer_idx,
)

self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm2 = norm_layer(dim, layer_idx=layer_idx)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
hidden_size=dim,
ffn_hidden_size=mlp_hidden_dim,
output_dropout_prob=drop,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
layer_idx=layer_idx,
)

if self.shift_size > 0:
Expand Down Expand Up @@ -257,7 +262,7 @@ def __init__(
attn_mask == 0, float(0.0)
)
attn_mask = attn_mask.to_global(
placement=dist.get_layer_placement(0),
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
else:
Expand Down Expand Up @@ -317,12 +322,13 @@ class PatchMerging(nn.Module):
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""

def __init__(self, input_resolution, dim, norm_layer=LayerNorm):
def __init__(self, input_resolution, dim, norm_layer=LayerNorm, layer_idx=0):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
self.reduction = Linear(4 * dim, 2 * dim, bias=False, layer_idx=layer_idx)
self.norm = norm_layer(4 * dim, layer_idx=layer_idx)
self.layer_idx = layer_idx

def forward(self, x):
"""
Expand Down Expand Up @@ -358,7 +364,9 @@ class PatchEmbed(nn.Module):
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""

def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, layer_idx=0
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
Expand All @@ -377,11 +385,11 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
).to_global(
placement=dist.get_layer_placement(0),
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
self.norm = norm_layer(embed_dim, layer_idx=layer_idx)
else:
self.norm = None

Expand Down Expand Up @@ -432,14 +440,15 @@ def __init__(
norm_layer=LayerNorm,
downsample=None,
use_checkpoint=False,
layer_id_offset=0,
):

super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint

self.layer_id_offset = layer_id_offset
# build blocks
self.blocks = nn.ModuleList(
[
Expand All @@ -456,23 +465,32 @@ def __init__(
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
layer_idx=layer_id_offset + i,
)
for i in range(depth)
]
)

# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=norm_layer,
layer_idx=layer_id_offset + depth - 1,
)
else:
self.downsample = None

def forward(self, x):
for blk in self.blocks:
layer_idx = self.layer_id_offset
for i in range(len(self.blocks)):
x = x.to_global(placement=dist.get_layer_placement(layer_idx))
if self.use_checkpoint:
raise Exception("Not Support Checkpointing yet!")
else:
x = blk(x)
x = self.blocks[i](x)
layer_idx += 1
if self.downsample is not None:
x = self.downsample(x)
return x
Expand Down Expand Up @@ -550,14 +568,19 @@ def __init__(
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
layer_idx=0,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution

# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(flow.zeros(1, num_patches, embed_dim))
self.absolute_pos_embed = nn.Parameter(
flow.zeros(1, num_patches, embed_dim),
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
trunc_normal_(self.absolute_pos_embed, std=0.02)

self.pos_drop = nn.Dropout(p=drop_rate)
Expand All @@ -569,6 +592,7 @@ def __init__(

# build layers
self.layers = nn.ModuleList()
layer_id_offset = 0
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
Expand All @@ -588,12 +612,18 @@ def __init__(
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
layer_id_offset=layer_id_offset,
)
layer_id_offset += depths[i_layer]
self.layers.append(layer)

self.norm = norm_layer(self.num_features)
self.norm = norm_layer(self.num_features, layer_idx=-1)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head = (
Linear(self.num_features, num_classes, layer_idx=-1)
if num_classes > 0
else nn.Identity()
)

# Loss func
self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func
Expand Down Expand Up @@ -666,3 +696,29 @@ def forward(self, images, labels=None):
return {"losses": losses}
else:
return {"prediction_scores": x}

@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()

model.patch_embed.config.stage_id = dist_utils.get_layer_stage_id(0)
model.pos_drop.config.stage_id = dist_utils.get_layer_stage_id(0)

# Set pipeline parallelism stage_id
for module_block in model.modules():
# module.origin can get the original module
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.stage_id = dist_utils.get_layer_stage_id(module_block.layer_idx)
elif isinstance(module_block.origin, PatchMerging):
module_block.config.stage_id = dist_utils.get_layer_stage_id(module_block.layer_idx)

model.norm.config.stage_id = dist_utils.get_layer_stage_id(-1)
model.head.config.stage_id = dist_utils.get_layer_stage_id(-1)
model.avgpool.config.stage_id = dist_utils.get_layer_stage_id(-1)
model.loss_func.config.stage_id = dist_utils.get_layer_stage_id(-1)

@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.activation_checkpointing = True
9 changes: 6 additions & 3 deletions libai/models/utils/graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@ def build(self, **kwargs):
return self.model(**kwargs)

def set_activation_checkpoint(self):
for module_block in self.model.modules():
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
if hasattr(type(self.model.origin), "set_activation_checkpoint"):
type(self.model.origin).set_activation_checkpoint(self.model)
else:
for module_block in self.model.modules():
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True

def set_pipeline_stage_id(self):
if hasattr(type(self.model.origin), "set_pipeline_stage_id"):
Expand Down
4 changes: 2 additions & 2 deletions libai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def pipeline_parallel_size(self):

@property
def model_parallel_size(self):
return self._tensor_parallel_size * self._pipeline_parallel_size
return self._tensor_parallel_size

@property
def data_parallel_size(self):
Expand Down Expand Up @@ -299,7 +299,7 @@ def get_hidden_sbp():

def get_data_parallel_rank():
dist_util = get_dist_util()
return flow.env.get_rank() // dist_util.model_parallel_size
return (flow.env.get_rank() // dist_util.model_parallel_size) % dist_util.data_parallel_size


def get_data_parallel_size():
Expand Down

0 comments on commit 7bf3820

Please sign in to comment.