From 6af99c553d5d2b54ec6f5235890cf827fa0d1420 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 9 Jul 2021 12:35:09 +0800 Subject: [PATCH 1/4] Multi-trial example for SPOS (demo for nnmeter development) (#3876) --- examples/nas/oneshot/spos/blocks.py | 25 ++-- examples/nas/oneshot/spos/multi_trial.py | 164 +++++++++++++++++++++++ 2 files changed, 177 insertions(+), 12 deletions(-) create mode 100644 examples/nas/oneshot/spos/multi_trial.py diff --git a/examples/nas/oneshot/spos/blocks.py b/examples/nas/oneshot/spos/blocks.py index e07a43f5f0..d5410d8403 100644 --- a/examples/nas/oneshot/spos/blocks.py +++ b/examples/nas/oneshot/spos/blocks.py @@ -27,17 +27,18 @@ def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp", affine self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence)) - if stride == 2: - self.branch_proj = nn.Sequential( - # dw - nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad, - groups=self.channels, bias=False), - nn.BatchNorm2d(self.channels, affine=affine), - # pw-linear - nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False), - nn.BatchNorm2d(self.channels, affine=affine), - nn.ReLU(inplace=True) - ) + # FIXME: restore before merging into master + # remove if stride == 2 for torchscript + self.branch_proj = nn.Sequential( + # dw + nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad, + groups=self.channels, bias=False), + nn.BatchNorm2d(self.channels, affine=affine), + # pw-linear + nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(self.channels, affine=affine), + nn.ReLU(inplace=True) + ) def forward(self, x): if self.stride == 2: @@ -76,7 +77,7 @@ def _decode_point_depth_conv(self, sequence): return result def _channel_shuffle(self, x): - bs, num_channels, height, width = x.data.size() + bs, num_channels, height, width = x.size() assert (num_channels % 4 == 0) x = x.reshape(bs * num_channels // 2, 2, height * width) x = x.permute(1, 0, 2) diff --git a/examples/nas/oneshot/spos/multi_trial.py b/examples/nas/oneshot/spos/multi_trial.py new file mode 100644 index 0000000000..97475055a9 --- /dev/null +++ b/examples/nas/oneshot/spos/multi_trial.py @@ -0,0 +1,164 @@ +import click +import nni.retiarii.evaluator.pytorch as pl +import nni.retiarii.nn.pytorch as nn +import nni.retiarii.strategy as strategy +import torch +from nni.retiarii import serialize +from nni.retiarii.nn.pytorch import LayerChoice +from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +from blocks import ShuffleNetBlock, ShuffleXceptionBlock + + +class ShuffleNetV2(nn.Module): + block_keys = [ + 'shufflenet_3x3', + 'shufflenet_5x5', + 'shufflenet_7x7', + 'xception_3x3', + ] + + def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000, affine=False): + super().__init__() + + assert input_size % 32 == 0 + + self.stage_blocks = [4, 4, 8, 4] + self.stage_channels = [64, 160, 320, 640] + self._parsed_flops = dict() + self._input_size = input_size + self._feature_map_size = input_size + self._first_conv_channels = first_conv_channels + self._last_conv_channels = last_conv_channels + self._n_classes = n_classes + self._affine = affine + + # building first layer + self.first_conv = nn.Sequential( + nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(first_conv_channels, affine=affine), + nn.ReLU(inplace=True), + ) + self._feature_map_size //= 2 + + p_channels = first_conv_channels + features = [] + for num_blocks, channels in zip(self.stage_blocks, self.stage_channels): + features.extend(self._make_blocks(num_blocks, p_channels, channels)) + p_channels = channels + self.features = nn.Sequential(*features) + + self.conv_last = nn.Sequential( + nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(last_conv_channels, affine=affine), + nn.ReLU(inplace=True), + ) + self.globalpool = nn.AvgPool2d(self._feature_map_size) + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Sequential( + nn.Linear(last_conv_channels, n_classes, bias=False), + ) + + self._initialize_weights() + + def _make_blocks(self, blocks, in_channels, channels): + result = [] + for i in range(blocks): + stride = 2 if i == 0 else 1 + inp = in_channels if i == 0 else channels + oup = channels + + base_mid_channels = channels // 2 + mid_channels = int(base_mid_channels) # prepare for scale + choice_block = LayerChoice([ + serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine), + serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine), + serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine), + serialize(ShuffleXceptionBlock, inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine) + ]) + result.append(choice_block) + + if stride == 2: + self._feature_map_size //= 2 + return result + + def forward(self, x): + bs = x.size(0) + x = self.first_conv(x) + x = self.features(x) + x = self.conv_last(x) + x = self.globalpool(x) + + x = self.dropout(x) + x = x.contiguous().view(bs, -1) + x = self.classifier(x) + return x + + def _initialize_weights(self): + # FIXME this won't work in base engine + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'first' in name: + torch.nn.init.normal_(m.weight, 0, 0.01) + else: + torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + if m.weight is not None: + torch.nn.init.constant_(m.weight, 1) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0.0001) + torch.nn.init.constant_(m.running_mean, 0) + elif isinstance(m, nn.BatchNorm1d): + torch.nn.init.constant_(m.weight, 1) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0.0001) + torch.nn.init.constant_(m.running_mean, 0) + elif isinstance(m, nn.Linear): + torch.nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + +@click.command() +@click.option('--port', default=8081, help='On which port the experiment is run.') +def _main(port): + base_model = ShuffleNetV2(32) + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768]) + ] + train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize)) + test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize)) + + trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=64), + val_dataloaders=pl.DataLoader(test_dataset, batch_size=64), + max_epochs=2, gpus=1) + + simple_strategy = strategy.Random() + + exp = RetiariiExperiment(base_model, trainer, [], simple_strategy) + + exp_config = RetiariiExeConfig('local') + exp_config.trial_concurrency = 2 + exp_config.max_trial_number = 2 + exp_config.trial_gpu_number = 1 + exp_config.training_service.use_active_gpu = False + exp_config.execution_engine = 'base' + + exp.run(exp_config, port) + + print('Exported models:') + for model in exp.export_top_models(formatter='dict'): + print(model) + + +if __name__ == '__main__': + _main() From 99aa8226c67a4a9097a029ebf8700894f9314c05 Mon Sep 17 00:00:00 2001 From: kalineid Date: Thu, 15 Jul 2021 16:01:16 +0800 Subject: [PATCH 2/4] [Retiarii]: Add info required by nn-meter to graph ir (#3910) * Fix mutable default * LayerChoice:forward now will default run the first candidate to support trace (#3910) * New GraphConverter to parse shape info required by nn-meter (#3910) * Support model filter in Random strategy * Support latency aware search in SPOS multi-trial example * Fix for review (#3910) * Add doc for hardware-aware NAS * Fix lint python & Add nn_meter to sphinx mock * Add comments * Move LatencyFilter to examples * Move example inputs into configs * Support nested layer choice Co-authored-by: Jianyu Wei Co-authored-by: kalineid Co-authored-by: Yuge Zhang Co-authored-by: Yuge Zhang --- docs/en_US/NAS/HardwareAwareNAS.rst | 41 ++++ docs/en_US/NAS/multi_trial_nas.rst | 1 + docs/en_US/conf.py | 2 +- examples/nas/oneshot/spos/blocks.py | 3 +- examples/nas/oneshot/spos/multi_trial.py | 44 +++- nni/retiarii/converter/graph_gen.py | 274 ++++++++++++++++++++--- nni/retiarii/converter/utils.py | 94 ++++++++ nni/retiarii/experiment/pytorch.py | 17 +- nni/retiarii/graph.py | 12 +- nni/retiarii/nn/pytorch/api.py | 3 +- nni/retiarii/operation.py | 8 +- nni/retiarii/strategy/bruteforce.py | 14 +- nni/retiarii/strategy/utils.py | 17 ++ test/ut/retiarii/test_highlevel_apis.py | 27 +++ 14 files changed, 497 insertions(+), 60 deletions(-) create mode 100644 docs/en_US/NAS/HardwareAwareNAS.rst diff --git a/docs/en_US/NAS/HardwareAwareNAS.rst b/docs/en_US/NAS/HardwareAwareNAS.rst new file mode 100644 index 0000000000..81eaeb7393 --- /dev/null +++ b/docs/en_US/NAS/HardwareAwareNAS.rst @@ -0,0 +1,41 @@ +Hardware-aware NAS +================== + +.. contents:: + +EndToEnd Multi-trial SPOS Demo +------------------------------ + +Basically, this demo will select the model whose latency satisfy constraints to train. + +To run this demo, first install nn-Meter from source code (currently we haven't released this package, so development installation is required). + +.. code-block:: bash + + python setup.py develop + +Then run multi-trail SPOS demo: + +.. code-block:: bash + + python ${NNI_ROOT}/examples/nas/oneshot/spos/multi_trial.py + +How the demo works +------------------ + +To support latency-aware NAS, you first need a `Strategy` that supports filtering the models by latency. We provide such a filter named `LatencyFilter` in NNI and initialize a `Random` strategy with the filter: + +.. code-block:: python + + simple_strategy = strategy.Random(model_filter=LatencyFilter(100) + +``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example). +You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency. + +Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, example_inputs=example_inputs``: + +.. code-block:: python + + RetiariiExperiment(base_model, trainer, [], simple_strategy, True, example_inputs) + +Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``example_inputs`` is required for tracing shape info. diff --git a/docs/en_US/NAS/multi_trial_nas.rst b/docs/en_US/NAS/multi_trial_nas.rst index 9e77216fb8..8914038bd8 100644 --- a/docs/en_US/NAS/multi_trial_nas.rst +++ b/docs/en_US/NAS/multi_trial_nas.rst @@ -11,3 +11,4 @@ In multi-trial NAS, users need model evaluator to evaluate the performance of ea Exploration Strategies Customize Exploration Strategies Execution Engines + Hardware-aware NAS diff --git a/docs/en_US/conf.py b/docs/en_US/conf.py index 6df286bde0..3f11cece89 100644 --- a/docs/en_US/conf.py +++ b/docs/en_US/conf.py @@ -51,7 +51,7 @@ ] # Add mock modules -autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda'] +autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda', 'nn_meter'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/examples/nas/oneshot/spos/blocks.py b/examples/nas/oneshot/spos/blocks.py index d5410d8403..0c7e5c8ed7 100644 --- a/examples/nas/oneshot/spos/blocks.py +++ b/examples/nas/oneshot/spos/blocks.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import torch -import torch.nn as nn +import nni.retiarii.nn.pytorch as nn class ShuffleNetBlock(nn.Module): @@ -78,7 +78,6 @@ def _decode_point_depth_conv(self, sequence): def _channel_shuffle(self, x): bs, num_channels, height, width = x.size() - assert (num_channels % 4 == 0) x = x.reshape(bs * num_channels // 2, 2, height * width) x = x.permute(1, 0, 2) x = x.reshape(2, -1, num_channels // 2, height, width) diff --git a/examples/nas/oneshot/spos/multi_trial.py b/examples/nas/oneshot/spos/multi_trial.py index 97475055a9..730e688142 100644 --- a/examples/nas/oneshot/spos/multi_trial.py +++ b/examples/nas/oneshot/spos/multi_trial.py @@ -1,3 +1,5 @@ +# This file is to demo the usage of multi-trial NAS in the usage of SPOS search space. + import click import nni.retiarii.evaluator.pytorch as pl import nni.retiarii.nn.pytorch as nn @@ -11,6 +13,8 @@ from blocks import ShuffleNetBlock, ShuffleXceptionBlock +from nn_meter import get_default_config, load_latency_predictors + class ShuffleNetV2(nn.Module): block_keys = [ @@ -73,10 +77,10 @@ def _make_blocks(self, blocks, in_channels, channels): base_mid_channels = channels // 2 mid_channels = int(base_mid_channels) # prepare for scale choice_block = LayerChoice([ - serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine), - serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine), - serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine), - serialize(ShuffleXceptionBlock, inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine) + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine), + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine), + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine), + ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine) ]) result.append(choice_block) @@ -123,6 +127,35 @@ def _initialize_weights(self): torch.nn.init.constant_(m.bias, 0) +class LatencyFilter: + def __init__(self, threshold, config=None, hardware='', reverse=False): + """ + Filter the models according to predcted latency. + + Parameters + ---------- + threshold: `float` + the threshold of latency + config, hardware: + determine the targeted device + reverse: `bool` + if reverse is `False`, then the model returns `True` when `latency < threshold`, + else otherwisse + """ + default_config, default_hardware = get_default_config() + if config is None: + config = default_config + if not hardware: + hardware = default_hardware + + self.predictors = load_latency_predictors(config, hardware) + self.threshold = threshold + + def __call__(self, ir_model): + latency = self.predictors.predict(ir_model, 'nni') + return latency < self.threshold + + @click.command() @click.option('--port', default=8081, help='On which port the experiment is run.') def _main(port): @@ -142,7 +175,7 @@ def _main(port): val_dataloaders=pl.DataLoader(test_dataset, batch_size=64), max_epochs=2, gpus=1) - simple_strategy = strategy.Random() + simple_strategy = strategy.Random(model_filter=LatencyFilter(100)) exp = RetiariiExperiment(base_model, trainer, [], simple_strategy) @@ -152,6 +185,7 @@ def _main(port): exp_config.trial_gpu_number = 1 exp_config.training_service.use_active_gpu = False exp_config.execution_engine = 'base' + exp_config.example_inputs = [1, 3, 32, 32] exp.run(exp_config, port) diff --git a/nni/retiarii/converter/graph_gen.py b/nni/retiarii/converter/graph_gen.py index f8b06b887a..e402082fc1 100644 --- a/nni/retiarii/converter/graph_gen.py +++ b/nni/retiarii/converter/graph_gen.py @@ -5,13 +5,17 @@ import torch -from ..graph import Graph, Model, Node -from ..nn.pytorch import InputChoice, Placeholder +from ..graph import Graph, Model, Node, Edge +from ..nn.pytorch import InputChoice, Placeholder, LayerChoice from ..operation import Cell, Operation from ..serializer import get_init_parameters_or_fail from ..utils import get_importable_name from .op_types import MODULE_EXCEPT_LIST, OpTypeName -from .utils import _convert_name, build_full_name +from .utils import ( + _convert_name, build_full_name, _without_shape_info, + _extract_info_from_trace_node, get_full_name_by_scope_name, + is_layerchoice_node, match_node, build_cand_name +) class GraphConverter: @@ -305,9 +309,9 @@ def handle_function_callmethod(node): submodule_full_name = build_full_name(module_name, submodule_name) submodule_obj = getattr(module, submodule_name) - subgraph, sub_m_attrs = self.convert_module(script_module._modules[submodule_name], - submodule_obj, - submodule_full_name, ir_model) + subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name], + submodule_obj, + submodule_full_name, ir_model) else: # %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self) # %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8) @@ -339,7 +343,7 @@ def handle_function_callmethod(node): for each_name in list(reversed(module_name_space)): submodule_obj = getattr(submodule_obj, each_name) script_submodule = script_submodule._modules[each_name] - subgraph, sub_m_attrs = self.convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model) + subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model) else: raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) @@ -566,29 +570,7 @@ def _handle_valuechoice(self, module): 'accessor': module._accessor } - def convert_module(self, script_module, module, module_name, ir_model): - """ - Convert a module to its graph ir (i.e., Graph) along with its input arguments - - Parameters - ---------- - script_module : torch.jit.RecursiveScriptModule - the script module of ```module``` obtained with torch.jit.script - module : nn.Module - the targeted module instance - module_name : str - the constructed name space of ```module``` - ir_model : Model - the whole graph ir - - Returns - ------- - Graph - the built graph ir from module, ```None``` means do not further parse the module - dict - the input arguments of this module - """ - + def _convert_module(self, script_module, module, module_name, ir_model): # NOTE: have not supported nested LayerChoice, i.e., a candidate module # also has LayerChoice or InputChoice or ValueChoice original_type_name = script_module.original_name @@ -597,10 +579,18 @@ def convert_module(self, script_module, module, module_name, ir_model): pass # do nothing elif original_type_name == OpTypeName.LayerChoice: graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now - candidate_name_list = [f'layerchoice_{module.label}_{cand_name}' for cand_name in module.names] - for cand_name, cand in zip(candidate_name_list, module): - cand_type = '__torch__.' + get_importable_name(cand.__class__) - graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand)) + candidate_name_list = [] + for cand_name in module.names: + cand = module[cand_name] + script_cand = script_module._modules[cand_name] + cand_name = build_cand_name(cand_name, module.label) + candidate_name_list.append(cand_name) + subgraph, attrs = self._convert_module(script_cand, cand, cand_name, ir_model) + if subgraph is not None: + graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs)) + else: + cand_type = '__torch__.' + get_importable_name(cand.__class__) + graph.add_node(cand_name, cand_type, attrs) graph._register() return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list} elif original_type_name == OpTypeName.InputChoice: @@ -654,8 +644,214 @@ def convert_module(self, script_module, module, module_name, ir_model): return ir_graph, {} + def convert_module(self, script_module, module, module_name, ir_model): + """ + Convert a module to its graph ir (i.e., Graph) along with its input arguments -def convert_to_graph(script_module, module): + Parameters + ---------- + script_module : torch.jit.RecursiveScriptModule + the script module of ```module``` obtained with torch.jit.script + module : nn.Module + the targeted module instance + module_name : str + the constructed name space of ```module``` + ir_model : Model + the whole graph ir + + Returns + ------- + Graph + the built graph ir from module, ```None``` means do not further parse the module + dict + the input arguments of this module + """ + + return self._convert_module(script_module, module, module_name, ir_model) + + +class GraphConverterWithShape(GraphConverter): + """ + Convert a pytorch model to nni ir along with input/output shape info. + Based ir acquired through `torch.jit.script` + and shape info acquired through `torch.jit.trace`. + + Known issues + ------------ + 1. `InputChoice` and `ValueChoice` not supported yet. + 2. Currently random inputs are feeded while tracing layerchoice. + If forward path of candidates depends on input data, then wrong path will be traced. + This will result in incomplete shape info. + """ + def convert_module(self, script_module, module, module_name, ir_model, example_inputs): + module.eval() + + ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model) + self.remove_dummy_nodes(ir_model) + self._initialize_parameters(ir_model) + self._trace_module(module, module_name, ir_model, example_inputs) + return ir_graph, attrs + + def _initialize_parameters(self, ir_model: 'Model'): + for ir_node in ir_model.get_nodes(): + if ir_node.operation.parameters is None: + ir_node.operation.parameters = {} + ir_node.operation.parameters.setdefault('input_shape', []) + ir_node.operation.parameters.setdefault('output_shape', []) + + def _trace_module(self, module, module_name, ir_model: 'Model', example_inputs): + # First, trace the whole graph + tm_graph = self._trace(module, example_inputs) + + for node in tm_graph.nodes(): + parameters = _extract_info_from_trace_node(node) + # '__module.convpool/__module.convpool.1/__module.convpool.1.conv' + ir_node = match_node(ir_model, node, module_name) + if ir_node is not None: + ir_node.operation.parameters.update(parameters) + + self.propagate_shape(ir_model) + + # trace each layerchoice + for name, submodule in module.named_modules(): + # TODO: support InputChoice and ValueChioce + if isinstance(submodule, LayerChoice): + full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name) + lc_node = ir_model.get_node_by_name(full_name) + + for cand_name in submodule.names: + cand = submodule[cand_name] + cand_name = build_cand_name(cand_name, submodule.label) + # TODO: Feed the exact input tensor if user provides input, + # in case the path changes according to input data. + lc_inputs = [torch.randn(shape) for shape in lc_node.operation.parameters['input_shape']] + self._trace_module(cand, cand_name, ir_model, lc_inputs) + + def propagate_shape(self, ir_model: 'Model'): + + def propagate_shape_for_graph(graph: 'Graph'): + if graph == ir_model.root_graph: + return + + graph_node = ir_model.get_node_by_name(graph.name) + if not _without_shape_info(graph_node): + return + + if is_layerchoice_node(graph_node): + cand_name = graph_node.operation.parameters['candidates'][0] + cand_node = ir_model.get_node_by_name(cand_name) + if _without_shape_info(cand_node): + propagate_shape_for_graph(ir_model.graphs[cand_name]) + graph_node.operation.parameters['input_shape'] = cand_node.operation.parameters['input_shape'] + graph_node.operation.parameters['output_shape'] = cand_node.operation.parameters['output_shape'] + else: + input_shape = [[]] * len(graph.input_node.operation.io_names or []) + output_shape = [[]] * len(graph.output_node.operation.io_names or []) + for edge in graph.input_node.outgoing_edges: + node = edge.tail + if _without_shape_info(node): + if node.name in ir_model.graphs: + propagate_shape_for_graph(ir_model.graphs[node.name]) + if node.operation.parameters['input_shape']: + input_shape[edge.head_slot or 0] = node.operation.parameters['input_shape'][edge.tail_slot or 0] + graph_node.operation.parameters['input_shape'] = input_shape + for edge in graph.output_node.incoming_edges: + node = edge.head + if _without_shape_info(node): + if node.name in ir_model.graphs: + propagate_shape_for_graph(ir_model.graphs[node.name]) + if node.operation.parameters['output_shape']: + output_shape[edge.tail_slot or 0] = node.operation.parameters['output_shape'][edge.head_slot or 0] + graph_node.operation.parameters['output_shape'] = output_shape + + propagate_shape_for_graph(graph_node.graph) + + # propagate from node to graph + for node in ir_model.get_nodes(): + propagate_shape_for_graph(node.graph) + + def flatten(self, ir_model: 'Model'): + """ + Flatten the subgraph into root graph. + """ + def _flatten(graph: 'Graph'): + """ + flatten this graph + """ + model = graph.model + node_to_remove = [] + + for node in graph.hidden_nodes: + node_graph = model.graphs.get(node.name) + if node_graph is not None: + _flatten(node_graph) + + # flatten node graph into this graph + id_to_new_node = {} + for node_graph_node in node_graph.hidden_nodes: + new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True) + new_node.update_label(node_graph_node.label) + new_node._register() + id_to_new_node[new_node.id] = new_node + + # reconnect node edges + for in_edge in node.incoming_edges: + graph.del_edge(in_edge) + for input_node_edge in node_graph.input_node.outgoing_edges: + if input_node_edge.head_slot == in_edge.tail_slot: + graph.add_edge( + head=(in_edge.head, in_edge.head_slot), + tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot)) + + for out_edge in node.outgoing_edges: + graph.del_edge(out_edge) + for output_node_edge in node_graph.output_node.incoming_edges: + if output_node_edge.head_slot == out_edge.tail_slot: + try: + graph.add_edge( + head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot), + tail=(out_edge.tail, out_edge.tail_slot)) + except: + import pdb; pdb.set_trace() + + for edge in node_graph.edges: + if edge.head == node_graph.input_node or edge.tail == node_graph.output_node: + continue + new_head = id_to_new_node[edge.head.id] + new_tail = id_to_new_node[edge.tail.id] + Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register() + + node_to_remove.append(node) + del model.graphs[node.name] + + for node in node_to_remove: + node.remove() + + _flatten(ir_model.root_graph) + + # remove subgraphs + ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph} + + def _trace(self, module, example_inputs): + traced_module = torch.jit.trace(module, example_inputs) + torch._C._jit_pass_inline(traced_module.graph) + return traced_module.graph + + def remove_dummy_nodes(self, ir_model: 'Model'): + # remove identity nodes + for node in ir_model.get_nodes_by_type('noop_identity'): + graph = node.graph + for in_edge in node.incoming_edges: + for out_edge in node.outgoing_edges: + if in_edge.tail_slot == out_edge.head_slot: + graph.add_edge(head=(in_edge.head, in_edge.head_slot), tail=(out_edge.tail, out_edge.tail_slot)) + graph.del_edge(in_edge) + graph.del_edge(out_edge) + break + node.remove() + + +def convert_to_graph(script_module, module, converter=None, **kwargs): """ Convert module to our graph ir, i.e., build a ```Model``` type @@ -665,6 +861,10 @@ def convert_to_graph(script_module, module): the script module obtained with torch.jit.script module : nn.Module the targeted module instance + converter : `TorchConverter` + default `GraphConverter` is used + kwargs: + will be passed to `converter.convert_module()` Returns ------- @@ -674,6 +874,8 @@ def convert_to_graph(script_module, module): model = Model(_internal=True) module_name = '_model' - GraphConverter().convert_module(script_module, module, module_name, model) + if converter is None: + converter = GraphConverter() + converter.convert_module(script_module, module, module_name, model, **kwargs) return model diff --git a/nni/retiarii/converter/utils.py b/nni/retiarii/converter/utils.py index c43f62176e..8b47d7b983 100644 --- a/nni/retiarii/converter/utils.py +++ b/nni/retiarii/converter/utils.py @@ -1,6 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from ..operation import Cell +from ..graph import Model, Node + + def build_full_name(prefix, name, seq=None): if isinstance(name, list): name = '__'.join(name) @@ -10,8 +14,98 @@ def build_full_name(prefix, name, seq=None): return '{}__{}{}'.format(prefix, name, str(seq)) +def build_cand_name(name, label): + return f'layerchoice_{label}_{name}' + + def _convert_name(name: str) -> str: """ Convert the names using separator '.' to valid variable name in code """ return name.replace('.', '__') + + +def _extract_info_from_trace_node(trace_node): + """ + Extract parameters from a trace node. + + Parameters + ---------- + trace_node: torch._C.Value + """ + input_shape = [] + output_shape = [] + + inputs = list(trace_node.inputs()) + + # cat input tensors are in a strange place + if trace_node.kind() == 'aten::cat': + input_shape = [input.type().sizes() for input in inputs[0].node().inputs()] + else: + for _input in inputs: + input_type = _input.type() + if input_type.kind() == 'TensorType': + shape = input_type.sizes() + if shape: + input_shape.append(shape) + + for _output in trace_node.outputs(): + output_type = _output.type() + if output_type.kind() == 'TensorType': + shape = output_type.sizes() + if shape: + output_shape.append(shape) + + parameters = { + 'input_shape': input_shape, + 'output_shape': output_shape, + } + + if trace_node.kind() == 'aten::cat': + parameters['dim'] = inputs[1].toIValue() + + return parameters + + +def is_layerchoice_node(ir_node: Node): + if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice': + return True + else: + return False + + +def get_full_name_by_scope_name(ir_model: Model, scope_names, prefix=''): + full_name = prefix + + for last_scope in range(len(scope_names)): + ir_node = ir_model.get_node_by_name(full_name) + # check if it's layerchoice + if is_layerchoice_node(ir_node): + full_name = f'layerchoice_{ir_node.operation.parameters["label"]}_{scope_names[last_scope]}' + else: + full_name = build_full_name(full_name, scope_names[last_scope]) + + return full_name + + +def match_node(ir_model: Model, torch_node, prefix=''): + """ + Match the corresponding node of a torch._C.Value + """ + scope_names = torch_node.scopeName().split('/')[-1].split('.')[1:] + full_name = get_full_name_by_scope_name(ir_model, scope_names, prefix) + # handle the case when node is not nn.Module, but directly used in forward() + # Because name can't be directly matched, so I use a hacky way. + # I match the first unshaped node of that kind + graph = ir_model.graphs.get(full_name) + if graph is not None: + for node in graph.get_nodes_by_type(torch_node.kind()): + if not node.operation.parameters['input_shape']: + return node + return None + else: + return ir_model.get_node_by_name(full_name) + + +def _without_shape_info(node: Node): + return not node.operation.parameters['input_shape'] and not node.operation.parameters['output_shape'] diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index 0780439cff..5ce34db0c9 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -28,6 +28,7 @@ from ..codegen import model_to_pytorch_script from ..converter import convert_to_graph +from ..converter.graph_gen import GraphConverterWithShape from ..execution import list_models, set_execution_engine from ..execution.python import get_mutation_dict from ..graph import Model, Evaluator @@ -58,6 +59,9 @@ class RetiariiExeConfig(ConfigBase): training_service: TrainingServiceConfig execution_engine: str = 'py' + # input used in GraphConverterWithShape. Currently support shape tuple only. + example_inputs: Optional[List[int]] = None + def __init__(self, training_service_platform: Optional[str] = None, **kwargs): super().__init__(**kwargs) if training_service_platform is not None: @@ -106,7 +110,7 @@ def _validation_rules(self): 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') } -def preprocess_model(base_model, trainer, applied_mutators, full_ir=True): +def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, example_inputs=None): # TODO: this logic might need to be refactored into execution engine if full_ir: try: @@ -114,7 +118,13 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True): except Exception as e: _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') raise e - base_model_ir = convert_to_graph(script_module, base_model) + if example_inputs is not None: + # FIXME: this is a workaround as full tensor is not supported in configs + example_inputs = torch.randn(*example_inputs) + converter = GraphConverterWithShape() + base_model_ir = convert_to_graph(script_module, base_model, converter, example_inputs=example_inputs) + else: + base_model_ir = convert_to_graph(script_module, base_model) # handle inline mutations mutators = process_inline_mutation(base_model_ir) else: @@ -171,7 +181,8 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT def _start_strategy(self): base_model_ir, self.applied_mutators = preprocess_model( - self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py') + self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py', + example_inputs=self.config.example_inputs) _logger.info('Start strategy...') self.strategy.run(base_model_ir, self.applied_mutators) diff --git a/nni/retiarii/graph.py b/nni/retiarii/graph.py index 3eba65805d..0adf0bdaf3 100644 --- a/nni/retiarii/graph.py +++ b/nni/retiarii/graph.py @@ -307,9 +307,9 @@ def _add_output(self, output_name) -> None: @overload def add_node(self, name: str, operation: Operation) -> 'Node': ... @overload - def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ... + def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ... - def add_node(self, name, operation_or_type, parameters={}): + def add_node(self, name, operation_or_type, parameters=None): if isinstance(operation_or_type, Operation): op = operation_or_type else: @@ -319,9 +319,9 @@ def add_node(self, name, operation_or_type, parameters={}): @overload def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ... @overload - def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ... + def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ... - def insert_node_on_edge(self, edge, name, operation_or_type, parameters={}) -> 'Node': + def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': if isinstance(operation_or_type, Operation): op = operation_or_type else: @@ -562,9 +562,9 @@ def update_label(self, label: str) -> None: @overload def update_operation(self, operation: Operation) -> None: ... @overload - def update_operation(self, type_name: str, parameters: Dict[str, Any] = {}) -> None: ... + def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ... - def update_operation(self, operation_or_type, parameters={}): + def update_operation(self, operation_or_type, parameters=None): if isinstance(operation_or_type, Operation): self.operation = operation_or_type else: diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py index 69d12fb908..931119a230 100644 --- a/nni/retiarii/nn/pytorch/api.py +++ b/nni/retiarii/nn/pytorch/api.py @@ -90,6 +90,7 @@ def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], lab self.names.append(str(i)) else: raise TypeError("Unsupported candidates type: {}".format(type(candidates))) + self._first_module = self._modules[self.names[0]] # to make the dummy forward meaningful @property def key(self): @@ -143,7 +144,7 @@ def _choices(self): def forward(self, x): warnings.warn('You should not run forward of this module directly.') - return x + return self._first_module(x) def __repr__(self): return f'LayerChoice({self.candidates}, label={repr(self.label)})' diff --git a/nni/retiarii/operation.py b/nni/retiarii/operation.py index d97f87f46b..d8b23d1d60 100644 --- a/nni/retiarii/operation.py +++ b/nni/retiarii/operation.py @@ -52,7 +52,9 @@ def __bool__(self) -> bool: return True @staticmethod - def new(type_name: str, parameters: Dict[str, Any] = {}, cell_name: str = None) -> 'Operation': + def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None) -> 'Operation': + if parameters is None: + parameters = {} if type_name == '_cell': # NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node return Cell(cell_name, parameters) @@ -199,9 +201,11 @@ def forward(...): No real usage. Exists for compatibility with base class. """ - def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}): + def __init__(self, cell_name: str, parameters: Dict[str, Any] = None): self.type = '_cell' self.cell_name = cell_name + if parameters is None: + parameters = {} self.parameters = parameters def _to_class_name(self): diff --git a/nni/retiarii/strategy/bruteforce.py b/nni/retiarii/strategy/bruteforce.py index 971711f0d9..04669b46be 100644 --- a/nni/retiarii/strategy/bruteforce.py +++ b/nni/retiarii/strategy/bruteforce.py @@ -10,7 +10,7 @@ from .. import Sampler, submit_models, query_available_resources, budget_exhausted from .base import BaseStrategy -from .utils import dry_run_for_search_space, get_targeted_model +from .utils import dry_run_for_search_space, get_targeted_model, filter_model _logger = logging.getLogger(__name__) @@ -84,15 +84,18 @@ class Random(BaseStrategy): Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false. dedup : bool Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true. + model_filter: Callable[[Model], bool] + Feed the model and return a bool. This will filter the models in search space and select which to submit. """ - def __init__(self, variational=False, dedup=True): + def __init__(self, variational=False, dedup=True, model_filter=None): self.variational = variational self.dedup = dedup if variational and dedup: raise ValueError('Dedup is not supported in variational mode.') self.random_sampler = _RandomSampler() self._polling_interval = 2. + self.filter = model_filter def run(self, base_model, applied_mutators): if self.variational: @@ -107,7 +110,8 @@ def run(self, base_model, applied_mutators): for mutator in applied_mutators: model = mutator.apply(model) _logger.debug('New model created. Applied mutators are: %s', str(applied_mutators)) - submit_models(model) + if filter_model(self.filter, model): + submit_models(model) elif budget_exhausted(): break else: @@ -121,4 +125,6 @@ def run(self, base_model, applied_mutators): if budget_exhausted(): return time.sleep(self._polling_interval) - submit_models(get_targeted_model(base_model, applied_mutators, sample)) + model = get_targeted_model(base_model, applied_mutators, sample) + if filter_model(self.filter, model): + submit_models(model) diff --git a/nni/retiarii/strategy/utils.py b/nni/retiarii/strategy/utils.py index 8b687c4c21..87d9d24037 100644 --- a/nni/retiarii/strategy/utils.py +++ b/nni/retiarii/strategy/utils.py @@ -1,11 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. + import collections +import logging from typing import Dict, Any, List from ..graph import Model from ..mutator import Mutator, Sampler +_logger = logging.getLogger(__name__) + class _FixedSampler(Sampler): def __init__(self, sample): @@ -30,3 +34,16 @@ def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) for mutator in mutators: model = mutator.bind_sampler(sampler).apply(model) return model + + +def filter_model(model_filter, ir_model): + if model_filter is not None: + _logger.debug(f'Check if model satisfies constraints.') + if model_filter(ir_model): + _logger.debug(f'Model satisfied. Submit the model.') + return True + else: + _logger.debug(f'Model unsatisfied. Discard the model.') + return False + else: + return True diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index e246e748f2..6426f049ba 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -111,6 +111,33 @@ def forward(self, x): self.assertEqual(self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, i, 3, 3])) + def test_nested_layer_choice(self): + @self.get_serializer() + class Net(nn.Module): + def __init__(self): + super().__init__() + self.module = nn.LayerChoice([ + nn.LayerChoice([nn.Conv2d(3, 3, kernel_size=1), + nn.Conv2d(3, 4, kernel_size=1), + nn.Conv2d(3, 5, kernel_size=1)]), + nn.Conv2d(3, 1, kernel_size=1) + ]) + + def forward(self, x): + return self.module(x) + + model, mutators = self._get_model_with_mutators(Net()) + self.assertEqual(len(mutators), 2) + mutators[0].bind_sampler(EnumerateSampler()) + mutators[1].bind_sampler(EnumerateSampler()) + input = torch.randn(1, 3, 5, 5) + self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), + torch.Size([1, 3, 5, 5])) + self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), + torch.Size([1, 1, 5, 5])) + self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), + torch.Size([1, 5, 5, 5])) + def test_input_choice(self): @self.get_serializer() class Net(nn.Module): From bb3974807e3d9ab847de71a58a8f2985810a8166 Mon Sep 17 00:00:00 2001 From: kalineid Date: Tue, 20 Jul 2021 14:24:18 +0800 Subject: [PATCH 3/4] Add tests for GraphConverterWithShape (#3951) --- test/ut/retiarii/convert_mixin.py | 19 ++++++ test/ut/retiarii/test_convert.py | 11 +-- test/ut/retiarii/test_convert_basic.py | 24 ++++--- test/ut/retiarii/test_convert_models.py | 11 +-- test/ut/retiarii/test_convert_operators.py | 37 +++++----- test/ut/retiarii/test_convert_pytorch.py | 13 ++-- test/ut/retiarii/test_convert_shape.py | 79 ++++++++++++++++++++++ 7 files changed, 154 insertions(+), 40 deletions(-) create mode 100644 test/ut/retiarii/convert_mixin.py create mode 100644 test/ut/retiarii/test_convert_shape.py diff --git a/test/ut/retiarii/convert_mixin.py b/test/ut/retiarii/convert_mixin.py new file mode 100644 index 0000000000..bfe54238fc --- /dev/null +++ b/test/ut/retiarii/convert_mixin.py @@ -0,0 +1,19 @@ +import torch + +from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape + + +class ConvertMixin: + @staticmethod + def _convert_model(model, input): + script_module = torch.jit.script(model) + model_ir = convert_to_graph(script_module, model) + return model_ir + + +class ConvertWithShapeMixin: + @staticmethod + def _convert_model(model, input): + script_module = torch.jit.script(model) + model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input) + return model_ir diff --git a/test/ut/retiarii/test_convert.py b/test/ut/retiarii/test_convert.py index c59d0aa9f7..c79c7696af 100644 --- a/test/ut/retiarii/test_convert.py +++ b/test/ut/retiarii/test_convert.py @@ -13,9 +13,10 @@ import nni.retiarii.nn.pytorch as nn from nni.retiarii import basic_unit -from nni.retiarii.converter import convert_to_graph from nni.retiarii.codegen import model_to_pytorch_script +from .convert_mixin import ConvertMixin, ConvertWithShapeMixin + class MnistNet(nn.Module): def __init__(self): super(MnistNet, self).__init__() @@ -48,7 +49,7 @@ def forward(self, input): out = self.linear(input.view(size[0] * size[1], -1)) return out.view(size[0], size[1], -1) -class TestConvert(unittest.TestCase): +class TestConvert(unittest.TestCase, ConvertMixin): @staticmethod def _match_state_dict(current_values, expected_format): result = {} @@ -61,8 +62,7 @@ def _match_state_dict(current_values, expected_format): return result def checkExportImport(self, model, input): - script_module = torch.jit.script(model) - model_ir = convert_to_graph(script_module, model) + model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) exec_vars = {} @@ -579,3 +579,6 @@ def test_alexnet(self): self.checkExportImport(model, (x,)) finally: remove_inject_pytorch_nn() + +class TestConvertWithShape(TestConvert, ConvertWithShapeMixin): + pass diff --git a/test/ut/retiarii/test_convert_basic.py b/test/ut/retiarii/test_convert_basic.py index b2148f4cf0..145f62f636 100644 --- a/test/ut/retiarii/test_convert_basic.py +++ b/test/ut/retiarii/test_convert_basic.py @@ -9,12 +9,13 @@ import nni.retiarii.nn.pytorch as nn from nni.retiarii import basic_unit -from nni.retiarii.converter import convert_to_graph + +from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from nni.retiarii.codegen import model_to_pytorch_script # following pytorch v1.7.1 -class TestConvert(unittest.TestCase): +class TestConvert(unittest.TestCase, ConvertMixin): @staticmethod def _match_state_dict(current_values, expected_format): result = {} @@ -27,8 +28,7 @@ def _match_state_dict(current_values, expected_format): return result def checkExportImport(self, model, input, check_value=True): - script_module = torch.jit.script(model) - model_ir = convert_to_graph(script_module, model) + model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) print(model_code) @@ -188,7 +188,7 @@ def forward(self, x, y, z): out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2) return out1, out2 self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), )) - + def test_basic_addr(self): class SimpleOp(nn.Module): def forward(self, x, y, z): @@ -204,7 +204,7 @@ def forward(self, x, y): out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False) return out1, out2 self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), )) - + def test_basic_angle(self): class SimpleOp(nn.Module): def forward(self, x): @@ -229,7 +229,7 @@ def forward(self, x): o4 = x.argmin(dim=1, keepdim=True) return out1, out2, out3, out4, out5, o1, o2, o3, o4 self.checkExportImport(SimpleOp(), (torch.randn(4, 4), )) - + def test_basic_argsort(self): class SimpleOp(nn.Module): def forward(self, x): @@ -241,7 +241,7 @@ def forward(self, x): self.checkExportImport(SimpleOp(), (torch.randn(4, 4), )) # skip backward(gradient=None, retain_graph=None, create_graph=False) - + def test_basic_bernoulli(self): class SimpleOp(nn.Module): def forward(self, x): @@ -261,7 +261,7 @@ def forward(self, x, y): out4 = x.bincount(weights=y, minlength=2) return out1, out2, out3, out4 self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), )) - + def test_basic_bitwise(self): class SimpleOp(nn.Module): def forward(self, x, y): @@ -279,4 +279,8 @@ class SimpleOp(nn.Module): def forward(self, x): out1 = x.ceil() return out1 - self.checkExportImport(SimpleOp(), (torch.randn(4), )) \ No newline at end of file + self.checkExportImport(SimpleOp(), (torch.randn(4), )) + + +class TestConvertWithShape(TestConvert, ConvertWithShapeMixin): + pass diff --git a/test/ut/retiarii/test_convert_models.py b/test/ut/retiarii/test_convert_models.py index 26ee327671..f8a7ae8665 100644 --- a/test/ut/retiarii/test_convert_models.py +++ b/test/ut/retiarii/test_convert_models.py @@ -10,11 +10,12 @@ import nni.retiarii.nn.pytorch as nn from nni.retiarii import serialize -from nni.retiarii.converter import convert_to_graph from nni.retiarii.codegen import model_to_pytorch_script +from .convert_mixin import ConvertMixin, ConvertWithShapeMixin -class TestModels(unittest.TestCase): + +class TestModels(unittest.TestCase, ConvertMixin): @staticmethod def _match_state_dict(current_values, expected_format): result = {} @@ -27,8 +28,7 @@ def _match_state_dict(current_values, expected_format): return result def run_test(self, model, input, check_value=True): - script_module = torch.jit.script(model) - model_ir = convert_to_graph(script_module, model) + model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) print(model_code) @@ -89,3 +89,6 @@ def forward(self, x: List[torch.Tensor]): model = Net(4) x = torch.rand((1, 16), dtype=torch.float) self.run_test(model, ([x], )) + +class TestModelsWithShape(TestModels, ConvertWithShapeMixin): + pass diff --git a/test/ut/retiarii/test_convert_operators.py b/test/ut/retiarii/test_convert_operators.py index 8500892375..a30c1f0843 100644 --- a/test/ut/retiarii/test_convert_operators.py +++ b/test/ut/retiarii/test_convert_operators.py @@ -15,13 +15,14 @@ import torchvision import nni.retiarii.nn.pytorch as nn -from nni.retiarii.converter import convert_to_graph from nni.retiarii.codegen import model_to_pytorch_script +from .convert_mixin import ConvertMixin, ConvertWithShapeMixin + # following pytorch v1.7.1 -class TestOperators(unittest.TestCase): +class TestOperators(unittest.TestCase, ConvertMixin): @staticmethod def _match_state_dict(current_values, expected_format): result = {} @@ -34,8 +35,7 @@ def _match_state_dict(current_values, expected_format): return result def checkExportImport(self, model, input, check_value=True): - script_module = torch.jit.script(model) - model_ir = convert_to_graph(script_module, model) + model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) #print(model_code) @@ -1042,7 +1042,7 @@ def forward(self, x): x = torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_batchnorm(self): class SimpleOp(nn.Module): @@ -1056,7 +1056,7 @@ def forward(self, x): x = torch.ones(2, 2, 2, 2, requires_grad=True) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_batchnorm_1d(self): class SimpleOp(nn.Module): @@ -1084,7 +1084,7 @@ def forward(self, x): x = torch.ones(20, 16, 50, 40, requires_grad=True) self.checkExportImport(SimpleOp(), (x, )) - + def test_conv_onnx_irv4_opset8(self): # This test point checks that for opset 8 (or lower), even if # keep_initializers_as_inputs is set to False, it is ignored, @@ -1129,7 +1129,7 @@ def forward(self, x): x = torch.randn(20, 16, 50) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_maxpool_dilations(self): class SimpleOp(nn.Module): @@ -1143,7 +1143,7 @@ def forward(self, x): x = torch.randn(20, 16, 50) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_avg_pool2d(self): class SimpleOp(nn.Module): @@ -1157,7 +1157,7 @@ def forward(self, x): x = torch.randn(20, 16, 50, 32) self.checkExportImport(SimpleOp(), (x, )) - + @unittest.skip('jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"') def test_basic_maxpool_indices(self): class SimpleOp(nn.Module): @@ -1200,7 +1200,7 @@ def forward(self, x): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_elu(self): class SimpleOp(nn.Module): @@ -1214,7 +1214,7 @@ def forward(self, x): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_selu(self): class SimpleOp(nn.Module): @@ -1261,7 +1261,7 @@ def forward(self, x): x = torch.randn(128, 128, 1, 1, requires_grad=True) self.checkExportImport(SimpleOp(), (x, )) - + def test_embedding_bags(self): class SimpleOp(nn.Module): def __init__(self): @@ -1288,7 +1288,7 @@ def forward(self, x): x = torch.randn(1, 2, 3, 4) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_prelu(self): class SimpleOp(nn.Module): @@ -1302,7 +1302,7 @@ def forward(self, x): x = torch.randn(1, 2, 3, 4) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_log_sigmoid(self): class SimpleOp(nn.Module): @@ -1316,7 +1316,7 @@ def forward(self, x): x = torch.randn(1, 2, 3, 4) self.checkExportImport(SimpleOp(), (x, )) - + def test_basic_linear(self): class SimpleOp(nn.Module): @@ -1385,4 +1385,7 @@ def forward(self, x): return out x = torch.randn(20, 5, 10, 10) - self.checkExportImport(SimpleOp(), (x, )) \ No newline at end of file + self.checkExportImport(SimpleOp(), (x, )) + +class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin): + pass diff --git a/test/ut/retiarii/test_convert_pytorch.py b/test/ut/retiarii/test_convert_pytorch.py index 51857b6815..267703456d 100644 --- a/test/ut/retiarii/test_convert_pytorch.py +++ b/test/ut/retiarii/test_convert_pytorch.py @@ -15,11 +15,12 @@ import nni.retiarii.nn.pytorch as nn from nni.retiarii import serialize -from nni.retiarii.converter import convert_to_graph from nni.retiarii.codegen import model_to_pytorch_script +from .convert_mixin import ConvertMixin, ConvertWithShapeMixin -class TestPytorch(unittest.TestCase): + +class TestPytorch(unittest.TestCase, ConvertMixin): @staticmethod def _match_state_dict(current_values, expected_format): result = {} @@ -32,8 +33,7 @@ def _match_state_dict(current_values, expected_format): return result def run_test(self, model, input, check_value=True): - script_module = torch.jit.script(model) - model_ir = convert_to_graph(script_module, model) + model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) print(model_code) @@ -1230,4 +1230,7 @@ def forward(self, input): return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape) x = torch.randn(5, 3, 2) - self.run_test(SizeModel(10, 5), (x, )) \ No newline at end of file + self.run_test(SizeModel(10, 5), (x, )) + +class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin): + pass diff --git a/test/ut/retiarii/test_convert_shape.py b/test/ut/retiarii/test_convert_shape.py new file mode 100644 index 0000000000..a6aca72892 --- /dev/null +++ b/test/ut/retiarii/test_convert_shape.py @@ -0,0 +1,79 @@ +import unittest +import torch + +import nni.retiarii.nn.pytorch as nn + +from .convert_mixin import ConvertWithShapeMixin + + +class TestShape(unittest.TestCase, ConvertWithShapeMixin): + def test_simple_convnet(self): + class ConvNet(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 1, 3) + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(kernel_size=2) + def forward(self, x): + return self.pool(self.relu(self.conv(x))) + + net = ConvNet() + input = torch.randn((1, 3, 224, 224)) + model_ir = self._convert_model(net, input) + + conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0] + relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0] + pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0] + self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]]) + self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) + self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]]) + self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) + self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]]) + self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]]) + + def test_nested_module(self): + class ConvRelu(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 1, 3) + self.relu = nn.ReLU() + def forward(self, x): + return self.relu(self.conv(x)) + + class ConvNet(nn.Module): + def __init__(self): + super().__init__() + self.conv = ConvRelu() + self.pool = nn.MaxPool2d(kernel_size=2) + def forward(self, x): + return self.pool(self.conv(x)) + + net = ConvNet() + input = torch.randn((1, 3, 224, 224)) + model_ir = self._convert_model(net, input) + + # check if shape propagation works + cell_node = model_ir.get_nodes_by_type('_cell')[0] + self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]]) + self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) + + def test_layerchoice(self): + class ConvNet(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.LayerChoice([ + nn.Conv2d(3, 1, 3), + nn.Conv2d(3, 1, 5, padding=1), + ]) + self.pool = nn.MaxPool2d(kernel_size=2) + def forward(self, x): + return self.pool(self.conv(x)) + + net = ConvNet() + input = torch.randn((1, 3, 224, 224)) + model_ir = self._convert_model(net, input) + + # check shape info of each candidates + conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d') + self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) + self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) From 5e04d56c51323348ece23de1fc8f49fb25527882 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 26 Jul 2021 22:17:21 +0800 Subject: [PATCH 4/4] Replace example_inputs with dummy_input (#3983) --- docs/en_US/NAS/HardwareAwareNAS.rst | 6 +++--- examples/nas/oneshot/spos/multi_trial.py | 2 +- nni/retiarii/converter/graph_gen.py | 12 ++++++------ nni/retiarii/experiment/pytorch.py | 12 ++++++------ test/ut/retiarii/convert_mixin.py | 2 +- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/en_US/NAS/HardwareAwareNAS.rst b/docs/en_US/NAS/HardwareAwareNAS.rst index 81eaeb7393..cec143c28b 100644 --- a/docs/en_US/NAS/HardwareAwareNAS.rst +++ b/docs/en_US/NAS/HardwareAwareNAS.rst @@ -32,10 +32,10 @@ To support latency-aware NAS, you first need a `Strategy` that supports filterin ``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example). You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency. -Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, example_inputs=example_inputs``: +Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, dummy_input=dummy_input``: .. code-block:: python - RetiariiExperiment(base_model, trainer, [], simple_strategy, True, example_inputs) + RetiariiExperiment(base_model, trainer, [], simple_strategy, True, dummy_input) -Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``example_inputs`` is required for tracing shape info. +Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``dummy_input`` is required for tracing shape info. diff --git a/examples/nas/oneshot/spos/multi_trial.py b/examples/nas/oneshot/spos/multi_trial.py index 730e688142..71bdb283e1 100644 --- a/examples/nas/oneshot/spos/multi_trial.py +++ b/examples/nas/oneshot/spos/multi_trial.py @@ -185,7 +185,7 @@ def _main(port): exp_config.trial_gpu_number = 1 exp_config.training_service.use_active_gpu = False exp_config.execution_engine = 'base' - exp_config.example_inputs = [1, 3, 32, 32] + exp_config.dummy_input = [1, 3, 32, 32] exp.run(exp_config, port) diff --git a/nni/retiarii/converter/graph_gen.py b/nni/retiarii/converter/graph_gen.py index e402082fc1..e3dc323ae9 100644 --- a/nni/retiarii/converter/graph_gen.py +++ b/nni/retiarii/converter/graph_gen.py @@ -683,13 +683,13 @@ class GraphConverterWithShape(GraphConverter): If forward path of candidates depends on input data, then wrong path will be traced. This will result in incomplete shape info. """ - def convert_module(self, script_module, module, module_name, ir_model, example_inputs): + def convert_module(self, script_module, module, module_name, ir_model, dummy_input): module.eval() ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model) self.remove_dummy_nodes(ir_model) self._initialize_parameters(ir_model) - self._trace_module(module, module_name, ir_model, example_inputs) + self._trace_module(module, module_name, ir_model, dummy_input) return ir_graph, attrs def _initialize_parameters(self, ir_model: 'Model'): @@ -699,9 +699,9 @@ def _initialize_parameters(self, ir_model: 'Model'): ir_node.operation.parameters.setdefault('input_shape', []) ir_node.operation.parameters.setdefault('output_shape', []) - def _trace_module(self, module, module_name, ir_model: 'Model', example_inputs): + def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input): # First, trace the whole graph - tm_graph = self._trace(module, example_inputs) + tm_graph = self._trace(module, dummy_input) for node in tm_graph.nodes(): parameters = _extract_info_from_trace_node(node) @@ -832,8 +832,8 @@ def _flatten(graph: 'Graph'): # remove subgraphs ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph} - def _trace(self, module, example_inputs): - traced_module = torch.jit.trace(module, example_inputs) + def _trace(self, module, dummy_input): + traced_module = torch.jit.trace(module, dummy_input) torch._C._jit_pass_inline(traced_module.graph) return traced_module.graph diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index 5ce34db0c9..76252e1e83 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -60,7 +60,7 @@ class RetiariiExeConfig(ConfigBase): execution_engine: str = 'py' # input used in GraphConverterWithShape. Currently support shape tuple only. - example_inputs: Optional[List[int]] = None + dummy_input: Optional[List[int]] = None def __init__(self, training_service_platform: Optional[str] = None, **kwargs): super().__init__(**kwargs) @@ -110,7 +110,7 @@ def _validation_rules(self): 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') } -def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, example_inputs=None): +def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_input=None): # TODO: this logic might need to be refactored into execution engine if full_ir: try: @@ -118,11 +118,11 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, exampl except Exception as e: _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') raise e - if example_inputs is not None: + if dummy_input is not None: # FIXME: this is a workaround as full tensor is not supported in configs - example_inputs = torch.randn(*example_inputs) + dummy_input = torch.randn(*dummy_input) converter = GraphConverterWithShape() - base_model_ir = convert_to_graph(script_module, base_model, converter, example_inputs=example_inputs) + base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input) else: base_model_ir = convert_to_graph(script_module, base_model) # handle inline mutations @@ -182,7 +182,7 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT def _start_strategy(self): base_model_ir, self.applied_mutators = preprocess_model( self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py', - example_inputs=self.config.example_inputs) + dummy_input=self.config.dummy_input) _logger.info('Start strategy...') self.strategy.run(base_model_ir, self.applied_mutators) diff --git a/test/ut/retiarii/convert_mixin.py b/test/ut/retiarii/convert_mixin.py index bfe54238fc..f538c9d277 100644 --- a/test/ut/retiarii/convert_mixin.py +++ b/test/ut/retiarii/convert_mixin.py @@ -15,5 +15,5 @@ class ConvertWithShapeMixin: @staticmethod def _convert_model(model, input): script_module = torch.jit.script(model) - model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input) + model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), dummy_input=input) return model_ir