Skip to content

Commit

Permalink
handle list fo configs in json config adpator
Browse files Browse the repository at this point in the history
Summary:
Currently, flsim cannot handle a list of configs
Eg: a json config like the following is not supported.
```
{
  "trainer": {
    "clients": [
      {"_base_": "base_client", "optimizer": {"lr": 0.1}},
      {"_base_": "base_client", "optimizer": {"lr": 0.2}},
    ]
  }
```

The one hiccup in supporting this is hydra still doesn't support overriding/appending lists when values are configs. (see facebookresearch/hydra#1939 (comment))

In order to overcome the above, we treat the list as a dictionary with the key as list index.

The above config in yaml format will look as follows:
```
trainer:
  ...
  clients:
    '0':
      _target_: flsim.clients.base_client.Client
      _recursive_: false
      epochs: 1
      optimizer:
        _target_: ???
        _recursive_: false
        lr: 0.1
        momentum: 0.0
        weight_decay: 0.0
      lr_scheduler:
        _target_: ???
        _recursive_: false
        base_lr: 0.001
      max_clip_norm_normalized: null
      only_federated_params: true
      random_seed: null
      shuffle_batch_order: false
      store_models_and_optimizers: false
      track_multiple_selection: false
    '1':
      _target_: flsim.clients.base_client.Client
      _recursive_: false
      epochs: 1
      optimizer:
        _target_: ???
        _recursive_: false
        lr: 0.1
        momentum: 0.0
        weight_decay: 0.0
      lr_scheduler:
        _target_: ???
        _recursive_: false
        base_lr: 0.001
      max_clip_norm_normalized: null
      only_federated_params: true
      random_seed: null
      shuffle_batch_order: false
      store_models_and_optimizers: false
      track_multiple_selection: false
  ...
```

Reviewed By: Anonymani

Differential Revision: D37638999

fbshipit-source-id: 5444da742742d4cc976875a6dc055a0c71c186e4
  • Loading branch information
karthikprasad authored and facebook-github-bot committed Jul 6, 2022
1 parent bae5e65 commit fb884b3
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 6 deletions.
23 changes: 18 additions & 5 deletions flsim/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,25 @@ def _flatten_dict(
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
# if value is not a dict and is mutable, extend the items and flatten again.
# if value is not supposed to be a dict but is mappable,
# then extend the items and flatten again.
# > hacky way of preserving dict values by checking if key has _dict as suffix.
if not new_key.endswith("_dict") and isinstance(v, abc.MutableMapping):
items.extend(_flatten_dict(v, new_key, sep=sep).items())
else:
elif type(v) is list:
# handle config in lists (doesn't support nested lists yet)
for i in range(len(v)):
if isinstance(v[i], abc.MutableMapping) and "_base_" in v[i].keys():
new_key_i = f"{new_key}.{i}"
items.extend(_flatten_dict(v[i], new_key_i, sep=sep).items())
else:
items.append((new_key, v))
elif type(v) is str and v.replace(".", "", 1).isdigit():
# check if a number needs to be retained as a string
# the repalce with one dot is needed to handle floats
if type(v) is str and v.replace(".", "", 1).isdigit():
v = f'"{v}"' # enclose it with quotes if so.
# the repalce with one dot and check is needed to handle floats
v = f'"{v}"' # enclose it with quotes if so.:
items.append((new_key, v))
else:
items.append((new_key, v))
return dict(items)

Expand Down Expand Up @@ -164,6 +174,9 @@ def fl_json_to_dotlist(
k = k.replace("._base_", "")
# extract aggregator from trainer.aggregator
config_group = k.split(".")[-1]
config_group = (
config_group if not config_group.isdigit() else k.split(".")[-2]
)
# trainer.aggregator --> +aggregator@trainer.aggregator
k = f"+{config_group}@{k}"
# +aggregator@trainer.aggregator=base_fed_avg_with_lr_sync_aggregator
Expand Down
82 changes: 81 additions & 1 deletion flsim/utils/tests/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,52 @@ def test_flatten_dict(self) -> None:
"a": {"b": 1},
"l": [1, 2, 3],
"ld": [{"a": 1, "b": {"bb": 2}, "c": [11, 22]}, {"z": "xyz"}],
"ldc": [
{
"_base_": "parent_base",
"x": 1,
"child": {
"_base_": "child_base",
"yl": [2, 3],
"yd": {"z": "4"},
"yld": [{"z": "4"}],
"yc": [{"_base_": "gc_base", "dummy": "one"}],
},
},
{
"_base_": "parent_base",
"x": 11,
"child": {
"_base_": "child_base",
"yl": [12, 13],
"yd": {"z": "14"},
"yld": [{"z": "14"}],
"yc": [{"_base_": "gc_base", "dummy": "two"}],
},
},
],
}
),
{
"a.b": 1,
"l": [1, 2, 3],
"ld": [{"a": 1, "b": {"bb": 2}, "c": [11, 22]}, {"z": "xyz"}],
"ldc.0._base_": "parent_base",
"ldc.0.x": 1,
"ldc.0.child._base_": "child_base",
"ldc.0.child.yl": [2, 3],
"ldc.0.child.yd.z": '"4"',
"ldc.0.child.yld": [{"z": "4"}],
"ldc.0.child.yc.0._base_": "gc_base",
"ldc.0.child.yc.0.dummy": "one",
"ldc.1._base_": "parent_base",
"ldc.1.x": 11,
"ldc.1.child._base_": "child_base",
"ldc.1.child.yl": [12, 13],
"ldc.1.child.yd.z": '"14"',
"ldc.1.child.yld": [{"z": "14"}],
"ldc.1.child.yc.0._base_": "gc_base",
"ldc.1.child.yc.0.dummy": "two",
},
)

Expand Down Expand Up @@ -226,19 +266,59 @@ def test_json_to_dotlist_append_or_override(self) -> None:
# checks string floats
assertEqual(fl_json_to_dotlist({"e": "5.5"}), ['++e="5.5"'])

# make sure json in list remains untouched
# check handling of list values
assertEqual(
fl_json_to_dotlist(
{
"a": {"b": 1},
"l": [1, 2, 3],
"ld": [{"a": 1, "b": {"bb": 2}, "c": [11, 22]}, {"z": "xyz"}],
"ldc": [
{
"_base_": "parent_base",
"x": 1,
"child": {
"_base_": "child_base",
"yl": [2, 3],
"yd": {"z": "4"},
"yld": [{"z": "4"}],
"yc": [{"_base_": "gc_base", "dummy": "one"}],
},
},
{
"_base_": "parent_base",
"x": 11,
"child": {
"_base_": "child_base",
"yl": [12, 13],
"yd": {"z": "14"},
"yld": [{"z": "14"}],
"yc": [{"_base_": "gc_base", "dummy": "two"}],
},
},
],
}
),
[
"+ldc@ldc.0=parent_base",
"+ldc@ldc.1=parent_base",
"+child@ldc.0.child=child_base",
"+child@ldc.1.child=child_base",
"+yc@ldc.0.child.yc.0=gc_base",
"+yc@ldc.1.child.yc.0=gc_base",
"++l=[1, 2, 3]",
"++ld=[{'a': 1, 'b': {'bb': 2}, 'c': [11, 22]}, {'z': 'xyz'}]",
"++a.b=1",
"++ldc.0.x=1",
"++ldc.1.x=11",
"++ldc.0.child.yl=[2, 3]",
"++ldc.0.child.yld=[{'z': '4'}]",
"++ldc.1.child.yl=[12, 13]",
"++ldc.1.child.yld=[{'z': '14'}]",
'++ldc.0.child.yd.z="4"',
'++ldc.1.child.yd.z="14"',
"++ldc.0.child.yc.0.dummy=one",
"++ldc.1.child.yc.0.dummy=two",
],
)

Expand Down

0 comments on commit fb884b3

Please sign in to comment.