Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade MMTask #97

Merged
merged 8 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions configs/mmcls/mmcls_cifar_100_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
]

space = {
'data.samples_per_gpu': {{_base_.batch_size}},
'model': {{_base_.model}},
'model.head.num_classes': 100,
'optimizer': {{_base_.optimizer}},
'data.samples_per_gpu': {{_base_.batch_size}},
}

task = dict(type='MMClassification')
metric = 'val/accuracy_top-1'
mode = 'max'
raise_on_failed_trial = False
num_samples = 256
tune_cfg = dict(num_samples=8, metric='val/accuracy_top-1', mode='max')
7 changes: 2 additions & 5 deletions configs/mmdet/mmdet_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
]

space = {
'data.samples_per_gpu': {{_base_.batch_size}},
'model': {{_base_.model}},
'optimizer': {{_base_.optimizer}},
'data.samples_per_gpu': {{_base_.batch_size}},
}

task = dict(type='MMDetection')
metric = 'val/AP'
mode = 'max'
raise_on_failed_trial = False
num_samples = 256
tune_cfg = dict(num_samples=8, metric='val/AP', mode='max')
11 changes: 7 additions & 4 deletions configs/mmseg/mmseg_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
'../_base_/space/optimizer.py', '../_base_/space/batch_size.py'
]

space = dict(
data=dict(samples_per_gpu={{_base_.batch_size}}),
model={{_base_.model}},
optimizer={{_base_.optimizer}})
space = {
'data.samples_per_gpu': {{_base_.batch_size}},
'model': {{_base_.model}},
'model.decode_head.num_classes': 21,
'model.auxiliary_head.num_classes': 21,
'optimizer': {{_base_.optimizer}},
}

task = dict(type='MMSegmentation')
tune_cfg = dict(num_samples=8, metric='val/mIoU', mode='max')
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[flake8]
per-file-ignores =
siatune/mm/tasks/mm*.py: E251,E501

[isort]
line_length = 79
multi_line_output = 0
Expand Down
242 changes: 111 additions & 131 deletions siatune/mm/tasks/mmcls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,28 @@
import copy
import os
import time
import warnings
from os import path as osp
from typing import Optional, Sequence

import mmcv
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from mmcv.utils import Config, DictAction, get_git_hash
from typing import Sequence

from .builder import TASKS
from .mmtrainbase import MMTrainBasedTask


@TASKS.register_module()
class MMClassification(MMTrainBasedTask):
"""MMClassification Wrapping class for ray tune."""
"""MMClassification wrapper class for `ray.tune`.

def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
"""Define and parse the necessary arguments for the task.
It is modified from https://github.com/open-mmlab/mmclassification/blob/v0.23.2/tools/train.py

Args:
args (Sequence[str]): The args.
Returns:
argparse.Namespace: The parsed args.
"""
Attributes:
args (Sequence[str]):
"""

VERSION = 'v0.23.2'

def parse_args(self, task_args: Sequence[str]):
from mmcv import DictAction

parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='train config file path')
Expand All @@ -39,6 +36,31 @@ def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--device', help='device used for training. (Deprecated)')
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument(
'--ipu-replicas',
type=int,
default=None,
help='num of ipu replicas to use')
parser.add_argument(
'--seed', type=int, default=None, help='random seed')
parser.add_argument(
Expand All @@ -53,108 +75,55 @@ def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value '
'pair in xxx=yyy format will be merged into config file. If the '
'value to be overwritten is a list, it should be like key="[a,b]" '
'or key=a,b It also allows nested list/tuple values, e.g. '
'key="[(a,b),(c,d)]" Note that the quotation marks are necessary '
'and that no white space is allowed.')
args = parser.parse_args(args)
return args

def build_model(self,
cfg: Config,
train_cfg: Optional[Config] = None,
test_cfg: Optional[Config] = None) -> torch.nn.Module:
"""Build the model from configs.

Args:
cfg (Config): The configs.
train_cfg (Optional[Config]):
The train opt. Defaults to None.
test_cfg (Optional[Config]):
The Test opt. Defaults to None.

Returns:
torch.nn.Module: The model.
"""

from mmcls.models import build_classifier
return build_classifier(cfg)

def build_dataset(
self,
cfg: Config,
default_args: Optional[Config] = None) -> torch.utils.data.Dataset:
"""Build the dataset from configs.

Args:
cfg (Config): The configs.
default_args (Optional[Config]):
The default args. Defaults to None.

Returns:
torch.utils.data.Dataset: The dataset.
"""

from mmcls.datasets.builder import build_dataset
return build_dataset(cfg, default_args)

def train_model(self,
model: torch.nn.Module,
dataset: torch.utils.data.Dataset,
cfg: Config,
distributed: bool = True,
validate: bool = False,
timestamp: Optional[str] = None,
meta: Optional[dict] = None) -> None:
from mmcls.apis.train import train_model
"""Train the model.

Args:
model (torch.nn.Module): The model.
dataset (torch.utils.data.Dataset): The dataset.
cfg (Config): The configs.
distributed (bool):
Whether or not distributed. Defaults to True.
validate (bool):
Whether or not validate. Defaults to False.
timestamp (Optional[str]):
The timestamp. Defaults to None.
meta (Optional[dict]):
The meta. Defaults to None.
"""
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args(task_args)
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)

train_model(
model, dataset, cfg, distributed, validate, timestamp, meta=meta)
return
return args

def run(self, *, args: argparse.Namespace, **kwargs) -> None:
def run(self, args: argparse.Namespace):
"""Run the task.

Args:
args (argparse.Namespace):
The args that received from context manager.
"""

import mmcv
import torch
import torch.distributed as dist
from mmcls import __version__
from mmcls.apis import init_random_seed, set_random_seed
from mmcls.utils import (collect_env, get_root_logger,
setup_multi_processes)

if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(dist.get_rank())
from mmcls.apis import init_random_seed, set_random_seed, train_model
from mmcls.datasets import build_dataset
from mmcls.models import build_classifier
from mmcls.utils import (auto_select_device, collect_env,
get_root_logger, setup_multi_processes)
from mmcv import Config
from mmcv.runner import get_dist_info, init_dist

cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# set multi-process settings
setup_multi_processes(cfg)

# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True

# work_dir is determined in this priority:
# CLI > segment in file > filename
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
Expand All @@ -164,12 +133,32 @@ def run(self, *, args: argparse.Namespace, **kwargs) -> None:
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.')
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = [args.gpu_id]

if args.ipu_replicas is not None:
cfg.ipu_replicas = args.ipu_replicas
args.device = 'ipu'

# init distributed env first, since logger depends on the dist info.
distributed = True
# gpu_ids is used to calculate iter when resuming checkpoint
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)

# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
Expand All @@ -180,64 +169,55 @@ def run(self, *, args: argparse.Namespace, **kwargs) -> None:
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

# set multi-process settings
setup_multi_processes(cfg)

# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + # noqa W504
'\n' + dash_line)
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info

# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
seed = init_random_seed(args.seed)
cfg.device = args.device or auto_select_device()
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)

model = self.build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model = build_classifier(cfg.model)
model.init_weights()

# SyncBN is not support for DP
logger.info(model)

datasets = [self.build_dataset(cfg.data.train)]
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(self.build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmcls version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmcls_version=f'{__version__}+{get_git_hash()[:7]}',
datasets.append(build_dataset(val_dataset))

# save mmcls version, config file content and class names in
# runner as meta data
meta.update(
dict(
mmcls_version=__version__,
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES)
CLASSES=datasets[0].CLASSES))

# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# passing checkpoint meta for saving best checkpoint
meta.update(cfg.checkpoint_config.meta)
self.train_model(
train_model(
model,
datasets,
cfg,
distributed=True,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
device=cfg.device,
meta=meta)
Loading