This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Prettify the export format of NAS trainer #2389
Merged
ultmaster
merged 8 commits into
microsoft:master
from
ultmaster:nas-human-friendly-export
May 11, 2020
Merged
Changes from 3 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
9a2243f
update
ultmaster 39922d9
fix a few bugs
ultmaster a679bf2
Merge branch 'master' into nas-human-friendly-export
ultmaster c720d16
fix pylint and add comments
ultmaster 601ed83
Merge branch 'nas-human-friendly-export' of github.com:ultmaster/nni …
ultmaster e9e79d2
update docs
ultmaster 3ff5b72
update
ultmaster 04a0cf2
update
ultmaster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,9 @@ | |
|
||
import torch | ||
|
||
from nni.nas.pytorch.mutables import MutableScope | ||
from nni.nas.pytorch.mutator import Mutator | ||
from .mutables import InputChoice, LayerChoice, MutableScope | ||
from .mutator import Mutator | ||
from .utils import to_list | ||
|
||
|
||
class FixedArchitecture(Mutator): | ||
|
@@ -17,8 +18,8 @@ class FixedArchitecture(Mutator): | |
---------- | ||
model : nn.Module | ||
A mutable network. | ||
fixed_arc : str or dict | ||
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). | ||
fixed_arc : dict | ||
Preloaded architecture object. | ||
strict : bool | ||
Force everything that appears in ``fixed_arc`` to be used at least once. | ||
""" | ||
|
@@ -33,6 +34,32 @@ def __init__(self, model, fixed_arc, strict=True): | |
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) | ||
if mutable_keys - fixed_arc_keys: | ||
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) | ||
self._fixed_arc = self._convert_human_readable_architecture(self._fixed_arc) | ||
|
||
def _convert_human_readable_architecture(self, human_arc): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please provide docstring for this function, though it is private. |
||
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. | ||
# First, convert non-list to list. | ||
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()} | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format. | ||
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. | ||
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length. | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for mutable in self.mutables: | ||
if mutable.key not in result_arc: | ||
continue # skip silently | ||
choice_arr = result_arc[mutable.key] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the meaning of "arr"? |
||
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): | ||
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ | ||
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): | ||
# multihot, do nothing | ||
continue | ||
if isinstance(mutable, LayerChoice): | ||
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] | ||
choice_arr = [i in choice_arr for i in range(len(mutable))] | ||
elif isinstance(mutable, InputChoice): | ||
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] | ||
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] | ||
result_arc[mutable.key] = choice_arr | ||
return result_arc | ||
|
||
def sample_search(self): | ||
""" | ||
|
@@ -47,17 +74,6 @@ def sample_final(self): | |
return self._fixed_arc | ||
|
||
|
||
def _encode_tensor(data): | ||
if isinstance(data, list): | ||
if all(map(lambda o: isinstance(o, bool), data)): | ||
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable | ||
else: | ||
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable | ||
if isinstance(data, dict): | ||
return {k: _encode_tensor(v) for k, v in data.items()} | ||
return data | ||
|
||
|
||
def apply_fixed_architecture(model, fixed_arc): | ||
""" | ||
Load architecture from `fixed_arc` and apply to model. | ||
|
@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc): | |
if isinstance(fixed_arc, str): | ||
with open(fixed_arc) as f: | ||
fixed_arc = json.load(f) | ||
fixed_arc = _encode_tensor(fixed_arc) | ||
architecture = FixedArchitecture(model, fixed_arc) | ||
architecture.reset() | ||
return architecture |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,9 @@ | |
import numpy as np | ||
import torch | ||
|
||
from nni.nas.pytorch.base_mutator import BaseMutator | ||
from .base_mutator import BaseMutator | ||
from .mutables import LayerChoice, InputChoice | ||
from .utils import to_list | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -58,7 +60,16 @@ def export(self): | |
dict | ||
A mapping from key of mutables to decisions. | ||
""" | ||
return self.sample_final() | ||
sampled = self.sample_final() | ||
result = dict() | ||
for mutable in self.mutables: | ||
if not isinstance(mutable, (LayerChoice, InputChoice)): | ||
# not supported as built-in | ||
continue | ||
result[mutable.key] = self._convert_mutable_decision(mutable, sampled.pop(mutable.key)) | ||
if sampled: | ||
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys())) | ||
return result | ||
|
||
def status(self): | ||
""" | ||
|
@@ -159,7 +170,7 @@ def _map_fn(op, args, kwargs): | |
mask = self._get_decision(mutable) | ||
assert len(mask) == len(mutable), \ | ||
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable)) | ||
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) | ||
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) | ||
return self._tensor_reduction(mutable.reduction, out), mask | ||
|
||
def on_forward_input_choice(self, mutable, tensor_list): | ||
|
@@ -185,17 +196,23 @@ def on_forward_input_choice(self, mutable, tensor_list): | |
mask = self._get_decision(mutable) | ||
assert len(mask) == mutable.n_candidates, \ | ||
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) | ||
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) | ||
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) | ||
return self._tensor_reduction(mutable.reduction, out), mask | ||
|
||
def _select_with_mask(self, map_fn, candidates, mask): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the new implementation of this function has complex logic, please add docstring for this function. |
||
if "BoolTensor" in mask.type(): | ||
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \ | ||
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \ | ||
"BoolTensor" in mask.type(): | ||
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] | ||
elif "FloatTensor" in mask.type(): | ||
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it can be |
||
(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \ | ||
"FloatTensor" in mask.type(): | ||
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m] | ||
else: | ||
raise ValueError("Unrecognized mask") | ||
return out | ||
raise ValueError("Unrecognized mask '%s'" % mask) | ||
if not torch.is_tensor(mask): | ||
mask = torch.tensor(mask) | ||
return out, mask | ||
|
||
def _tensor_reduction(self, reduction_type, tensor_list): | ||
if reduction_type == "none": | ||
|
@@ -237,3 +254,37 @@ def _get_decision(self, mutable): | |
result = self._cache[mutable.key] | ||
logger.debug("Decision %s: %s", mutable.key, result) | ||
return result | ||
|
||
def _convert_mutable_decision(self, mutable, sampled): | ||
# Assert the existence of mutable.key in returned architecture. | ||
# Also check if there is anything extra. | ||
multihot_list = to_list(sampled) | ||
converted = None | ||
# If it's a boolean array, we can do optimization. | ||
if all([t == 0 or t == 1 for t in multihot_list]): | ||
if isinstance(mutable, LayerChoice): | ||
assert len(multihot_list) == len(mutable), \ | ||
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \ | ||
% (mutable.key, multihot_list) | ||
# check if all modules have different names and they indeed have names | ||
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names): | ||
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]] | ||
else: | ||
converted = [i for i in range(len(multihot_list)) if multihot_list[i]] | ||
if isinstance(mutable, InputChoice): | ||
assert len(multihot_list) == mutable.n_candidates, \ | ||
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \ | ||
% (mutable.key, multihot_list) | ||
# check if all input candidates have different names | ||
if len(set(mutable.choose_from)) == mutable.n_candidates: | ||
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]] | ||
else: | ||
converted = [i for i in range(len(multihot_list)) if multihot_list[i]] | ||
if converted is not None: | ||
# if only one element, then remove the bracket | ||
if len(converted) == 1: | ||
converted = converted[0] | ||
else: | ||
# do nothing | ||
converted = multihot_list | ||
return converted |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function is convert to or from human readable architecture?