This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii] New API: Repeat and Cell #3481
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
959b10f
Add API draft
ultmaster e83e854
Merge branch 'master' of https://github.com/microsoft/nni into retiar…
ultmaster eb1c947
Add implementation for Repeat
ultmaster 75ea38a
Update cell API
ultmaster d5f0953
Merge branch 'master' of https://github.com/microsoft/nni into retiar…
ultmaster 50b9d86
support nested ModuleList
QuanluZhang 7f80b06
add test
QuanluZhang 32e1a2f
Checkpoint for testing code
ultmaster 922e3ff
fix issue of list inplace append
QuanluZhang 026585f
minor
QuanluZhang 72a0fc9
Update API docs
ultmaster 82e1e10
Refine documents
ultmaster 4fd4755
Move api to component and make arguments optional
ultmaster 9d0bc08
Merge branch 'master' of https://github.com/microsoft/nni into retiar…
ultmaster 23e1dd5
Merge branch 'master' of https://github.com/microsoft/nni into retiar…
ultmaster 4f18b0f
Add support in pure-python engine
ultmaster 6af0f2b
Fix pylint
ultmaster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .api import * | ||
from .component import * | ||
from .nn import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import copy | ||
from typing import Callable, List, Union, Tuple, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .api import LayerChoice, InputChoice | ||
from .nn import ModuleList | ||
|
||
from .utils import generate_new_label, get_fixed_value | ||
|
||
|
||
__all__ = ['Repeat', 'Cell'] | ||
|
||
|
||
class Repeat(nn.Module): | ||
""" | ||
Repeat a block by a variable number of times. | ||
|
||
Parameters | ||
---------- | ||
blocks : function, list of function, module or list of module | ||
The block to be repeated. If not a list, it will be replicated into a list. | ||
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken. | ||
If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied. | ||
depth : int or tuple of int | ||
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max), | ||
meaning that the block will be repeated at least `min` times and at most `max` times. | ||
""" | ||
|
||
def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]], | ||
depth: Union[int, Tuple[int, int]], label: Optional[str] = None): | ||
try: | ||
repeat = get_fixed_value(label) | ||
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat)) | ||
except AssertionError: | ||
return super().__new__(cls) | ||
|
||
def __init__(self, | ||
blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]], | ||
depth: Union[int, Tuple[int, int]], label: Optional[str] = None): | ||
super().__init__() | ||
self._label = generate_new_label(label) | ||
self.min_depth = depth if isinstance(depth, int) else depth[0] | ||
self.max_depth = depth if isinstance(depth, int) else depth[1] | ||
assert self.max_depth >= self.min_depth > 0 | ||
self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth)) | ||
|
||
@property | ||
def label(self): | ||
return self._label | ||
|
||
def forward(self, x): | ||
for block in self.blocks: | ||
x = block(x) | ||
return x | ||
|
||
@staticmethod | ||
def _replicate_and_instantiate(blocks, repeat): | ||
if not isinstance(blocks, list): | ||
if isinstance(blocks, nn.Module): | ||
blocks = [blocks] + [copy.deepcopy(blocks) for _ in range(repeat - 1)] | ||
else: | ||
blocks = [blocks for _ in range(repeat)] | ||
assert len(blocks) > 0 | ||
assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.' | ||
blocks = blocks[:repeat] | ||
if not isinstance(blocks[0], nn.Module): | ||
blocks = [b() for b in blocks] | ||
return blocks | ||
|
||
|
||
class Cell(nn.Module): | ||
""" | ||
Cell structure [1]_ [2]_ that is popularly used in NAS literature. | ||
|
||
A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from | ||
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell. | ||
The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes). | ||
|
||
Parameters | ||
---------- | ||
op_candidates : function or list of module | ||
A list of modules to choose from, or a function that returns a list of modules. | ||
num_nodes : int | ||
Number of nodes in the cell. | ||
num_ops_per_node: int | ||
Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1. | ||
num_predecessors : int | ||
Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1. | ||
merge_op : str | ||
Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all. | ||
label : str | ||
Identifier of the cell. Cell sharing the same label will semantically share the same choice. | ||
|
||
References | ||
---------- | ||
.. [1] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578 | ||
.. [2] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le, | ||
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012 | ||
""" | ||
|
||
# TODO: | ||
# Support loose end concat (shape inference on the following cells) | ||
# How to dynamically create convolution with stride as the first node | ||
|
||
def __init__(self, | ||
op_candidates: Union[Callable, List[nn.Module]], | ||
num_nodes: int, | ||
num_ops_per_node: int = 1, | ||
num_predecessors: int = 1, | ||
merge_op: str = 'all', | ||
label: str = None): | ||
super().__init__() | ||
self._label = generate_new_label(label) | ||
self.ops = ModuleList() | ||
self.inputs = ModuleList() | ||
self.num_nodes = num_nodes | ||
self.num_ops_per_node = num_ops_per_node | ||
self.num_predecessors = num_predecessors | ||
for i in range(num_nodes): | ||
self.ops.append(ModuleList()) | ||
self.inputs.append(ModuleList()) | ||
for k in range(num_ops_per_node): | ||
if isinstance(op_candidates, list): | ||
assert len(op_candidates) > 0 and isinstance(op_candidates[0], nn.Module) | ||
ops = copy.deepcopy(op_candidates) | ||
else: | ||
ops = op_candidates() | ||
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}__op_{i}_{k}')) | ||
self.inputs[-1].append(InputChoice(i + num_predecessors, 1, label=f'{self.label}/input_{i}_{k}')) | ||
assert merge_op in ['all'] # TODO: loose_end | ||
self.merge_op = merge_op | ||
|
||
@property | ||
def label(self): | ||
return self._label | ||
|
||
def forward(self, x: List[torch.Tensor]): | ||
states = x | ||
for ops, inps in zip(self.ops, self.inputs): | ||
current_state = [] | ||
for op, inp in zip(ops, inps): | ||
current_state.append(op(inp(states))) | ||
current_state = torch.sum(torch.stack(current_state), 0) | ||
states.append(current_state) | ||
return torch.cat(states[self.num_predecessors:], 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,9 @@ | |
|
||
from ...mutator import Mutator | ||
from ...graph import Cell, Graph, Model, ModelStatus, Node | ||
from ...utils import uid | ||
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder | ||
from .component import Repeat | ||
from ...utils import uid | ||
|
||
|
||
class LayerChoiceMutator(Mutator): | ||
|
@@ -80,6 +81,42 @@ def mutate(self, model): | |
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value}) | ||
|
||
|
||
class RepeatMutator(Mutator): | ||
def __init__(self, nodes: List[Node]): | ||
# nodes is a subgraph consisting of repeated blocks. | ||
super().__init__() | ||
self.nodes = nodes | ||
|
||
def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]: | ||
u = graph.input_node | ||
chain = [] | ||
while u != graph.output_node: | ||
if u != graph.input_node: | ||
chain.append(u) | ||
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.' | ||
u = u.successors[0] | ||
return chain | ||
|
||
def mutate(self, model): | ||
min_depth = self.nodes[0].operation.parameters['min_depth'] | ||
max_depth = self.nodes[0].operation.parameters['max_depth'] | ||
if min_depth < max_depth: | ||
chosen_depth = self.choice(list(range(min_depth, max_depth + 1))) | ||
for node in self.nodes: | ||
# the logic here is similar to layer choice. We find cell attached to each node. | ||
target: Graph = model.graphs[node.operation.cell_name] | ||
chain = self._retrieve_chain_from_graph(target) | ||
for edge in chain[chosen_depth - 1].outgoing_edges: | ||
edge.remove() | ||
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None)) | ||
for rm_node in chain[chosen_depth:]: | ||
for edge in rm_node.outgoing_edges: | ||
edge.remove() | ||
rm_node.remove() | ||
# to delete the unused parameters. | ||
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the meaning of this line, why it is necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To delete unused parameters. Otherwise, codegen will complain. |
||
|
||
|
||
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: | ||
applied_mutators = [] | ||
|
||
|
@@ -120,6 +157,15 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: | |
mutator = LayerChoiceMutator(node_list) | ||
applied_mutators.append(mutator) | ||
|
||
repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat', | ||
model.get_nodes_by_type('_cell'))) | ||
for node_list in repeat_nodes: | ||
assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \ | ||
_is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \ | ||
'Repeat with the same label must have the same number of candidates.' | ||
mutator = RepeatMutator(node_list) | ||
applied_mutators.append(mutator) | ||
|
||
if applied_mutators: | ||
return applied_mutators | ||
return None | ||
|
@@ -190,6 +236,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op | |
if isinstance(module, ValueChoice): | ||
node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates}) | ||
node.label = module.label | ||
if isinstance(module, Repeat) and module.min_depth <= module.max_depth: | ||
node = graph.add_node(name, 'Repeat', { | ||
'candidates': list(range(module.min_depth, module.max_depth + 1)) | ||
}) | ||
node.label = module.label | ||
if isinstance(module, Placeholder): | ||
raise NotImplementedError('Placeholder is not supported in python execution mode.') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Optional | ||
|
||
from ...utils import uid, get_current_context | ||
|
||
|
||
def generate_new_label(label: Optional[str]): | ||
if label is None: | ||
return '_mutation_' + str(uid('mutation')) | ||
return label | ||
|
||
|
||
def get_fixed_value(label: str): | ||
ret = get_current_context('fixed') | ||
try: | ||
return ret[generate_new_label(label)] | ||
except KeyError: | ||
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}') |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to explain the meaning of
nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok