Skip to content

Commit

Permalink
support layer_decay in optim_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
sageyou committed Jan 15, 2024
1 parent a086e4e commit b468028
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 5 deletions.
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def create_parser():
help='Whether use clip grad (default=False)')
group.add_argument('--clip_value', type=float, default=15.0,
help='Clip value (default=15.0)')
group.add_argument('--layer_decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
group.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Accumulate the gradients of n batches before update.")

Expand Down
171 changes: 167 additions & 4 deletions mindcv/optim/optim_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
""" optim factory """
import collections
import logging
import os
from typing import Optional
import re
from collections import defaultdict
from itertools import chain, islice
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union

from mindspore import load_checkpoint, load_param_into_net, nn

Expand All @@ -14,6 +18,8 @@

_logger = logging.getLogger(__name__)

MATCH_PREV_GROUP = [9]


def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay):
if weight_decay_filter == "disable":
Expand All @@ -37,6 +43,152 @@ def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay
]


def param_groups_layer_decay(
model: nn.Cell,
lr: Optional[float] = 1e-3,
weight_decay: float = 0.05,
no_weight_decay_list: Tuple[str] = (),
layer_decay: float = 0.75,
):
"""
Parameter groups for layer-wise lr decay & weight decay
"""
no_weight_decay_list = set(no_weight_decay_list)
param_group_names = {} # NOTE for debugging
param_groups = {}
if hasattr(model, "group_matcher"):
layer_map = group_with_matcher(model.trainable_params(), model.group_matcher(coarse=False), reverse=True)
else:
layer_map = _layer_map(model)

num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))

for name, param in model.parameters_and_names():
if not param.requires_grad:
continue

# no decay: all 1D parameters and model specific ones
if param.ndim == 1 or name in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.0
else:
g_decay = "decay"
this_decay = weight_decay

layer_id = layer_map.get(name, layer_max)
group_name = "layer_%d_%s" % (layer_id, g_decay)

if group_name not in param_groups:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr": [learning_rate * this_scale for learning_rate in lr],
"weight_decay": this_decay,
"param_names": [],
}
param_groups[group_name] = {
"lr": [learning_rate * this_scale for learning_rate in lr],
"weight_decay": this_decay,
"params": [],
}

param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)

return list(param_groups.values())


MATCH_PREV_GROUP = (99999,)


def group_with_matcher(
named_objects: Iterator[Tuple[str, Any]], group_matcher: Union[Dict, Callable], reverse: bool = False
):
if isinstance(group_matcher, dict):
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
compiled = []
for group_ordinal, (_, mspec) in enumerate(group_matcher.items()):
if mspec is None:
continue
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
if isinstance(mspec, (tuple, list)):
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
for sspec in mspec:
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
else:
compiled += [(re.compile(mspec), (group_ordinal,), None)]
group_matcher = compiled

def _get_grouping(name):
if isinstance(group_matcher, (list, tuple)):
for match_fn, prefix, suffix in group_matcher:
r = match_fn.match(name)
if r:
parts = (prefix, r.groups(), suffix)
# map all tuple elem to int for numeric sort, filter out None entries
return tuple(map(float, chain.from_iterable(filter(None, parts))))
return (float("inf"),) # un-matched layers (neck, head) mapped to largest ordinal
else:
ord = group_matcher(name)
if not isinstance(ord, collections.abc.Iterable):
return (ord,)
return tuple(ord)

grouping = defaultdict(list)
for param in named_objects:
grouping[_get_grouping(param.name)].append(param.name)
# remap to integers
layer_id_to_param = defaultdict(list)
lid = -1
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
lid += 1
layer_id_to_param[lid].extend(grouping[k])

if reverse:
# output reverse mapping
param_to_layer_id = {}
for lid, lm in layer_id_to_param.items():
for n in lm:
param_to_layer_id[n] = lid
return param_to_layer_id

return layer_id_to_param


def _group(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())


def _layer_map(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
elif isinstance(hp, (tuple, list)):
return any([n.startswith(hpi) for hpi in hp])
else:
return n.startswith(hp)

# attention: need to add pretrained_cfg attr to model
head_prefix = getattr(model, "pretrained_cfg", {}).get("classifier", None)
names_trunk = []
names_head = []
for n, _ in model.parameters_and_names():
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)

# group non-head layers
num_trunk_layers = len(names_trunk)
if num_groups is not None:
layers_per_group = -(num_trunk_layers // -num_groups)
names_trunk = list(_group(names_trunk, layers_per_group))
num_trunk_groups = len(names_trunk)
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map


def create_optimizer(
model_or_params,
opt: str = "adam",
Expand All @@ -45,6 +197,7 @@ def create_optimizer(
momentum: float = 0.9,
nesterov: bool = False,
weight_decay_filter: str = "disable",
layer_decay: Optional[float] = None,
loss_scale: float = 1.0,
schedule_decay: float = 4e-3,
checkpoint_path: str = "",
Expand All @@ -54,9 +207,9 @@ def create_optimizer(
r"""Creates optimizer by name.
Args:
params: network parameters. Union[list[Parameter],list[dict]], which must be the list of parameters
or list of dicts. When the list element is a dictionary, the key of the dictionary can be
"params", "lr", "weight_decay","grad_centralization" and "order_params".
model_or_params: network or network parameters. Union[list[Parameter],list[dict], nn.Cell], which must be
the list of parameters or list of dicts or nn.Cell. When the list element is a dictionary, the key of
the dictionary can be "params", "lr", "weight_decay","grad_centralization" and "order_params".
opt: wrapped optimizer. You could choose like 'sgd', 'nesterov', 'momentum', 'adam', 'adamw', 'lion',
'rmsprop', 'adagrad', 'lamb'. 'adam' is the default choose for convolution-based networks.
'adamw' is recommended for ViT-based networks. Default: 'adam'.
Expand All @@ -73,6 +226,7 @@ def create_optimizer(
- "auto": We do not apply weight decay filtering to any parameters. However, MindSpore currently
automatically filters the parameters of Norm layer from weight decay.
- "norm_and_bias": Filter the paramters of Norm layer and Bias from weight decay.
layer_decay: for apply layer-wise learning rate decay.
loss_scale: A floating point value for the loss scale, which must be larger than 0.0. Default: 1.0.
Returns:
Expand All @@ -95,6 +249,15 @@ def create_optimizer(
"when creating an mindspore.nn.Optimizer instance. "
"NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. "
)
elif layer_decay is not None and isinstance(model_or_params, nn.Cell):
params = param_groups_layer_decay(
model_or_params,
lr=lr,
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
)
weight_decay = 0.0
elif weight_decay_filter == "disable" or "norm_and_bias":
params = init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay)
weight_decay = 0.0
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,14 @@ def main():
else:
optimizer_loss_scale = 1.0
optimizer = create_optimizer(
network.trainable_params(),
network,
opt=args.opt,
lr=lr_scheduler,
weight_decay=args.weight_decay,
momentum=args.momentum,
nesterov=args.use_nesterov,
weight_decay_filter=args.weight_decay_filter,
layer_decay=args.layer_decay,
loss_scale=optimizer_loss_scale,
checkpoint_path=opt_ckpt_path,
eps=args.eps,
Expand Down

0 comments on commit b468028

Please sign in to comment.