-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Prettify the export format of NAS trainer #2389
Prettify the export format of NAS trainer #2389
Conversation
please fix pylint error |
please update doc accordingly |
…into nas-human-friendly-export
docs/en_US/NAS/NasGuide.md
Outdated
"LayerChoice1": [false, true, false, false], | ||
"InputChoice2": [true, true, false] | ||
"LayerChoice1": "conv5x5", | ||
"InputChoice2": [1, 2], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of 1, 2? the index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it could be either index or name? when it is index and when it is name?
@@ -33,6 +32,33 @@ def __init__(self, model, fixed_arc, strict=True): | |||
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) | |||
if mutable_keys - fixed_arc_keys: | |||
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) | |||
self._fixed_arc = self._convert_human_readable_architecture(self._fixed_arc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function is convert to or from human readable architecture?
for mutable in self.mutables: | ||
if mutable.key not in result_arc: | ||
continue # skip silently | ||
choice_arr = result_arc[mutable.key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of "arr"?
@@ -33,6 +32,33 @@ def __init__(self, model, fixed_arc, strict=True): | |||
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) | |||
if mutable_keys - fixed_arc_keys: | |||
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) | |||
self._fixed_arc = self._convert_human_readable_architecture(self._fixed_arc) | |||
|
|||
def _convert_human_readable_architecture(self, human_arc): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please provide docstring for this function, though it is private.
@@ -185,17 +196,23 @@ def on_forward_input_choice(self, mutable, tensor_list): | |||
mask = self._get_decision(mutable) | |||
assert len(mask) == mutable.n_candidates, \ | |||
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) | |||
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) | |||
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) | |||
return self._tensor_reduction(mutable.reduction, out), mask | |||
|
|||
def _select_with_mask(self, map_fn, candidates, mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the new implementation of this function has complex logic, please add docstring for this function.
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] | ||
elif "FloatTensor" in mask.type(): | ||
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can be int
?
Pending #2386 as it uses APIs introduced in that PR. I coded on another branch merging #2386 and cherry picked the changes.
Note that this PR will be backward-compatible itself. Old checkpoints will still be valid.
This PR prettifies the export format of NAS trainer (point 2 of #2316). Here are three examples (of the new export format):
P-DARTS (first):
ENAS macro:
ENAS micro: