diff --git a/docs/en_US/NAS/retiarii/Advanced.rst b/docs/en_US/NAS/retiarii/Advanced.rst index 18146ab5a0..12568b8ab6 100644 --- a/docs/en_US/NAS/retiarii/Advanced.rst +++ b/docs/en_US/NAS/retiarii/Advanced.rst @@ -1,7 +1,17 @@ Advanced Tutorial ================= -This document includes two parts. The first part explains the design decision of ``@basic_unit`` and ``serializer``. The second part is the tutorial of how to write a model space with mutators. +Pure-python execution engine (experimental) +------------------------------------------- + +If you are experiencing issues with TorchScript, or the generated model code by Retiarii, there is another execution engine called Pure-python execution engine which doesn't need the code-graph conversion. This should generally not affect models and strategies in most cases, but customized mutation might not be supported. + +This will come as the default execution engine in future version of Retiarii. + +Two steps are needed to enable this engine now. + +1. Add ``@nni.retiarii.model_wrapper`` decorator outside the whole PyTorch model. +2. Add ``config.execution_engine = 'py'`` to ``RetiariiExeConfig``. ``@basic_unit`` and ``serializer`` ---------------------------------- diff --git a/nni/retiarii/__init__.py b/nni/retiarii/__init__.py index 762af7c834..f441367460 100644 --- a/nni/retiarii/__init__.py +++ b/nni/retiarii/__init__.py @@ -5,4 +5,4 @@ from .graph import * from .execution import * from .mutator import * -from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls +from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper diff --git a/nni/retiarii/execution/api.py b/nni/retiarii/execution/api.py index c53ee56fd0..8027e7e363 100644 --- a/nni/retiarii/execution/api.py +++ b/nni/retiarii/execution/api.py @@ -15,19 +15,18 @@ 'list_models', 'submit_models', 'wait_models', 'query_available_resources', 'set_execution_engine', 'is_stopped_exec', 'budget_exhausted'] -def set_execution_engine(engine) -> None: + +def set_execution_engine(engine: AbstractExecutionEngine) -> None: global _execution_engine if _execution_engine is None: _execution_engine = engine else: - raise RuntimeError('execution engine is already set') + raise RuntimeError('Execution engine is already set.') def get_execution_engine() -> AbstractExecutionEngine: - """ - Currently we assume the default execution engine is BaseExecutionEngine. - """ global _execution_engine + assert _execution_engine is not None, 'You need to set execution engine, before using it.' return _execution_engine diff --git a/nni/retiarii/execution/base.py b/nni/retiarii/execution/base.py index 65ab99fc2b..36d09b505f 100644 --- a/nni/retiarii/execution/base.py +++ b/nni/retiarii/execution/base.py @@ -5,7 +5,7 @@ import os import random import string -from typing import Dict, Iterable, List +from typing import Any, Dict, Iterable, List from .interface import AbstractExecutionEngine, AbstractGraphListener from .. import codegen, utils @@ -59,7 +59,7 @@ def __init__(self) -> None: def submit_models(self, *models: Model) -> None: for model in models: - data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) + data = self.pack_model_data(model) self._running_models[send_trial(data.dump())] = model self._history.append(model) @@ -108,6 +108,10 @@ def budget_exhausted(self) -> bool: advisor = get_advisor() return advisor.stopping + @classmethod + def pack_model_data(cls, model: Model) -> Any: + return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) + @classmethod def trial_execute_graph(cls) -> None: """ diff --git a/nni/retiarii/execution/python.py b/nni/retiarii/execution/python.py new file mode 100644 index 0000000000..93a4d10333 --- /dev/null +++ b/nni/retiarii/execution/python.py @@ -0,0 +1,53 @@ +from typing import Dict, Any, List + +from ..graph import Evaluator, Model +from ..integration_api import receive_trial_parameters +from ..utils import ContextStack, import_, get_importable_name +from .base import BaseExecutionEngine + + +class PythonGraphData: + def __init__(self, class_name: str, init_parameters: Dict[str, Any], + mutation: Dict[str, Any], evaluator: Evaluator) -> None: + self.class_name = class_name + self.init_parameters = init_parameters + self.mutation = mutation + self.evaluator = evaluator + + def dump(self) -> dict: + return { + 'class_name': self.class_name, + 'init_parameters': self.init_parameters, + 'mutation': self.mutation, + 'evaluator': self.evaluator + } + + @staticmethod + def load(data) -> 'PythonGraphData': + return PythonGraphData(data['class_name'], data['init_parameters'], data['mutation'], data['evaluator']) + + +class PurePythonExecutionEngine(BaseExecutionEngine): + @classmethod + def pack_model_data(cls, model: Model) -> Any: + mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history} + graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True), + model.python_init_params, mutation, model.evaluator) + return graph_data + + @classmethod + def trial_execute_graph(cls) -> None: + graph_data = PythonGraphData.load(receive_trial_parameters()) + + class _model(import_(graph_data.class_name)): + def __init__(self): + super().__init__(**graph_data.init_parameters) + + with ContextStack('fixed', graph_data.mutation): + graph_data.evaluator._execute(_model) + + +def _unpack_if_only_one(ele: List[Any]): + if len(ele) == 1: + return ele[0] + return ele diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index 2e5df288c6..351b4ee0be 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -28,11 +28,11 @@ from ..codegen import model_to_pytorch_script from ..converter import convert_to_graph -from ..execution import list_models +from ..execution import list_models, set_execution_engine from ..graph import Model, Evaluator from ..integration import RetiariiAdvisor from ..mutator import Mutator -from ..nn.pytorch.mutator import process_inline_mutation +from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module from ..strategy import BaseStrategy from ..oneshot.interface import BaseOneShotTrainer @@ -43,7 +43,7 @@ class RetiariiExeConfig(ConfigBase): experiment_name: Optional[str] = None search_space: Any = '' # TODO: remove - trial_command: str = 'python3 -m nni.retiarii.trial_entry' + trial_command: str = '_reserved' trial_code_directory: PathLike = '.' trial_concurrency: int trial_gpu_number: int = 0 @@ -55,21 +55,26 @@ class RetiariiExeConfig(ConfigBase): experiment_working_directory: Optional[PathLike] = None # remove configuration of tuner/assessor/advisor training_service: TrainingServiceConfig + execution_engine: str = 'base' def __init__(self, training_service_platform: Optional[str] = None, **kwargs): super().__init__(**kwargs) if training_service_platform is not None: assert 'training_service' not in kwargs self.training_service = util.training_service_config_factory(platform = training_service_platform) + self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry base' def __setattr__(self, key, value): fixed_attrs = {'search_space': '', - 'trial_command': 'python3 -m nni.retiarii.trial_entry'} + 'trial_command': '_reserved'} if key in fixed_attrs and fixed_attrs[key] != value: raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') # 'trial_code_directory' is handled differently because the path will be converted to absolute path by us if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)): raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') + if key == 'execution_engine': + assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.' + self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value self.__dict__[key] = value def validate(self, initialized_tuner: bool = False) -> None: @@ -100,23 +105,27 @@ def _validation_rules(self): 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') } -def preprocess_model(base_model, trainer, applied_mutators): +def preprocess_model(base_model, trainer, applied_mutators, full_ir=True): + # TODO: this logic might need to be refactored into execution engine + if full_ir: try: script_module = torch.jit.script(base_model) except Exception as e: _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') raise e base_model_ir = convert_to_graph(script_module, base_model) - base_model_ir.evaluator = trainer - # handle inline mutations mutators = process_inline_mutation(base_model_ir) - if mutators is not None and applied_mutators: - raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' - 'do not use mutators when you use LayerChoice/InputChoice') - if mutators is not None: - applied_mutators = mutators - return base_model_ir, applied_mutators + else: + base_model_ir, mutators = extract_mutation_from_pt_module(base_model) + base_model_ir.evaluator = trainer + + if mutators is not None and applied_mutators: + raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' + 'do not use mutators when you use LayerChoice/InputChoice') + if mutators is not None: + applied_mutators = mutators + return base_model_ir, applied_mutators def debug_mutated_model(base_model, trainer, applied_mutators): """ @@ -160,7 +169,8 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT self._pipe: Optional[Pipe] = None def _start_strategy(self): - base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators) + base_model_ir, self.applied_mutators = preprocess_model( + self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py') _logger.info('Start strategy...') self.strategy.run(base_model_ir, self.applied_mutators) @@ -182,6 +192,18 @@ def start(self, port: int = 8080, debug: bool = False) -> None: """ atexit.register(self.stop) + # we will probably need a execution engine factory to make this clean and elegant + if self.config.execution_engine == 'base': + from ..execution.base import BaseExecutionEngine + engine = BaseExecutionEngine() + elif self.config.execution_engine == 'cgo': + from ..execution.cgo_engine import CGOExecutionEngine + engine = CGOExecutionEngine() + elif self.config.execution_engine == 'py': + from ..execution.python import PurePythonExecutionEngine + engine = PurePythonExecutionEngine() + set_execution_engine(engine) + self.id = management.generate_experiment_id() if self.config.experiment_working_directory is not None: diff --git a/nni/retiarii/graph.py b/nni/retiarii/graph.py index 48f2971a74..2255e288f1 100644 --- a/nni/retiarii/graph.py +++ b/nni/retiarii/graph.py @@ -9,12 +9,12 @@ import copy import json from enum import Enum -from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload) +from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload) from .operation import Cell, Operation, _IOPseudoOperation from .utils import get_importable_name, import_, uid -__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData'] +__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData'] MetricData = Any @@ -80,6 +80,10 @@ class Model: Attributes ---------- + python_class + Python class that base model is converted from. + python_init_params + Initialization parameters of python class. status See `ModelStatus`. root_graph @@ -102,6 +106,8 @@ class Model: def __init__(self, _internal=False): assert _internal, '`Model()` is private, use `model.fork()` instead' self.model_id: int = uid('model') + self.python_class: Optional[Type] = None + self.python_init_params: Optional[Dict[str, Any]] = None self.status: ModelStatus = ModelStatus.Mutating @@ -116,7 +122,8 @@ def __init__(self, _internal=False): def __repr__(self): return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \ - f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})' + f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \ + f'python_class={self.python_class})' @property def root_graph(self) -> 'Graph': @@ -133,9 +140,12 @@ def fork(self) -> 'Model': """ new_model = Model(_internal=True) new_model._root_graph_name = self._root_graph_name + new_model.python_class = self.python_class + new_model.python_init_params = self.python_init_params new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()} new_model.evaluator = copy.deepcopy(self.evaluator) # TODO this may be a problem when evaluator is large - new_model.history = self.history + [self] + new_model.history = [*self.history] + # Note: the history is not updated. It will be updated when the model is changed, that is in mutator. return new_model @staticmethod @@ -167,8 +177,8 @@ def get_nodes(self) -> Iterable['Node']: def get_nodes_by_label(self, label: str) -> List['Node']: """ - Traverse all the nodes to find the matched node(s) with the given name. - There could be multiple nodes with the same name. Name space name can uniquely + Traverse all the nodes to find the matched node(s) with the given label. + There could be multiple nodes with the same label. Name space name can uniquely identify a graph or node. NOTE: the implementation does not support the class abstration @@ -493,6 +503,8 @@ class Node: If two models have nodes with same ID, they are semantically the same node. name Mnemonic name. It should have an one-to-one mapping with ID. + label + Optional. If two nodes have the same label, they are considered same by the mutator. operation ... cell @@ -515,7 +527,7 @@ def __init__(self, graph, node_id, name, operation, _internal=False): # TODO: the operation is likely to be considered editable by end-user and it will be hard to debug # maybe we should copy it here or make Operation class immutable, in next release self.operation: Operation = operation - self.label: str = None + self.label: Optional[str] = None def __repr__(self): return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})' @@ -673,6 +685,37 @@ def _dump(self) -> Any: } +class Mutation: + """ + An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices), + the model that it comes from, and the model that it becomes. + + In general cases, the mutation logs are not reliable and should not be replayed as the mutators can + be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here, + this can be useful for metadata visualization and python execution mode. + + Attributes + ---------- + mutator + Mutator. + samples + Decisions/choices. + from_ + Model that is comes from. + to + Model that it becomes. + """ + + def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model): # noqa: F821 + self.mutator: 'Mutator' = mutator # noqa: F821 + self.samples: List[Any] = samples + self.from_: Model = from_ + self.to: Model = to + + def __repr__(self): + return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})' + + class IllegalGraphError(ValueError): def __init__(self, graph, *args): self._debug_dump_graph(graph) diff --git a/nni/retiarii/integration.py b/nni/retiarii/integration.py index 1027062d8e..189db5ff5c 100644 --- a/nni/retiarii/integration.py +++ b/nni/retiarii/integration.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import logging -import os from typing import Any, Callable from nni.runtime.msg_dispatcher_base import MsgDispatcherBase @@ -10,9 +9,6 @@ from nni.utils import MetricType from .graph import MetricData -from .execution.base import BaseExecutionEngine -from .execution.cgo_engine import CGOExecutionEngine -from .execution.api import set_execution_engine from .integration_api import register_advisor from .serializer import json_dumps, json_loads @@ -62,15 +58,6 @@ def __init__(self): self.parameters_count = 0 - engine = self._create_execution_engine() - set_execution_engine(engine) - - def _create_execution_engine(self): - if os.environ.get('CGO') == 'true': - return CGOExecutionEngine() - else: - return BaseExecutionEngine() - def handle_initialize(self, data): """callback for initializing the advisor Parameters diff --git a/nni/retiarii/mutator.py b/nni/retiarii/mutator.py index fac3350f7c..e7d5708169 100644 --- a/nni/retiarii/mutator.py +++ b/nni/retiarii/mutator.py @@ -3,7 +3,7 @@ from typing import (Any, Iterable, List, Optional) -from .graph import Model +from .graph import Model, Mutation, ModelStatus __all__ = ['Sampler', 'Mutator'] @@ -40,10 +40,13 @@ class Mutator: and then use `Mutator.apply()` to mutate model. For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates. # Method names are open for discussion. + + If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label. """ - def __init__(self, sampler: Optional[Sampler] = None): + def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None): self.sampler: Optional[Sampler] = sampler + self.label: Optional[str] = label self._cur_model: Optional[Model] = None self._cur_choice_idx: Optional[int] = None @@ -64,9 +67,12 @@ def apply(self, model: Model) -> Model: copy = model.fork() self._cur_model = copy self._cur_choice_idx = 0 + self._cur_samples = [] self.sampler.mutation_start(self, copy) self.mutate(copy) self.sampler.mutation_end(self, copy) + copy.history.append(Mutation(self, self._cur_samples, model, copy)) + copy.status = ModelStatus.Frozen self._cur_model = None self._cur_choice_idx = None return copy @@ -97,6 +103,7 @@ def choice(self, candidates: Iterable[Choice]) -> Choice: """ assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx) + self._cur_samples.append(ret) self._cur_choice_idx += 1 return ret diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py index b394d3ca55..2eef6ac627 100644 --- a/nni/retiarii/nn/pytorch/api.py +++ b/nni/retiarii/nn/pytorch/api.py @@ -4,18 +4,32 @@ import copy import warnings from collections import OrderedDict -from typing import Any, List, Union, Dict +from typing import Any, List, Union, Dict, Optional import torch import torch.nn as nn from ...serializer import Translatable, basic_unit -from ...utils import uid +from ...utils import uid, get_current_context __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs'] +def _generate_new_label(label: Optional[str]): + if label is None: + return '_mutation_' + str(uid('mutation')) + return label + + +def _get_fixed_value(label: str): + ret = get_current_context('fixed') + try: + return ret[_generate_new_label(label)] + except KeyError: + raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}') + + class LayerChoice(nn.Module): """ Layer choice selects one of the ``candidates``, then apply it on inputs and return results. @@ -55,6 +69,16 @@ class LayerChoice(nn.Module): ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. """ + def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs): + try: + chosen = _get_fixed_value(label) + if isinstance(candidates, list): + return candidates[int(chosen)] + else: + return candidates[chosen] + except AssertionError: + return super().__new__(cls) + def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs): super(LayerChoice, self).__init__() if 'key' in kwargs: @@ -65,7 +89,7 @@ def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], lab if 'reduction' in kwargs: warnings.warn(f'"reduction" is deprecated. Ignoring...') self.candidates = candidates - self._label = label if label is not None else f'layerchoice_{uid()}' + self._label = _generate_new_label(label) self.names = [] if isinstance(candidates, OrderedDict): @@ -163,6 +187,12 @@ class InputChoice(nn.Module): Identifier of the input choice. """ + def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs): + try: + return ChosenInputs(_get_fixed_value(label), reduction=reduction) + except AssertionError: + return super().__new__(cls) + def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs): super(InputChoice, self).__init__() if 'key' in kwargs: @@ -176,7 +206,7 @@ def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', self.n_chosen = n_chosen self.reduction = reduction assert self.reduction in ['mean', 'concat', 'sum', 'none'] - self._label = label if label is not None else f'inputchoice_{uid()}' + self._label = _generate_new_label(label) @property def key(self): @@ -265,10 +295,16 @@ def forward(self, x): Identifier of the value choice. """ + def __new__(cls, candidates: List[Any], label: str = None): + try: + return _get_fixed_value(label) + except AssertionError: + return super().__new__(cls) + def __init__(self, candidates: List[Any], label: str = None): super().__init__() self.candidates = candidates - self._label = label if label is not None else f'valuechoice_{uid()}' + self._label = _generate_new_label(label) self._accessor = [] @property @@ -297,6 +333,14 @@ def access(self, value): raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}') return v + def __copy__(self): + return self + + def __deepcopy__(self, memo): + new_item = ValueChoice(self.candidates, self.label) + new_item._accessor = [*self._accessor] + return new_item + def __getitem__(self, item): """ Get a sub-element of value choice. @@ -331,9 +375,9 @@ class ChosenInputs(nn.Module): The already-chosen version of InputChoice. """ - def __init__(self, chosen: List[int], reduction: str): + def __init__(self, chosen: Union[List[int], int], reduction: str): super().__init__() - self.chosen = chosen + self.chosen = chosen if isinstance(chosen, list) else [chosen] self.reduction = reduction def forward(self, candidate_inputs): diff --git a/nni/retiarii/nn/pytorch/mutator.py b/nni/retiarii/nn/pytorch/mutator.py index 3f4b256da8..8b2f790a69 100644 --- a/nni/retiarii/nn/pytorch/mutator.py +++ b/nni/retiarii/nn/pytorch/mutator.py @@ -1,11 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import inspect from typing import Any, List, Optional, Tuple +import torch.nn as nn + from ...mutator import Mutator -from ...graph import Cell, Model, Node -from .api import ValueChoice +from ...graph import Cell, Graph, Model, ModelStatus, Node +from ...utils import uid +from .api import LayerChoice, InputChoice, ValueChoice, Placeholder class LayerChoiceMutator(Mutator): @@ -40,7 +44,7 @@ def __init__(self, nodes: List[Node]): def mutate(self, model): n_candidates = self.nodes[0].operation.parameters['n_candidates'] - n_chosen = self.nodes[0].operation.parameters['n_chosen'] + n_chosen = self.nodes[0].operation.parameters['n_chosen'] candidates = list(range(n_candidates)) chosen = [self.choice(candidates) for _ in range(n_chosen)] for node in self.nodes: @@ -116,12 +120,96 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: mutator = LayerChoiceMutator(node_list) applied_mutators.append(mutator) - if applied_mutators: return applied_mutators return None +# The following are written for pure-python mode + + +class ManyChooseManyMutator(Mutator): + """ + Choose based on labels. Will not affect the model itself. + """ + + def __init__(self, label: Optional[str]): + super().__init__(label=label) + + @staticmethod + def candidates(node): + if 'n_candidates' in node.operation.parameters: + return list(range(node.operation.parameters['n_candidates'])) + else: + return node.operation.parameters['candidates'] + + @staticmethod + def number_of_chosen(node): + if 'n_chosen' in node.operation.parameters: + return node.operation.parameters['n_chosen'] + return 1 + + def mutate(self, model: Model): + # this mutate does not have any effect, but it is recorded in the mutation history + for node in model.get_nodes_by_label(self.label): + for _ in range(self.number_of_chosen(node)): + self.choice(self.candidates(node)) + break + + +def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]: + model = Model(_internal=True) + graph = Graph(model, uid(), '_model', _internal=True)._register() + model.python_class = pytorch_model.__class__ + if len(inspect.signature(model.python_class.__init__).parameters) > 1: + if not hasattr(pytorch_model, '_init_parameters'): + raise ValueError('Please annotate the model with @serialize decorator in python execution mode ' + 'if your model has init parameters.') + model.python_init_params = pytorch_model._init_parameters + else: + model.python_init_params = {} + + for name, module in pytorch_model.named_modules(): + # tricky case: value choice that serves as parameters are stored in _init_parameters + if hasattr(module, '_init_parameters'): + for key, value in module._init_parameters.items(): + if isinstance(value, ValueChoice): + node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates}) + node.label = value.label + + if isinstance(module, (LayerChoice, InputChoice, ValueChoice)): + # TODO: check the label of module and warn if it's auto-generated + pass + if isinstance(module, LayerChoice): + node = graph.add_node(name, 'LayerChoice', {'candidates': module.names}) + node.label = module.label + if isinstance(module, InputChoice): + node = graph.add_node(name, 'InputChoice', + {'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen}) + node.label = module.label + if isinstance(module, ValueChoice): + node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates}) + node.label = module.label + if isinstance(module, Placeholder): + raise NotImplementedError('Placeholder is not supported in python execution mode.') + + model.status = ModelStatus.Frozen + if not graph.hidden_nodes: + return model, None + + mutators = [] + for nodes in _group_by_label_and_type(graph.hidden_nodes): + assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \ + f'Node with label "{nodes[0].label}" does not all have the same type.' + assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \ + f'Node with label "{nodes[0].label}" does not agree on parameters.' + mutators.append(ManyChooseManyMutator(nodes[0].label)) + return model, mutators + + +# utility functions + + def _is_all_equal(lst): last = None for x in lst: @@ -131,6 +219,16 @@ def _is_all_equal(lst): return True +def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]: + result = {} + for node in nodes: + key = (node.label, node.operation.type) + if key not in result: + result[key] = [] + result[key].append(node) + return list(result.values()) + + def _group_by_label(nodes: List[Node]) -> List[List[Node]]: result = {} for node in nodes: diff --git a/nni/retiarii/serializer.py b/nni/retiarii/serializer.py index 9aad75c9c6..e0c2a26115 100644 --- a/nni/retiarii/serializer.py +++ b/nni/retiarii/serializer.py @@ -9,7 +9,7 @@ import json_tricks -from .utils import get_importable_name, get_module_name, import_ +from .utils import get_importable_name, get_module_name, import_, reset_uid def get_init_parameters_or_fail(obj, silently=False): @@ -83,9 +83,11 @@ def _translate(self) -> Any: pass -def _create_wrapper_cls(cls, store_init_parameters=True): +def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False): class wrapper(cls): def __init__(self, *args, **kwargs): + if reset_mutation_uid: + reset_uid('mutation') if store_init_parameters: argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:] full_args = {} @@ -149,3 +151,15 @@ def basic_unit(cls): import torch.nn as nn assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' return serialize_cls(cls) + + +def model_wrapper(cls): + """ + Wrap the model if you are using pure-python execution engine. + + The wrapper serves two purposes: + + 1. Capture the init parameters of python class so that it can be re-instantiated in another process. + 2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios. + """ + return _create_wrapper_cls(cls, reset_mutation_uid=True) diff --git a/nni/retiarii/trial_entry.py b/nni/retiarii/trial_entry.py index 16855ca820..7d805dd47f 100644 --- a/nni/retiarii/trial_entry.py +++ b/nni/retiarii/trial_entry.py @@ -6,13 +6,20 @@ Assuming execution engine is BaseExecutionEngine. """ -import os +import argparse -from .execution.base import BaseExecutionEngine -from .execution.cgo_engine import CGOExecutionEngine if __name__ == '__main__': - if os.environ.get('CGO') == 'true': - CGOExecutionEngine.trial_execute_graph() - else: - BaseExecutionEngine.trial_execute_graph() + parser = argparse.ArgumentParser() + parser.add_argument('exec', choices=['base', 'py', 'cgo']) + args = parser.parse_args() + if args.exec == 'base': + from .execution.base import BaseExecutionEngine + engine = BaseExecutionEngine + elif args.exec == 'cgo': + from .execution.cgo_engine import CGOExecutionEngine + engine = CGOExecutionEngine + elif args.exec == 'py': + from .execution.python import PurePythonExecutionEngine + engine = PurePythonExecutionEngine + engine.trial_execute_graph() diff --git a/nni/retiarii/utils.py b/nni/retiarii/utils.py index d17ccabcfc..c8b02dfba4 100644 --- a/nni/retiarii/utils.py +++ b/nni/retiarii/utils.py @@ -4,7 +4,7 @@ import inspect import warnings from collections import defaultdict -from typing import Any +from typing import Any, List, Dict from pathlib import Path @@ -31,6 +31,10 @@ def uid(namespace: str = 'default') -> int: return _last_uid[namespace] +def reset_uid(namespace: str = 'default') -> None: + _last_uid[namespace] = 0 + + def get_module_name(cls_or_func): module_name = cls_or_func.__module__ if module_name == '__main__': @@ -61,3 +65,42 @@ def get_module_name(cls_or_func): def get_importable_name(cls, relocate_module=False): module_name = get_module_name(cls) if relocate_module else cls.__module__ return module_name + '.' + cls.__name__ + + +class ContextStack: + """ + This is to maintain a globally-accessible context envinronment that is visible to everywhere. + + Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to + get the corresponding value in the namespace. + """ + + _stack: Dict[str, List[Any]] = defaultdict(list) + + def __init__(self, key: str, value: Any): + self.key = key + self.value = value + + def __enter__(self): + self.push(self.key, self.value) + return self + + def __exit__(self, *args, **kwargs): + self.pop(self.key) + + @classmethod + def push(cls, key: str, value: Any): + cls._stack[key].append(value) + + @classmethod + def pop(cls, key: str) -> None: + cls._stack[key].pop() + + @classmethod + def top(cls, key: str) -> Any: + assert cls._stack[key], 'Context is empty.' + return cls._stack[key][-1] + + +def get_current_context(key: str) -> Any: + return ContextStack.top(key) diff --git a/test/ut/retiarii/debug_mnist_pytorch.py b/test/ut/retiarii/debug_mnist_pytorch.py index 4ac3ddff8d..a15977e5f2 100644 --- a/test/ut/retiarii/debug_mnist_pytorch.py +++ b/test/ut/retiarii/debug_mnist_pytorch.py @@ -3,22 +3,24 @@ import torch.nn.functional as F import torch.optim as optim +import torch + class _model(nn.Module): def __init__(self): super().__init__() self.stem = stem() - - self.fc1 = nn.Linear(1024, 256) - self.fc2 = nn.Linear(256, 10) - + self.flatten = torch.nn.Flatten() + self.fc1 = torch.nn.Linear(out_features=256, in_features=1024) + self.fc2 = torch.nn.Linear(out_features=10, in_features=256) + self.softmax = torch.nn.Softmax() def forward(self, image): stem = self.stem(image) - flatten = stem.view(stem.size(0), -1) + flatten = self.flatten(stem) fc1 = self.fc1(flatten) fc2 = self.fc2(fc1) - softmax = F.softmax(fc2, -1) + softmax = self.softmax(fc2) return softmax @@ -26,10 +28,10 @@ def forward(self, image): class stem(nn.Module): def __init__(self): super().__init__() - self.conv1 = nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5) - self.pool1 = nn.MaxPool2d(kernel_size=2) - self.conv2 = nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5) - self.pool2 = nn.MaxPool2d(kernel_size=2) + self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5) + self.pool1 = torch.nn.MaxPool2d(kernel_size=2) + self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5) + self.pool2 = torch.nn.MaxPool2d(kernel_size=2) def forward(self, *_inputs): conv1 = self.conv1(_inputs[0]) diff --git a/test/ut/retiarii/mnist_pytorch.json b/test/ut/retiarii/mnist_pytorch.json index 5788136d8a..b5ddc87887 100644 --- a/test/ut/retiarii/mnist_pytorch.json +++ b/test/ut/retiarii/mnist_pytorch.json @@ -5,10 +5,10 @@ "nodes": { "stem": {"operation": {"type": "_cell", "cell_name": "stem"}}, - "flatten": {"operation": {"type": "Flatten"}}, - "fc1": {"operation": {"type": "Dense", "parameters": {"out_features": 256, "in_features": 1024}}}, - "fc2": {"operation": {"type": "Dense", "parameters": {"out_features": 10, "in_features": 256}}}, - "softmax": {"operation": {"type": "Softmax"}} + "flatten": {"operation": {"type": "__torch__.torch.nn.Flatten"}}, + "fc1": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 256, "in_features": 1024}}}, + "fc2": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 10, "in_features": 256}}}, + "softmax": {"operation": {"type": "__torch__.torch.nn.Softmax"}} }, "edges": [ @@ -23,10 +23,10 @@ "stem": { "nodes": { - "conv1": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}}, - "pool1": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}}, - "conv2": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}}, - "pool2": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}} + "conv1": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}}, + "pool1": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}}, + "conv2": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}}, + "pool2": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}} }, "edges": [ @@ -36,26 +36,5 @@ {"head": ["conv2", null], "tail": ["pool2", null]}, {"head": ["pool2", null], "tail": ["_outputs", 0]} ] - }, - - "_evaluator": { - "module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer", - "kwargs": { - "dataset_cls": "MNIST", - "dataset_kwargs": { - "root": "data/mnist", - "download": true - }, - "dataloader_kwargs": { - "batch_size": 32 - }, - "optimizer_cls" : "SGD", - "optimizer_kwargs": { - "lr": 1e-3 - }, - "trainer_kwargs": { - "max_epochs": 1 - } - } } } diff --git a/test/ut/retiarii/test_engine.py b/test/ut/retiarii/test_engine.py index 48dc53e9bb..0a7881308b 100644 --- a/test/ut/retiarii/test_engine.py +++ b/test/ut/retiarii/test_engine.py @@ -1,59 +1,68 @@ import json import os -import sys -import threading import unittest from pathlib import Path -import nni +import nni.retiarii from nni.retiarii import Model, submit_models from nni.retiarii.codegen import model_to_pytorch_script -from nni.retiarii.integration import RetiariiAdvisor, register_advisor -from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer -from nni.retiarii.utils import import_ +from nni.retiarii.execution import set_execution_engine +from nni.retiarii.execution.base import BaseExecutionEngine +from nni.retiarii.execution.python import PurePythonExecutionEngine +from nni.retiarii.integration import RetiariiAdvisor -@unittest.skip('Skipped in this version') -class CodeGenTest(unittest.TestCase): - def test_mnist_example_pytorch(self): - with open('mnist_pytorch.json') as f: +class EngineTest(unittest.TestCase): + def test_codegen(self): + with open(self.enclosing_dir / 'mnist_pytorch.json') as f: model = Model._load(json.load(f)) script = model_to_pytorch_script(model) - with open('debug_mnist_pytorch.py') as f: + with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f: reference_script = f.read() self.assertEqual(script.strip(), reference_script.strip()) + def test_base_execution_engine(self): + advisor = RetiariiAdvisor() + set_execution_engine(BaseExecutionEngine()) + with open(self.enclosing_dir / 'mnist_pytorch.json') as f: + model = Model._load(json.load(f)) + submit_models(model, model) -@unittest.skip('Skipped in this version') -class TrainerTest(unittest.TestCase): - def test_trainer(self): - sys.path.insert(0, Path(__file__).parent.as_posix()) - Model = import_('debug_mnist_pytorch._model') - trainer = PyTorchImageClassificationTrainer( - Model(), - dataset_kwargs={'root': (Path(__file__).parent / 'data' / 'mnist').as_posix(), 'download': True}, - dataloader_kwargs={'batch_size': 32}, - optimizer_kwargs={'lr': 1e-3}, - trainer_kwargs={'max_epochs': 1} - ) - trainer.fit() - - -@unittest.skip('Skipped in this version') -class EngineTest(unittest.TestCase): + advisor.stopping = True + advisor.default_worker.join() + advisor.assessor_worker.join() - def test_submit_models(self): - os.makedirs('generated', exist_ok=True) - from nni.runtime import protocol - protocol._out_file = open(Path(__file__).parent / 'generated/debug_protocol_out_file.py', 'wb') + def test_py_execution_engine(self): + advisor = RetiariiAdvisor() - with open('mnist_pytorch.json') as f: - model = Model._load(json.load(f)) + set_execution_engine(PurePythonExecutionEngine()) + model = Model._load({ + '_model': { + 'inputs': None, + 'outputs': None, + 'nodes': { + 'layerchoice_1': { + 'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}} + } + }, + 'edges': [] + } + }) + model.python_class = object submit_models(model, model) advisor.stopping = True advisor.default_worker.join() advisor.assessor_worker.join() - def test_execution_engine(self): - pass + def setUp(self) -> None: + self.enclosing_dir = Path(__file__).parent + os.makedirs(self.enclosing_dir / 'generated', exist_ok=True) + from nni.runtime import protocol + protocol._out_file = open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb') + + def tearDown(self) -> None: + from nni.runtime import protocol + protocol._out_file.close() + nni.retiarii.execution.api._execution_engine = None + nni.retiarii.integration_api._advisor = None diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index 20fbbcddb6..c1b9bb0e3d 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -8,7 +8,10 @@ from nni.retiarii import Sampler, basic_unit from nni.retiarii.converter import convert_to_graph from nni.retiarii.codegen import model_to_pytorch_script -from nni.retiarii.nn.pytorch.mutator import process_inline_mutation +from nni.retiarii.execution.python import _unpack_if_only_one +from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module +from nni.retiarii.serializer import model_wrapper +from nni.retiarii.utils import ContextStack class EnumerateSampler(Sampler): @@ -44,7 +47,7 @@ def forward(self, x: torch.Tensor, index: int): return self.conv2(x) -class TestHighLevelAPI(unittest.TestCase): +class GraphIR(unittest.TestCase): def _convert_to_ir(self, model): script_module = torch.jit.script(model) @@ -56,7 +59,19 @@ def _get_converted_pytorch_model(self, model_ir): exec(model_code + '\n\nconverted_model = _model()', exec_vars) return exec_vars['converted_model'] + def _get_model_with_mutators(self, pytorch_model): + model = self._convert_to_ir(pytorch_model) + mutators = process_inline_mutation(model) + return model, mutators + + def get_serializer(self): + def dummy(cls): + return cls + + return dummy + def test_layer_choice(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -68,8 +83,7 @@ def __init__(self): def forward(self, x): return self.module(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -80,6 +94,7 @@ def forward(self, x): torch.Size([1, 5, 3, 3])) def test_input_choice(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -92,8 +107,7 @@ def forward(self, x): x2 = self.conv2(x) return self.input([x1, x2]) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -104,6 +118,7 @@ def forward(self, x): torch.Size([1, 5, 3, 3])) def test_chosen_inputs(self): + @self.get_serializer() class Net(nn.Module): def __init__(self, reduction): super().__init__() @@ -117,8 +132,7 @@ def forward(self, x): return self.input([x1, x2]) for reduction in ['none', 'sum', 'mean', 'concat']: - model = self._convert_to_ir(Net(reduction)) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net(reduction)) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model = mutator.apply(model) @@ -133,6 +147,7 @@ def forward(self, x): self.assertEqual(result.size(), torch.Size([1, 3, 3, 3])) def test_value_choice(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -142,8 +157,7 @@ def __init__(self): def forward(self, x): return self.conv(x, self.index()) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -154,6 +168,7 @@ def forward(self, x): torch.Size([1, 5, 3, 3])) def test_value_choice_as_parameter(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -162,8 +177,7 @@ def __init__(self): def forward(self, x): return self.conv(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -174,6 +188,7 @@ def forward(self, x): torch.Size([1, 5, 1, 1])) def test_value_choice_as_parameter(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -182,8 +197,7 @@ def __init__(self): def forward(self, x): return self.conv(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -194,6 +208,7 @@ def forward(self, x): torch.Size([1, 5, 1, 1])) def test_value_choice_as_parameter(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -202,8 +217,7 @@ def __init__(self): def forward(self, x): return self.conv(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 2) mutators[0].bind_sampler(EnumerateSampler()) mutators[1].bind_sampler(EnumerateSampler()) @@ -214,6 +228,7 @@ def forward(self, x): torch.Size([1, 8, 1, 1])) def test_value_choice_as_parameter_shared(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -223,8 +238,7 @@ def __init__(self): def forward(self, x): return self.conv1(x) + self.conv2(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -235,6 +249,7 @@ def forward(self, x): torch.Size([1, 8, 5, 5])) def test_value_choice_in_functional(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -243,8 +258,7 @@ def __init__(self): def forward(self, x): return F.dropout(x, self.dropout_rate()) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -254,6 +268,7 @@ def forward(self, x): self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) def test_value_choice_in_layer_choice(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -265,8 +280,7 @@ def __init__(self): def forward(self, x): return self.linear(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 3) sz_counter = Counter() sampler = RandomSampler() @@ -278,6 +292,7 @@ def forward(self, x): self.assertEqual(len(sz_counter), 4) def test_shared(self): + @self.get_serializer() class Net(nn.Module): def __init__(self, shared=True): super().__init__() @@ -294,16 +309,14 @@ def __init__(self, shared=True): def forward(self, x): return self.module1(x) + self.module2(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) sampler = RandomSampler() mutator = mutators[0].bind_sampler(sampler) self.assertEqual(self._get_converted_pytorch_model(mutator.apply(model))(torch.randn(1, 3, 3, 3)).size(0), 1) self.assertEqual(sampler.counter, 1) - model = self._convert_to_ir(Net(shared=False)) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net(shared=False)) self.assertEqual(len(mutators), 2) sampler = RandomSampler() # repeat test. Expectation: sometimes succeeds, sometimes fails. @@ -321,6 +334,7 @@ def forward(self, x): self.assertLess(failed_count, 30) def test_valuechoice_access(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -330,8 +344,7 @@ def __init__(self): def forward(self, x): return self.conv(x) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutators[0].bind_sampler(EnumerateSampler()) input = torch.randn(1, 3, 5, 5) @@ -340,6 +353,7 @@ def forward(self, x): self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(), torch.Size([1, 8, 1, 1])) + @self.get_serializer() class Net2(nn.Module): def __init__(self): super().__init__() @@ -354,14 +368,14 @@ def forward(self, x): x = self.conv(x) return self.conv1(torch.cat((x, x), 1)) - model = self._convert_to_ir(Net2()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net2()) self.assertEqual(len(mutators), 1) mutators[0].bind_sampler(EnumerateSampler()) input = torch.randn(1, 3, 5, 5) self._get_converted_pytorch_model(mutators[0].apply(model))(input) def test_valuechoice_access_functional(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -370,8 +384,7 @@ def __init__(self): def forward(self, x): return F.dropout(x, self.dropout_rate()[0]) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -381,6 +394,7 @@ def forward(self, x): self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) def test_valuechoice_access_functional_expression(self): + @self.get_serializer() class Net(nn.Module): def __init__(self): super().__init__() @@ -391,8 +405,7 @@ def forward(self, x): # ValueError: dropout probability has to be between 0 and 1, but got 1.05 return F.dropout(x, self.dropout_rate()[0] - .1) - model = self._convert_to_ir(Net()) - mutators = process_inline_mutation(model) + model, mutators = self._get_model_with_mutators(Net()) self.assertEqual(len(mutators), 1) mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) @@ -400,3 +413,29 @@ def forward(self, x): self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) + + +class Python(GraphIR): + def _get_converted_pytorch_model(self, model_ir): + mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history} + with ContextStack('fixed', mutation): + model = model_ir.python_class(**model_ir.python_init_params) + return model + + def _get_model_with_mutators(self, pytorch_model): + return extract_mutation_from_pt_module(pytorch_model) + + def get_serializer(self): + return model_wrapper + + @unittest.skip + def test_value_choice(self): ... + + @unittest.skip + def test_value_choice_in_functional(self): ... + + @unittest.skip + def test_valuechoice_access_functional(self): ... + + @unittest.skip + def test_valuechoice_access_functional_expression(self): ... diff --git a/test/ut/retiarii/test_mutator.py b/test/ut/retiarii/test_mutator.py index 0c4cfd404b..a0cd05296d 100644 --- a/test/ut/retiarii/test_mutator.py +++ b/test/ut/retiarii/test_mutator.py @@ -60,7 +60,14 @@ def test_mutation(): model2 = mutator.apply(model1) assert _get_pools(model2) == (global_pool, max_pool) - assert model2.history == [model0, model1] + assert len(model2.history) == 2 + assert model2.history[0].from_ == model0 + assert model2.history[0].to == model1 + assert model2.history[1].from_ == model1 + assert model2.history[1].to == model2 + assert model2.history[0].mutator == mutator + assert model2.history[1].mutator == mutator + assert _get_pools(model0) == (max_pool, max_pool) assert _get_pools(model1) == (avg_pool, global_pool)