Skip to content

Commit

Permalink
Recover lazy object during initialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Aug 24, 2023
1 parent 8861ca9 commit 3cdc47b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 42 deletions.
20 changes: 18 additions & 2 deletions mmengine/config/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def __str__(self) -> str:
return str(self.source) + '.' + self.name
return self.name

__repr__ = __str__

def __repr__(self) -> str:
return f"<Lazy '{str(self)}'>"

Expand Down Expand Up @@ -136,3 +134,21 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def __repr__(self):
return f'<LazyImportContext (enable={self.enable})>'


def recover_lazy_field(cfg):

if isinstance(cfg, dict):
for k, v in cfg.items():
cfg[k] = recover_lazy_field(v)
return cfg
elif isinstance(cfg, (tuple, list)):
container_type = type(cfg)
cfg = list(cfg)
for i, v in enumerate(cfg):
cfg[i] = recover_lazy_field(v)
return container_type(cfg)
elif isinstance(cfg, str):
recover = LazyObject.from_str(cfg)
return recover if recover is not None else cfg
return cfg
44 changes: 16 additions & 28 deletions mmengine/config/new_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from yapf.yapflib.yapf_api import FormatCode

from .config import Config, ConfigDict
from .lazy import LazyImportContext, LazyObject
from .lazy import LazyImportContext, LazyObject, recover_lazy_field

RESERVED_KEYS = ['filename', 'text', 'pretty_text']
_CFG_UID = 0
Expand All @@ -31,24 +31,6 @@ def format_inpsect(obj):
return msg


def recover_lazy_field(cfg):

if isinstance(cfg, dict):
for k, v in cfg.items():
cfg[k] = recover_lazy_field(v)
return cfg
elif isinstance(cfg, (tuple, list)):
container_type = type(cfg)
cfg = list(cfg)
for i, v in enumerate(cfg):
cfg[i] = recover_lazy_field(v)
return container_type(cfg)
elif isinstance(cfg, str):
recover = LazyObject.from_str(cfg)
return recover if recover is not None else cfg
return cfg


def dump_extra_type(value):
if isinstance(value, LazyObject):
return value.dump_str
Expand Down Expand Up @@ -136,6 +118,9 @@ def __init__(self,

if not isinstance(cfg_dict, ConfigDict):
cfg_dict = ConfigDict(cfg_dict)
# Recover dumped lazy object like '<torch.nn.Linear>' from string
cfg_dict = recover_lazy_field(cfg_dict)

super(Config, self).__setattr__('_cfg_dict', cfg_dict)
super(Config, self).__setattr__('_filename', filename)
super(Config, self).__setattr__('_format_python_code',
Expand Down Expand Up @@ -169,9 +154,10 @@ def _sanity_check(cfg):
raise ValueError(msg)

@staticmethod
def fromfile(filename: Union[str, Path],
keep_imported: bool = False,
format_python_code: bool = True) -> 'ConfigV2':
def fromfile( # type: ignore
filename: Union[str, Path],
keep_imported: bool = False,
format_python_code: bool = True) -> 'ConfigV2':
"""Build a Config instance from config file.
Args:
Expand Down Expand Up @@ -202,8 +188,6 @@ def fromfile(filename: Union[str, Path],
module_dict = dict(filter(filter_imports, module_dict.items()))

cfg_dict = ConfigDict(module_dict)
# Recover dumped lazy object like '<torch.nn.Linear>' from string
cfg_dict = recover_lazy_field(cfg_dict)

cfg = ConfigV2(
cfg_dict,
Expand Down Expand Up @@ -232,7 +216,7 @@ def _get_config_module(filename: Union[str, Path]):
with LazyImportContext():
spec = importlib.util.spec_from_file_location(fullname, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
spec.loader.exec_module(module) # type: ignore
sys.modules[fullname] = module

return module
Expand Down Expand Up @@ -345,12 +329,16 @@ def _format_basic_types(input_):

return text

def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], bool]:
def __getstate__(
self
) -> Tuple[dict, Optional[str], Optional[str], bool]: # type: ignore
return (self._cfg_dict, self._filename, self._text,
self._format_python_code)

def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
bool]):
def __setstate__( # type: ignore
self,
state: Tuple[dict, Optional[str], Optional[str], bool],
):
super(Config, self).__setattr__('_cfg_dict', state[0])
super(Config, self).__setattr__('_filename', state[1])
super(Config, self).__setattr__('_text', state[2])
Expand Down
25 changes: 17 additions & 8 deletions mmengine/config/old_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mmengine.utils import (check_file_exist, get_installed_path,
import_modules_from_strings, is_installed)
from .config import BASE_KEY, Config, ConfigDict
from .lazy import recover_lazy_field
from .utils import (ConfigParsingError, RemoveAssignFromAST,
_get_external_cfg_base_path, _get_external_cfg_path,
_get_package_and_cfg_path)
Expand Down Expand Up @@ -91,6 +92,9 @@ def __init__(self,

if not isinstance(cfg_dict, ConfigDict):
cfg_dict = ConfigDict(cfg_dict)
# Recover dumped lazy object like '<torch.nn.Linear>' from string
cfg_dict = recover_lazy_field(cfg_dict)

super(Config, self).__setattr__('_cfg_dict', cfg_dict)
super(Config, self).__setattr__('_filename', filename)
super(Config, self).__setattr__('_format_python_code',
Expand All @@ -108,11 +112,13 @@ def __init__(self,
super(Config, self).__setattr__('_env_variables', env_variables)

@staticmethod
def fromfile(filename: Union[str, Path],
use_predefined_variables: bool = True,
import_custom_modules: bool = True,
use_environment_variables: bool = True,
format_python_code: bool = True) -> 'ConfigV1':
def fromfile( # type: ignore
filename: Union[str, Path],
use_predefined_variables: bool = True,
import_custom_modules: bool = True,
use_environment_variables: bool = True,
format_python_code: bool = True,
) -> 'ConfigV1':
"""Build a Config instance from config file.
Args:
Expand Down Expand Up @@ -668,13 +674,16 @@ def env_variables(self) -> dict:
return self._env_variables

def __getstate__(
self) -> Tuple[dict, Optional[str], Optional[str], dict, bool]:
self
) -> Tuple[dict, Optional[str], Optional[str], dict, bool]: # type: ignore
state = (self._cfg_dict, self._filename, self._text,
self._env_variables, self._format_python_code)
return state

def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
dict, bool]):
def __setstate__( # type: ignore
self,
state: Tuple[dict, Optional[str], Optional[str], dict, bool],
):
super(Config, self).__setattr__('_cfg_dict', state[0])
super(Config, self).__setattr__('_filename', state[1])
super(Config, self).__setattr__('_text', state[2])
Expand Down
4 changes: 2 additions & 2 deletions mmengine/runner/_flexible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ def __init__(
if isinstance(cfg, Config):
self.cfg = copy.deepcopy(cfg)
elif isinstance(cfg, dict):
self.cfg = Config(cfg)
self.cfg = Config(cfg) # type: ignore
else:
self.cfg = Config(dict())
self.cfg = Config(dict()) # type: ignore

# lazy initialization
training_related = [train_dataloader, train_cfg, optim_wrapper]
Expand Down
4 changes: 2 additions & 2 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ def __init__(
if isinstance(cfg, Config):
self.cfg = copy.deepcopy(cfg)
elif isinstance(cfg, dict):
self.cfg = Config(cfg)
self.cfg = Config(cfg) # type: ignore
else:
self.cfg = Config(dict())
self.cfg = Config(dict()) # type: ignore

# lazy initialization
training_related = [train_dataloader, train_cfg, optim_wrapper]
Expand Down

0 comments on commit 3cdc47b

Please sign in to comment.