Skip to content

Commit

Permalink
[Fix] Ensure the type of filename parameter in Config is str (#1725)
Browse files Browse the repository at this point in the history
* ensure type of filename is str

* check filename for func: fromfile

* add ut for fromfile
  • Loading branch information
teamwong111 authored Apr 27, 2022
1 parent de0c103 commit c324b1f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
6 changes: 6 additions & 0 deletions mmcv/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
from pathlib import Path

from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
Expand Down Expand Up @@ -334,6 +335,8 @@ def _merge_a_into_b(a, b, allow_list_keys=False):
def fromfile(filename,
use_predefined_variables=True,
import_custom_modules=True):
if isinstance(filename, Path):
filename = str(filename)
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None):
Expand Down Expand Up @@ -390,6 +393,9 @@ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if key in RESERVED_KEYS:
raise KeyError(f'{key} is reserved for config file')

if isinstance(filename, Path):
filename = str(filename)

super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if cfg_text:
Expand Down
36 changes: 21 additions & 15 deletions tests/test_utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ def test_construct():
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
# test a.py
cfg_file = osp.join(data_path, 'config/a.py')
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.dump() == cfg.pretty_text
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'a.py')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert Config.fromfile(dump_file)
cfg_file_path = Path(cfg_file)
file_list = [cfg_file, cfg_file_path]
for item in file_list:
cfg = Config(cfg_dict, filename=item)
assert isinstance(cfg, Config)
assert isinstance(cfg.filename, str) and cfg.filename == str(item)
assert cfg.text == open(item, 'r').read()
assert cfg.dump() == cfg.pretty_text
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'a.py')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert Config.fromfile(dump_file)

# test b.json
cfg_file = osp.join(data_path, 'config/b.json')
Expand Down Expand Up @@ -142,11 +145,14 @@ def test_construct():
def test_fromfile():
for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']:
cfg_file = osp.join(data_path, 'config', filename)
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
open(cfg_file, 'r').read()
cfg_file_path = Path(cfg_file)
file_list = [cfg_file, cfg_file_path]
for item in file_list:
cfg = Config.fromfile(item)
assert isinstance(cfg, Config)
assert isinstance(cfg.filename, str) and cfg.filename == str(item)
assert cfg.text == osp.abspath(osp.expanduser(item)) + '\n' + \
open(item, 'r').read()

# test custom_imports for Config.fromfile
cfg_file = osp.join(data_path, 'config', 'q.py')
Expand Down

0 comments on commit c324b1f

Please sign in to comment.