Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename MMTask and BackendConfig #119

Merged
merged 4 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions siatune/codebase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from .builder import TASKS, build_task_processor
from .cont_test_func import ContinuousTestFunction
from .disc_test_func import DiscreteTestFunction
from .mm import MMBaseTask
from .mmcls import MMClassification
from .mmdet import MMDetection
from .mmedit import MMEditing
from .mmseg import MMSegmentation
from .mmtrainbase import MMTrainBasedTask

__all__ = [
'TASKS',
Expand All @@ -17,7 +17,7 @@
'BlackBoxTask',
'ContinuousTestFunction',
'DiscreteTestFunction',
'MMTrainBasedTask',
'MMBaseTask',
'MMClassification',
'MMDetection',
'MMEditing',
Expand Down
2 changes: 1 addition & 1 deletion siatune/codebase/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .base import BaseTask

TASKS = Registry('tasks')
TASKS = Registry('task')


def build_task_processor(task: Dict) -> BaseTask:
Expand Down
6 changes: 3 additions & 3 deletions siatune/codebase/mmtrainbase.py → siatune/codebase/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from ray.air.config import ScalingConfig
from ray.train.data_parallel_trainer import DataParallelTrainer

from siatune.tune import CustomBackendConfig
from siatune.tune import MMBackendConfig
from .base import BaseTask
from .builder import TASKS


@TASKS.register_module()
class MMTrainBasedTask(BaseTask, metaclass=ABCMeta):
class MMBaseTask(BaseTask, metaclass=ABCMeta):
"""Wrap the apis of open mm train-based projects."""

def create_trainable(self) -> DataParallelTrainer:
Expand All @@ -23,7 +23,7 @@ def create_trainable(self) -> DataParallelTrainer:

return DataParallelTrainer(
self.context_aware_run,
backend_config=CustomBackendConfig(),
backend_config=MMBackendConfig(),
scaling_config=ScalingConfig(
trainer_resources=dict(CPU=self.num_cpus_per_worker),
num_workers=self.num_workers,
Expand Down
4 changes: 2 additions & 2 deletions siatune/codebase/mmcls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from typing import Sequence

from .builder import TASKS
from .mmtrainbase import MMTrainBasedTask
from .mm import MMBaseTask


@TASKS.register_module()
class MMClassification(MMTrainBasedTask):
class MMClassification(MMBaseTask):
"""MMClassification wrapper class for `ray.tune`.

It is modified from https://github.com/open-mmlab/mmclassification/blob/v0.23.2/tools/train.py
Expand Down
4 changes: 2 additions & 2 deletions siatune/codebase/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from typing import Sequence

from .builder import TASKS
from .mmtrainbase import MMTrainBasedTask
from .mm import MMBaseTask


@TASKS.register_module()
class MMDetection(MMTrainBasedTask):
class MMDetection(MMBaseTask):
"""MMDetection wrapper class for `ray.tune`.

It is modified from https://github.com/open-mmlab/mmdetection/blob/v2.25.2/tools/train.py
Expand Down
4 changes: 2 additions & 2 deletions siatune/codebase/mmedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from typing import Sequence

from .builder import TASKS
from .mmtrainbase import MMTrainBasedTask
from .mm import MMBaseTask


@TASKS.register_module()
class MMEditing(MMTrainBasedTask):
class MMEditing(MMBaseTask):
"""MMEditing wrapper class for `ray.tune`.

It is modified from https://github.com/open-mmlab/mmediting/blob/v0.15.0/tools/train.py
Expand Down
4 changes: 2 additions & 2 deletions siatune/codebase/mmseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from typing import Sequence

from .builder import TASKS
from .mmtrainbase import MMTrainBasedTask
from .mm import MMBaseTask


@TASKS.register_module()
class MMSegmentation(MMTrainBasedTask):
class MMSegmentation(MMBaseTask):
"""MMSegmentation wrapper class for `ray.tune`.

It is modified from https://github.com/open-mmlab/mmsegmentation/blob/v0.25.0/tools/train.py
Expand Down
4 changes: 2 additions & 2 deletions siatune/tune/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) SI-Analytics. All rights reserved.
from .callbacks import * # noqa F403
from .config import CustomBackendConfig
from .config import MMBackendConfig
from .schedulers import * # noqa F403
from .searchers import * # noqa F403
from .spaces import * # noqa F403
from .stoppers import * # noqa F403
from .tuner import Tuner

__all__ = ['CustomBackendConfig', 'Tuner']
__all__ = ['MMBackendConfig', 'Tuner']
2 changes: 1 addition & 1 deletion siatune/tune/callbacks/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
LegacyLoggerCallback, LoggerCallback,
TBXLoggerCallback)

CALLBACKS = Registry('callbacks')
CALLBACKS = Registry('callback')
CALLBACKS.register_module(module=LegacyLoggerCallback)
CALLBACKS.register_module(module=JsonLoggerCallback)
CALLBACKS.register_module(module=CSVLoggerCallback)
Expand Down
2 changes: 1 addition & 1 deletion siatune/tune/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@dataclass
class CustomBackendConfig(BackendConfig):
class MMBackendConfig(BackendConfig):
"""Configuration for torch process group setup."""

@property
Expand Down
2 changes: 1 addition & 1 deletion siatune/tune/spaces/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mmcv.utils import Registry

SPACES = Registry('spaces')
SPACES = Registry('space')


def build_space(cfg: dict) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion siatune/tune/stoppers/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mmcv.utils import Config, Registry
from ray import tune

STOPPERS = Registry('stoppers')
STOPPERS = Registry('stopper')
for stopper in dir(tune.stopper):
if not stopper.endswith('Stopper'):
continue
Expand Down
4 changes: 2 additions & 2 deletions tests/test_codebase/test_blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def report_to_session(**kwargs):


@patch('ray.tune.report', side_effect=report_to_session)
def test_continuous_test_function(*not_used):
def test_continuous_test_function(*mocks):
func = ContinuousTestFunction()
predefined_cont_funcs = [
'delayedsphere',
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_continuous_test_function(*not_used):


@patch('ray.tune.report', side_effect=report_to_session)
def test_discrete_test_function(*not_used):
def test_discrete_test_function(*mocks):
func = DiscreteTestFunction()

predefined_discrete_funcs = ['onemax', 'leadingones', 'jump']
Expand Down
8 changes: 4 additions & 4 deletions tests/test_codebase/test_mmtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@patch('mmcls.apis.train_model')
@patch('mmcls.datasets.build_dataset')
@patch('mmcls.models.build_classifier')
def test_mmcls(*not_used):
def test_mmcls(*mocks):
task = MMClassification()
task_args = ['tests/data/config.py']
task.set_args(task_args)
Expand All @@ -22,7 +22,7 @@ def test_mmcls(*not_used):
@patch('mmdet.apis.train_detector')
@patch('mmdet.datasets.build_dataset')
@patch('mmdet.models.build_detector')
def test_mmdet(*not_used):
def test_mmdet(*mocks):
task = MMDetection()
task_args = ['tests/data/config.py']
task.set_args(task_args)
Expand All @@ -32,7 +32,7 @@ def test_mmdet(*not_used):
@patch('mmedit.apis.train_model')
@patch('mmedit.datasets.build_dataset')
@patch('mmedit.models.build_model')
def test_mmedit(*not_used):
def test_mmedit(*mocks):
task = MMEditing()
task_args = ['tests/data/config.py']
task.set_args(task_args)
Expand All @@ -42,7 +42,7 @@ def test_mmedit(*not_used):
@patch('mmseg.apis.train_segmentor')
@patch('mmseg.datasets.build_dataset')
@patch('mmseg.models.build_segmentor')
def test_mmseg(*not_used):
def test_mmseg(*mocks):
task = MMSegmentation()
task_args = ['tests/data/config.py']
task.set_args(task_args)
Expand Down