diff --git a/docs/en_US/NAS/NasGuide.md b/docs/en_US/NAS/NasGuide.md index 6773d28e14..35a2fda442 100644 --- a/docs/en_US/NAS/NasGuide.md +++ b/docs/en_US/NAS/NasGuide.md @@ -156,12 +156,23 @@ model = Net() apply_fixed_architecture(model, "model_dir/final_architecture.json") ``` -The JSON is simply a mapping from mutable keys to one-hot or multi-hot representation of choices. For example +The JSON is simply a mapping from mutable keys to choices. Choices can be expressed in: + +* A string: select the candidate with corresponding name. +* A number: select the candidate with corresponding index. +* A list of string: select the candidates with corresponding names. +* A list of number: select the candidates with corresponding indices. +* A list of boolean values: a multi-hot array. + +For example, ```json { - "LayerChoice1": [false, true, false, false], - "InputChoice2": [true, true, false] + "LayerChoice1": "conv5x5", + "LayerChoice2": 6, + "InputChoice3": ["layer1", "layer3"], + "InputChoice4": [1, 2], + "InputChoice5": [false, true, false, false, true] } ``` diff --git a/src/sdk/pynni/nni/nas/pytorch/fixed.py b/src/sdk/pynni/nni/nas/pytorch/fixed.py index 0be4e0ea79..106368128c 100644 --- a/src/sdk/pynni/nni/nas/pytorch/fixed.py +++ b/src/sdk/pynni/nni/nas/pytorch/fixed.py @@ -3,10 +3,9 @@ import json -import torch - -from nni.nas.pytorch.mutables import MutableScope -from nni.nas.pytorch.mutator import Mutator +from .mutables import InputChoice, LayerChoice, MutableScope +from .mutator import Mutator +from .utils import to_list class FixedArchitecture(Mutator): @@ -17,8 +16,8 @@ class FixedArchitecture(Mutator): ---------- model : nn.Module A mutable network. - fixed_arc : str or dict - Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). + fixed_arc : dict + Preloaded architecture object. strict : bool Force everything that appears in ``fixed_arc`` to be used at least once. """ @@ -33,6 +32,34 @@ def __init__(self, model, fixed_arc, strict=True): raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) if mutable_keys - fixed_arc_keys: raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) + self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc) + + def _from_human_readable_architecture(self, human_arc): + # convert from an exported architecture + result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. + # First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"}, + # which means {"op1": [0, ]} ir {"op1": ["conv", ]} + result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()} + # Second, infer which ones are multi-hot arrays and which ones are in human-readable format. + # This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. + # Here, we assume an multihot array has to be a boolean array or a float array and matches the length. + for mutable in self.mutables: + if mutable.key not in result_arc: + continue # skip silently + choice_arr = result_arc[mutable.key] + if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): + if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ + (isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): + # multihot, do nothing + continue + if isinstance(mutable, LayerChoice): + choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] + choice_arr = [i in choice_arr for i in range(len(mutable))] + elif isinstance(mutable, InputChoice): + choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] + choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] + result_arc[mutable.key] = choice_arr + return result_arc def sample_search(self): """ @@ -47,17 +74,6 @@ def sample_final(self): return self._fixed_arc -def _encode_tensor(data): - if isinstance(data, list): - if all(map(lambda o: isinstance(o, bool), data)): - return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable - else: - return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable - if isinstance(data, dict): - return {k: _encode_tensor(v) for k, v in data.items()} - return data - - def apply_fixed_architecture(model, fixed_arc): """ Load architecture from `fixed_arc` and apply to model. @@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc): if isinstance(fixed_arc, str): with open(fixed_arc) as f: fixed_arc = json.load(f) - fixed_arc = _encode_tensor(fixed_arc) architecture = FixedArchitecture(model, fixed_arc) architecture.reset() return architecture diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index 160a20de84..e9cc68857a 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -7,7 +7,9 @@ import numpy as np import torch -from nni.nas.pytorch.base_mutator import BaseMutator +from .base_mutator import BaseMutator +from .mutables import LayerChoice, InputChoice +from .utils import to_list logger = logging.getLogger(__name__) @@ -58,7 +60,16 @@ def export(self): dict A mapping from key of mutables to decisions. """ - return self.sample_final() + sampled = self.sample_final() + result = dict() + for mutable in self.mutables: + if not isinstance(mutable, (LayerChoice, InputChoice)): + # not supported as built-in + continue + result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key)) + if sampled: + raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys())) + return result def status(self): """ @@ -159,7 +170,7 @@ def _map_fn(op, args, kwargs): mask = self._get_decision(mutable) assert len(mask) == len(mutable), \ "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable)) - out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) + out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) return self._tensor_reduction(mutable.reduction, out), mask def on_forward_input_choice(self, mutable, tensor_list): @@ -185,17 +196,41 @@ def on_forward_input_choice(self, mutable, tensor_list): mask = self._get_decision(mutable) assert len(mask) == mutable.n_candidates, \ "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) - out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) + out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) return self._tensor_reduction(mutable.reduction, out), mask def _select_with_mask(self, map_fn, candidates, mask): - if "BoolTensor" in mask.type(): + """ + Select masked tensors and return a list of tensors. + + Parameters + ---------- + map_fn : function + Convert candidates to target candidates. Can be simply identity. + candidates : list of torch.Tensor + Tensor list to apply the decision on. + mask : list-like object + Can be a list, an numpy array or a tensor (recommended). Needs to + have the same length as ``candidates``. + + Returns + ------- + tuple of list of torch.Tensor and torch.Tensor + Output and mask. + """ + if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \ + (isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \ + "BoolTensor" in mask.type(): out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] - elif "FloatTensor" in mask.type(): + elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \ + (isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \ + "FloatTensor" in mask.type(): out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m] else: - raise ValueError("Unrecognized mask") - return out + raise ValueError("Unrecognized mask '%s'" % mask) + if not torch.is_tensor(mask): + mask = torch.tensor(mask) # pylint: disable=not-callable + return out, mask def _tensor_reduction(self, reduction_type, tensor_list): if reduction_type == "none": @@ -237,3 +272,37 @@ def _get_decision(self, mutable): result = self._cache[mutable.key] logger.debug("Decision %s: %s", mutable.key, result) return result + + def _convert_mutable_decision_to_human_readable(self, mutable, sampled): + # Assert the existence of mutable.key in returned architecture. + # Also check if there is anything extra. + multihot_list = to_list(sampled) + converted = None + # If it's a boolean array, we can do optimization. + if all([t == 0 or t == 1 for t in multihot_list]): + if isinstance(mutable, LayerChoice): + assert len(multihot_list) == len(mutable), \ + "Results returned from 'sample_final()' (%s: %s) either too short or too long." \ + % (mutable.key, multihot_list) + # check if all modules have different names and they indeed have names + if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names): + converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]] + else: + converted = [i for i in range(len(multihot_list)) if multihot_list[i]] + if isinstance(mutable, InputChoice): + assert len(multihot_list) == mutable.n_candidates, \ + "Results returned from 'sample_final()' (%s: %s) either too short or too long." \ + % (mutable.key, multihot_list) + # check if all input candidates have different names + if len(set(mutable.choose_from)) == mutable.n_candidates: + converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]] + else: + converted = [i for i in range(len(multihot_list)) if multihot_list[i]] + if converted is not None: + # if only one element, then remove the bracket + if len(converted) == 1: + converted = converted[0] + else: + # do nothing + converted = multihot_list + return converted diff --git a/src/sdk/pynni/nni/nas/pytorch/utils.py b/src/sdk/pynni/nni/nas/pytorch/utils.py index 7536740eb3..a3f5aabfb7 100644 --- a/src/sdk/pynni/nni/nas/pytorch/utils.py +++ b/src/sdk/pynni/nni/nas/pytorch/utils.py @@ -4,6 +4,7 @@ import logging from collections import OrderedDict +import numpy as np import torch _counter = 0 @@ -45,6 +46,16 @@ def to_device(obj, device): raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj))) +def to_list(arr): + if torch.is_tensor(arr): + return arr.cpu().numpy().tolist() + if isinstance(arr, np.ndarray): + return arr.tolist() + if isinstance(arr, (list, tuple)): + return list(arr) + return arr + + class AverageMeterGroup: """ Average meter group for multiple average meters.