Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Prettify the export format of NAS trainer #2389

Merged
merged 8 commits into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions src/sdk/pynni/nni/nas/pytorch/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import torch

from nni.nas.pytorch.mutables import MutableScope
from nni.nas.pytorch.mutator import Mutator
from .mutables import InputChoice, LayerChoice, MutableScope
from .mutator import Mutator
from .utils import to_list


class FixedArchitecture(Mutator):
Expand All @@ -17,8 +18,8 @@ class FixedArchitecture(Mutator):
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
fixed_arc : dict
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
Expand All @@ -33,6 +34,32 @@ def __init__(self, model, fixed_arc, strict=True):
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
self._fixed_arc = self._convert_human_readable_architecture(self._fixed_arc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is convert to or from human readable architecture?


def _convert_human_readable_architecture(self, human_arc):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide docstring for this function, though it is private.

result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc.
# First, convert non-list to list.
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()}
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
for mutable in self.mutables:
if mutable.key not in result_arc:
continue # skip silently
choice_arr = result_arc[mutable.key]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of "arr"?

if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr):
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)):
# multihot, do nothing
continue
if isinstance(mutable, LayerChoice):
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(len(mutable))]
elif isinstance(mutable, InputChoice):
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)]
result_arc[mutable.key] = choice_arr
return result_arc

def sample_search(self):
"""
Expand All @@ -47,17 +74,6 @@ def sample_final(self):
return self._fixed_arc


def _encode_tensor(data):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: _encode_tensor(v) for k, v in data.items()}
return data


def apply_fixed_architecture(model, fixed_arc):
"""
Load architecture from `fixed_arc` and apply to model.
Expand All @@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc)
architecture.reset()
return architecture
67 changes: 59 additions & 8 deletions src/sdk/pynni/nni/nas/pytorch/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy as np
import torch

from nni.nas.pytorch.base_mutator import BaseMutator
from .base_mutator import BaseMutator
from .mutables import LayerChoice, InputChoice
from .utils import to_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +60,16 @@ def export(self):
dict
A mapping from key of mutables to decisions.
"""
return self.sample_final()
sampled = self.sample_final()
result = dict()
for mutable in self.mutables:
if not isinstance(mutable, (LayerChoice, InputChoice)):
# not supported as built-in
continue
result[mutable.key] = self._convert_mutable_decision(mutable, sampled.pop(mutable.key))
if sampled:
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys()))
return result

def status(self):
"""
Expand Down Expand Up @@ -159,7 +170,7 @@ def _map_fn(op, args, kwargs):
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask

def on_forward_input_choice(self, mutable, tensor_list):
Expand All @@ -185,17 +196,23 @@ def on_forward_input_choice(self, mutable, tensor_list):
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask

def _select_with_mask(self, map_fn, candidates, mask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new implementation of this function has complex logic, please add docstring for this function.

if "BoolTensor" in mask.type():
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \
"BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be int?

(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \
"FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError("Unrecognized mask")
return out
raise ValueError("Unrecognized mask '%s'" % mask)
if not torch.is_tensor(mask):
mask = torch.tensor(mask)
return out, mask

def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
Expand Down Expand Up @@ -237,3 +254,37 @@ def _get_decision(self, mutable):
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result

def _convert_mutable_decision(self, mutable, sampled):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list = to_list(sampled)
converted = None
# If it's a boolean array, we can do optimization.
if all([t == 0 or t == 1 for t in multihot_list]):
if isinstance(mutable, LayerChoice):
assert len(multihot_list) == len(mutable), \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all modules have different names and they indeed have names
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names):
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if isinstance(mutable, InputChoice):
assert len(multihot_list) == mutable.n_candidates, \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all input candidates have different names
if len(set(mutable.choose_from)) == mutable.n_candidates:
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if converted is not None:
# if only one element, then remove the bracket
if len(converted) == 1:
converted = converted[0]
else:
# do nothing
converted = multihot_list
return converted
11 changes: 11 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections import OrderedDict

import numpy as np
import torch

_counter = 0
Expand Down Expand Up @@ -45,6 +46,16 @@ def to_device(obj, device):
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))


def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr


class AverageMeterGroup:
"""
Average meter group for multiple average meters.
Expand Down