-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Support loading from state_dict
of supernet
#4544
Support loading from state_dict
of supernet
#4544
Conversation
@@ -304,6 +304,8 @@ def make_list(x): return x if isinstance(x, list) else [x] | |||
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)], | |||
adjacency_list, in_features, out_features, num_nodes, projection) | |||
|
|||
# FIXME: weight inheritance on nasbench101 is not supported yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why nasbench101 is different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NAS-Bench-101 has a customized recipe to create a fixed cell, which includes removing unused nodes and reordering nodes to make the cell more compact. This makes supporting weight sharing itself on NAS-Bench-101 difficult, not to mention loading from super-net.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so does NAS-Bench-101 support weight sharing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double confirm, does graph engine support weight sharing nas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so does NAS-Bench-101 support weight sharing?
Currently, no.
double confirm, does graph engine support weight sharing nas?
Graph engine and weight sharing are unrelated components. There's no such thing like "support".
nni/retiarii/nn/pytorch/api.py
Outdated
|
||
# map the named hierarchies to support weight inheritance for python engine | ||
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL): | ||
# already has a mapping, will merge with it | ||
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL) | ||
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()}) | ||
else: | ||
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'': str(chosen)}) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I failed to connect this part with supernet and state dict load/dump, maybe you can brief this pr in a quick discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The modules has an attribute called STATE_DICT_PY_MAPPING_PARTIAL
, which records which variables belongs to which module originally. Then the hooks use those attributes to recover the full mapping, so as to map the original state dict into the current module hierarchy.
return candidates[chosen] | ||
result = candidates[chosen] | ||
|
||
# map the named hierarchies to support weight inheritance for python engine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have not supported value choice, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Value choice are "values", there is no weight to load into value choice.
nni/retiarii/nn/pytorch/api.py
Outdated
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL) | ||
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()}) | ||
else: | ||
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'': str(chosen)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to give one example, which may greatly help the understanding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description
Support a state dict hook to load from supernet's state-dict.
Checklist
How to test