Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] New API: Repeat and Cell #3481

Merged
merged 17 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/en_US/NAS/retiarii/ApiReference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ Inline Mutation APIs
.. autoclass:: nni.retiarii.nn.pytorch.ChosenInputs
:members:

.. autoclass:: nni.retiarii.nn.pytorch.Repeat
:members:

.. autoclass:: nni.retiarii.nn.pytorch.Cell
:members:

Graph Mutation APIs
-------------------

Expand Down
10 changes: 10 additions & 0 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,16 @@ def convert_module(self, script_module, module, module_name, ir_model):

ir_graph._register()

# add mutation signal for special modules
if original_type_name == OpTypeName.Repeat:
attrs = {
'mutation': 'repeat',
'label': module.label,
'min_depth': module.min_depth,
'max_depth': module.max_depth
}
return ir_graph, attrs

return ir_graph, {}


Expand Down
2 changes: 2 additions & 0 deletions nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ class OpTypeName(str, Enum):
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
Repeat = 'Repeat'
Cell = 'Cell'
1 change: 1 addition & 0 deletions nni/retiarii/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .api import *
from .component import *
from .nn import *
40 changes: 13 additions & 27 deletions nni/retiarii/nn/pytorch/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,12 @@
import torch.nn as nn

from ...serializer import Translatable, basic_unit
from ...utils import uid, get_current_context
from .utils import generate_new_label, get_fixed_value


__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']


def _generate_new_label(label: Optional[str]):
if label is None:
return '_mutation_' + str(uid('mutation'))
return label


def _get_fixed_value(label: str):
ret = get_current_context('fixed')
try:
return ret[_generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')


class LayerChoice(nn.Module):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
Expand Down Expand Up @@ -69,17 +55,17 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""

def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
try:
chosen = _get_fixed_value(label)
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
except AssertionError:
return super().__new__(cls)

def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
super(LayerChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
Expand All @@ -89,7 +75,7 @@ def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], lab
if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates
self._label = _generate_new_label(label)
self._label = generate_new_label(label)

self.names = []
if isinstance(candidates, OrderedDict):
Expand Down Expand Up @@ -187,13 +173,13 @@ class InputChoice(nn.Module):
Identifier of the input choice.
"""

def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
try:
return ChosenInputs(_get_fixed_value(label), reduction=reduction)
return ChosenInputs(get_fixed_value(label), reduction=reduction)
except AssertionError:
return super().__new__(cls)

def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
super(InputChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
Expand All @@ -206,7 +192,7 @@ def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum',
self.n_chosen = n_chosen
self.reduction = reduction
assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = _generate_new_label(label)
self._label = generate_new_label(label)

@property
def key(self):
Expand Down Expand Up @@ -295,16 +281,16 @@ def forward(self, x):
Identifier of the value choice.
"""

def __new__(cls, candidates: List[Any], label: str = None):
def __new__(cls, candidates: List[Any], label: Optional[str] = None):
try:
return _get_fixed_value(label)
return get_fixed_value(label)
except AssertionError:
return super().__new__(cls)

def __init__(self, candidates: List[Any], label: str = None):
def __init__(self, candidates: List[Any], label: Optional[str] = None):
super().__init__()
self.candidates = candidates
self._label = _generate_new_label(label)
self._label = generate_new_label(label)
self._accessor = []

@property
Expand Down
147 changes: 147 additions & 0 deletions nni/retiarii/nn/pytorch/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import copy
from typing import Callable, List, Union, Tuple, Optional

import torch
import torch.nn as nn

from .api import LayerChoice, InputChoice
from .nn import ModuleList

from .utils import generate_new_label, get_fixed_value


__all__ = ['Repeat', 'Cell']


class Repeat(nn.Module):
"""
Repeat a block by a variable number of times.

Parameters
----------
blocks : function, list of function, module or list of module
The block to be repeated. If not a list, it will be replicated into a list.
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied.
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least `min` times and at most `max` times.
"""

def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
try:
repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
except AssertionError:
return super().__new__(cls)

def __init__(self,
blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.min_depth = depth if isinstance(depth, int) else depth[0]
self.max_depth = depth if isinstance(depth, int) else depth[1]
assert self.max_depth >= self.min_depth > 0
self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth))

@property
def label(self):
return self._label

def forward(self, x):
for block in self.blocks:
x = block(x)
return x

@staticmethod
def _replicate_and_instantiate(blocks, repeat):
if not isinstance(blocks, list):
if isinstance(blocks, nn.Module):
blocks = [blocks] + [copy.deepcopy(blocks) for _ in range(repeat - 1)]
else:
blocks = [blocks for _ in range(repeat)]
assert len(blocks) > 0
assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.'
blocks = blocks[:repeat]
if not isinstance(blocks[0], nn.Module):
blocks = [b() for b in blocks]
return blocks


class Cell(nn.Module):
"""
Cell structure [1]_ [2]_ that is popularly used in NAS literature.

A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).

Parameters
----------
op_candidates : function or list of module
A list of modules to choose from, or a function that returns a list of modules.
num_nodes : int
Number of nodes in the cell.
num_ops_per_node: int
Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1.
num_predecessors : int
Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1.
merge_op : str
Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.

References
----------
.. [1] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [2] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
"""

# TODO:
# Support loose end concat (shape inference on the following cells)
# How to dynamically create convolution with stride as the first node

def __init__(self,
op_candidates: Union[Callable, List[nn.Module]],
num_nodes: int,
num_ops_per_node: int = 1,
num_predecessors: int = 1,
merge_op: str = 'all',
label: str = None):
super().__init__()
self._label = generate_new_label(label)
self.ops = ModuleList()
self.inputs = ModuleList()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
for i in range(num_nodes):
self.ops.append(ModuleList())
self.inputs.append(ModuleList())
for k in range(num_ops_per_node):
if isinstance(op_candidates, list):
assert len(op_candidates) > 0 and isinstance(op_candidates[0], nn.Module)
ops = copy.deepcopy(op_candidates)
else:
ops = op_candidates()
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}__op_{i}_{k}'))
self.inputs[-1].append(InputChoice(i + num_predecessors, 1, label=f'{self.label}/input_{i}_{k}'))
assert merge_op in ['all'] # TODO: loose_end
self.merge_op = merge_op

@property
def label(self):
return self._label

def forward(self, x: List[torch.Tensor]):
states = x
for ops, inps in zip(self.ops, self.inputs):
current_state = []
for op, inp in zip(ops, inps):
current_state.append(op(inp(states)))
current_state = torch.sum(torch.stack(current_state), 0)
states.append(current_state)
return torch.cat(states[self.num_predecessors:], 1)
53 changes: 52 additions & 1 deletion nni/retiarii/nn/pytorch/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from ...mutator import Mutator
from ...graph import Cell, Graph, Model, ModelStatus, Node
from ...utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .component import Repeat
from ...utils import uid


class LayerChoiceMutator(Mutator):
Expand Down Expand Up @@ -80,6 +81,42 @@ def mutate(self, model):
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})


class RepeatMutator(Mutator):
def __init__(self, nodes: List[Node]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to explain the meaning of nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

# nodes is a subgraph consisting of repeated blocks.
super().__init__()
self.nodes = nodes

def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]:
u = graph.input_node
chain = []
while u != graph.output_node:
if u != graph.input_node:
chain.append(u)
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.'
u = u.successors[0]
return chain

def mutate(self, model):
min_depth = self.nodes[0].operation.parameters['min_depth']
max_depth = self.nodes[0].operation.parameters['max_depth']
if min_depth < max_depth:
chosen_depth = self.choice(list(range(min_depth, max_depth + 1)))
for node in self.nodes:
# the logic here is similar to layer choice. We find cell attached to each node.
target: Graph = model.graphs[node.operation.cell_name]
chain = self._retrieve_chain_from_graph(target)
for edge in chain[chosen_depth - 1].outgoing_edges:
edge.remove()
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
for rm_node in chain[chosen_depth:]:
for edge in rm_node.outgoing_edges:
edge.remove()
rm_node.remove()
# to delete the unused parameters.
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of this line, why it is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To delete unused parameters. Otherwise, codegen will complain.



def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = []

Expand Down Expand Up @@ -120,6 +157,15 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)

repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat',
model.get_nodes_by_type('_cell')))
for node_list in repeat_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \
_is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \
'Repeat with the same label must have the same number of candidates.'
mutator = RepeatMutator(node_list)
applied_mutators.append(mutator)

if applied_mutators:
return applied_mutators
return None
Expand Down Expand Up @@ -190,6 +236,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if isinstance(module, ValueChoice):
node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates})
node.label = module.label
if isinstance(module, Repeat) and module.min_depth <= module.max_depth:
node = graph.add_node(name, 'Repeat', {
'candidates': list(range(module.min_depth, module.max_depth + 1))
})
node.label = module.label
if isinstance(module, Placeholder):
raise NotImplementedError('Placeholder is not supported in python execution mode.')

Expand Down
17 changes: 17 additions & 0 deletions nni/retiarii/nn/pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Optional

from ...utils import uid, get_current_context


def generate_new_label(label: Optional[str]):
if label is None:
return '_mutation_' + str(uid('mutation'))
return label


def get_fixed_value(label: str):
ret = get_current_context('fixed')
try:
return ret[generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
Loading