Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix: refine shape attribute #4214

Merged
merged 30 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9b7874f
refactor nn-Meter multi-trial to adapt new structure of nn-Meter
JiahangXu Jul 28, 2021
a3e5ec2
add github repo link of nn-Meter
JiahangXu Jul 28, 2021
51e3f67
refine nn-Meter API and add nn-meter in README.md
JiahangXu Aug 2, 2021
3878315
Merge branch 'microsoft:master' into master
JiahangXu Aug 2, 2021
aac4c43
Merge branch 'microsoft:master' into master
JiahangXu Aug 3, 2021
b07bfba
Merge branch 'microsoft:master' into master
JiahangXu Aug 5, 2021
fd1fa27
refine mulit_trial doc
JiahangXu Aug 5, 2021
bd7ec88
add dummy_inputs
JiahangXu Aug 5, 2021
5e11afd
Merge branch 'microsoft:master' into master
JiahangXu Aug 6, 2021
15fb7fa
remove mutator in doc
JiahangXu Aug 6, 2021
072a1f1
remove mutator in doc
JiahangXu Aug 9, 2021
6efba0a
Merge branch 'master' of github.com:microsoft/nni
JiahangXu Aug 18, 2021
c00bbec
fix nn.Sequential bug
JiahangXu Aug 20, 2021
0546f04
Merge branch 'microsoft:master' into master
JiahangXu Aug 20, 2021
2abb7bb
Merge branch 'master' of github.com:microsoft/nni
JiahangXu Aug 23, 2021
9ff8928
add identity node if the graph is empty
JiahangXu Aug 24, 2021
477b849
Merge branch 'master' of github.com:microsoft/nni
JiahangXu Aug 27, 2021
d3ac334
delete trailing whitespace
JiahangXu Sep 6, 2021
66a1555
Fix: change shape to node attr
JiahangXu Sep 24, 2021
828760c
Fix: change shape to node attr (complete)
JiahangXu Sep 26, 2021
0f38a36
change python type hint
JiahangXu Sep 27, 2021
9040c9c
Feature: move i/o shape to operation attr
JiahangXu Sep 28, 2021
84467a7
remove redundant code
JiahangXu Sep 28, 2021
625be9e
Refactor: add default value for node parameters
JiahangXu Oct 8, 2021
da0f42f
Fix pipeline error: tf_operation miss attr
JiahangXu Oct 11, 2021
5ee44c9
Fix pipeline error: tf ir json miss attr
JiahangXu Oct 11, 2021
589aa84
Refactor: change attr to attributes 1
JiahangXu Oct 11, 2021
72323d1
Merge branch 'dev-refine-shape-attr' of github.com:JiahangXu/nni into…
JiahangXu Oct 11, 2021
6615ae0
Refactor: change attr to attributes 2
JiahangXu Oct 11, 2021
2666b42
Fix pylint typo
JiahangXu Oct 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ def handle_single_node(node):
for node in sm_graph.nodes():
handle_single_node(node)

if node_index == {}:
# here is an example that the ir_graph is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add a noop_identity node to handle this situation
self.global_seq += 1
ni_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ni_node, None))
ir_graph.add_edge(head=(ni_node, None), tail=(ir_graph.output_node, None))
for _output in sm_graph.outputs():
node_index[_output.node()] = ni_node
return node_index

def merge_aten_slices(self, ir_graph):
Expand Down Expand Up @@ -575,9 +586,7 @@ def _convert_module(self, script_module, module, module_name, ir_model):
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name in MODULE_EXCEPT_LIST:
pass # do nothing
elif original_type_name == OpTypeName.LayerChoice:
if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = []
for cand_name in module.names:
Expand All @@ -599,7 +608,9 @@ def _convert_module(self, script_module, module, module_name, ir_model):
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
elif module.__class__.__module__.startswith('torch.nn') and \
original_type_name in torch.nn.__dict__ and \
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False):
Expand Down Expand Up @@ -696,19 +707,21 @@ def _initialize_parameters(self, ir_model: 'Model'):
for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None:
ir_node.operation.parameters = {}
ir_node.operation.parameters.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', [])
ir_node.operation.attr.setdefault('input_shape', [])
ir_node.operation.attr.setdefault('output_shape', [])

def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, dummy_input)

for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node)
shape_parameters, parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name)
if ir_node is not None:
ir_node.operation.parameters.update(parameters)
ir_node.operation.attr.update(shape_parameters)
if parameters:
ir_node.operation.parameters.update(parameters)

self.propagate_shape(ir_model)

Expand All @@ -724,7 +737,7 @@ def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.parameters['input_shape']]
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.attr['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs)

def propagate_shape(self, ir_model: 'Model'):
Expand All @@ -742,8 +755,8 @@ def propagate_shape_for_graph(graph: 'Graph'):
cand_node = ir_model.get_node_by_name(cand_name)
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.parameters['input_shape'] = cand_node.operation.parameters['input_shape']
graph_node.operation.parameters['output_shape'] = cand_node.operation.parameters['output_shape']
graph_node.operation.attr['input_shape'] = cand_node.operation.attr['input_shape']
graph_node.operation.attr['output_shape'] = cand_node.operation.attr['output_shape']
else:
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
Expand All @@ -752,17 +765,17 @@ def propagate_shape_for_graph(graph: 'Graph'):
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.parameters['input_shape'][edge.tail_slot or 0]
graph_node.operation.parameters['input_shape'] = input_shape
if node.operation.attr['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.attr['input_shape'][edge.tail_slot or 0]
graph_node.operation.attr['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges:
node = edge.head
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.parameters['output_shape'][edge.head_slot or 0]
graph_node.operation.parameters['output_shape'] = output_shape
if node.operation.attr['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.attr['output_shape'][edge.head_slot or 0]
graph_node.operation.attr['output_shape'] = output_shape

propagate_shape_for_graph(graph_node.graph)

Expand Down
1 change: 1 addition & 0 deletions nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from enum import Enum

# except the special case which can not treat as a basic module from pytorch
MODULE_EXCEPT_LIST = ['Sequential']


Expand Down
13 changes: 7 additions & 6 deletions nni/retiarii/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
if shape:
output_shape.append(shape)

parameters = {
shape_parameters = {
'input_shape': input_shape,
'output_shape': output_shape,
}

if trace_node.kind() == 'aten::cat':
parameters['dim'] = inputs[1].toIValue()

return parameters
parameters = {'dim': inputs[1].toIValue()}
return shape_parameters, parameters
else:
return shape_parameters, None


def is_layerchoice_node(ir_node: Node):
Expand Down Expand Up @@ -100,12 +101,12 @@ def match_node(ir_model: Model, torch_node, prefix=''):
graph = ir_model.graphs.get(full_name)
if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.parameters['input_shape']:
if not node.operation.attr['input_shape']:
return node
return None
else:
return ir_model.get_node_by_name(full_name)


def _without_shape_info(node: Node):
return not node.operation.parameters['input_shape'] and not node.operation.parameters['output_shape']
return not node.operation.attr['input_shape'] and not node.operation.attr['output_shape']
6 changes: 3 additions & 3 deletions nni/retiarii/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,16 +603,16 @@ def _register(self) -> 'Node':
@staticmethod
def _load(graph: Graph, name: str, ir: Any) -> 'Node':
if ir['operation']['type'] == '_cell':
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}))
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}), attr=ir['operation'].get('attr', {}))
else:
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {}))
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {}), attr=ir['operation'].get('attr', {}))
node = Node(graph, uid(), name, op)
if 'label' in ir:
node.update_label(ir['label'])
return node

def _dump(self) -> Any:
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters}}
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attr': self.operation.attr}}
if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/nn/pytorch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

Module = nn.Module

Sequential = transparent_serialize(nn.Sequential)
Sequential = nn.Sequential
ModuleList = transparent_serialize(nn.ModuleList)

Identity = basic_unit(nn.Identity)
Expand Down
19 changes: 10 additions & 9 deletions nni/retiarii/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
"""

def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False):
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False, attr: Dict[str, Any] = {}):
JiahangXu marked this conversation as resolved.
Show resolved Hide resolved
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
self.parameters: Dict[str, Any] = parameters
self.attr: Dict[str, Any] = attr

def to_init_code(self, field: str) -> str:
raise NotImplementedError()
Expand All @@ -52,9 +53,10 @@ def __bool__(self) -> bool:
return True

@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None) -> 'Operation':
if parameters is None:
parameters = {}
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None,
attr: Dict[str, Any] = None) -> 'Operation':
parameters = parameters or {}
attr = attr or {}
if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
Expand All @@ -67,7 +69,7 @@ def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None
cls = TensorFlowOperation._find_subclass(type_name)
else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}')
return cls(type_name, parameters, _internal=True)
return cls(type_name, parameters, _internal=True, attr=attr)

@classmethod
def _find_subclass(cls, subclass_name):
Expand Down Expand Up @@ -205,12 +207,11 @@ def forward(...):
No real usage. Exists for compatibility with base class.
"""

def __init__(self, cell_name: str, parameters: Dict[str, Any] = None):
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None, attr: Dict[str, Any] = None):
self.type = '_cell'
self.cell_name = cell_name
if parameters is None:
parameters = {}
self.parameters = parameters
self.parameters = parameters or {}
self.attr = attr or {}

def _to_class_name(self):
# TODO: ugly, think about how to refactor this part
Expand Down
20 changes: 10 additions & 10 deletions test/ut/retiarii/test_convert_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ def forward(self, x):
conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0]
relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0]
pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0]
self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]])
self.assertEqual(conv_node.operation.attr.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.attr.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.attr.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.attr.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.attr.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.attr.get('output_shape'), [[1, 1, 111, 111]])

def test_nested_module(self):
class ConvRelu(nn.Module):
Expand All @@ -54,8 +54,8 @@ def forward(self, x):

# check if shape propagation works
cell_node = model_ir.get_nodes_by_type('_cell')[0]
self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(cell_node.operation.attr.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.attr.get('output_shape'), [[1, 1, 222, 222]])

def test_layerchoice(self):
class ConvNet(nn.Module):
Expand All @@ -75,5 +75,5 @@ def forward(self, x):

# check shape info of each candidates
conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')
self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[0].operation.attr.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.attr.get('output_shape'), [[1, 1, 222, 222]])