Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Prettify the export format of NAS trainer (#2389)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored May 11, 2020
1 parent af80021 commit bf7daa8
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 29 deletions.
17 changes: 14 additions & 3 deletions docs/en_US/NAS/NasGuide.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,23 @@ model = Net()
apply_fixed_architecture(model, "model_dir/final_architecture.json")
```

The JSON is simply a mapping from mutable keys to one-hot or multi-hot representation of choices. For example
The JSON is simply a mapping from mutable keys to choices. Choices can be expressed in:

* A string: select the candidate with corresponding name.
* A number: select the candidate with corresponding index.
* A list of string: select the candidates with corresponding names.
* A list of number: select the candidates with corresponding indices.
* A list of boolean values: a multi-hot array.

For example,

```json
{
"LayerChoice1": [false, true, false, false],
"InputChoice2": [true, true, false]
"LayerChoice1": "conv5x5",
"LayerChoice2": 6,
"InputChoice3": ["layer1", "layer3"],
"InputChoice4": [1, 2],
"InputChoice5": [false, true, false, false, true]
}
```

Expand Down
51 changes: 33 additions & 18 deletions src/sdk/pynni/nni/nas/pytorch/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import json

import torch

from nni.nas.pytorch.mutables import MutableScope
from nni.nas.pytorch.mutator import Mutator
from .mutables import InputChoice, LayerChoice, MutableScope
from .mutator import Mutator
from .utils import to_list


class FixedArchitecture(Mutator):
Expand All @@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
fixed_arc : dict
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
Expand All @@ -33,6 +32,34 @@ def __init__(self, model, fixed_arc, strict=True):
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc)

def _from_human_readable_architecture(self, human_arc):
# convert from an exported architecture
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc.
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
# which means {"op1": [0, ]} ir {"op1": ["conv", ]}
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()}
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
for mutable in self.mutables:
if mutable.key not in result_arc:
continue # skip silently
choice_arr = result_arc[mutable.key]
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr):
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)):
# multihot, do nothing
continue
if isinstance(mutable, LayerChoice):
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(len(mutable))]
elif isinstance(mutable, InputChoice):
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)]
result_arc[mutable.key] = choice_arr
return result_arc

def sample_search(self):
"""
Expand All @@ -47,17 +74,6 @@ def sample_final(self):
return self._fixed_arc


def _encode_tensor(data):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: _encode_tensor(v) for k, v in data.items()}
return data


def apply_fixed_architecture(model, fixed_arc):
"""
Load architecture from `fixed_arc` and apply to model.
Expand All @@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc)
architecture.reset()
return architecture
85 changes: 77 additions & 8 deletions src/sdk/pynni/nni/nas/pytorch/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy as np
import torch

from nni.nas.pytorch.base_mutator import BaseMutator
from .base_mutator import BaseMutator
from .mutables import LayerChoice, InputChoice
from .utils import to_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +60,16 @@ def export(self):
dict
A mapping from key of mutables to decisions.
"""
return self.sample_final()
sampled = self.sample_final()
result = dict()
for mutable in self.mutables:
if not isinstance(mutable, (LayerChoice, InputChoice)):
# not supported as built-in
continue
result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key))
if sampled:
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys()))
return result

def status(self):
"""
Expand Down Expand Up @@ -159,7 +170,7 @@ def _map_fn(op, args, kwargs):
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask

def on_forward_input_choice(self, mutable, tensor_list):
Expand All @@ -185,17 +196,41 @@ def on_forward_input_choice(self, mutable, tensor_list):
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask

def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
"""
Select masked tensors and return a list of tensors.
Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.
Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \
"BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \
(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \
"FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError("Unrecognized mask")
return out
raise ValueError("Unrecognized mask '%s'" % mask)
if not torch.is_tensor(mask):
mask = torch.tensor(mask) # pylint: disable=not-callable
return out, mask

def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
Expand Down Expand Up @@ -237,3 +272,37 @@ def _get_decision(self, mutable):
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result

def _convert_mutable_decision_to_human_readable(self, mutable, sampled):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list = to_list(sampled)
converted = None
# If it's a boolean array, we can do optimization.
if all([t == 0 or t == 1 for t in multihot_list]):
if isinstance(mutable, LayerChoice):
assert len(multihot_list) == len(mutable), \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all modules have different names and they indeed have names
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names):
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if isinstance(mutable, InputChoice):
assert len(multihot_list) == mutable.n_candidates, \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all input candidates have different names
if len(set(mutable.choose_from)) == mutable.n_candidates:
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if converted is not None:
# if only one element, then remove the bracket
if len(converted) == 1:
converted = converted[0]
else:
# do nothing
converted = multihot_list
return converted
11 changes: 11 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections import OrderedDict

import numpy as np
import torch

_counter = 0
Expand Down Expand Up @@ -45,6 +46,16 @@ def to_device(obj, device):
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))


def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr


class AverageMeterGroup:
"""
Average meter group for multiple average meters.
Expand Down

0 comments on commit bf7daa8

Please sign in to comment.