diff --git a/nni/retiarii/evaluator/pytorch/lightning.py b/nni/retiarii/evaluator/pytorch/lightning.py index 72409f313a..f2a78b30aa 100644 --- a/nni/retiarii/evaluator/pytorch/lightning.py +++ b/nni/retiarii/evaluator/pytorch/lightning.py @@ -4,7 +4,7 @@ import os import warnings from pathlib import Path -from typing import Dict, Union, Optional, List, Callable, Type +from typing import Any, Dict, Union, Optional, List, Callable, Type import pytorch_lightning as pl import torch.nn as nn @@ -22,6 +22,7 @@ cgo_import_failed = True from nni.retiarii.graph import Evaluator +from nni.typehint import Literal __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression'] @@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule): See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html """ + running_mode: Literal['multi', 'oneshot'] = 'multi' + """An indicator of whether current module is running in a multi-trial experiment or an one-shot. + This flag should be automatically set by experiments when they start to run. + """ + def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None: """Set the inner model (architecture) to train / evaluate. @@ -59,6 +65,7 @@ def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None: Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html """ + @nni.trace class Lightning(Evaluator): """ @@ -74,51 +81,67 @@ class Lightning(Evaluator): Parameters ---------- - lightning_module : LightningModule + lightning_module Lightning module that defines the training logic. - trainer : Trainer + trainer Lightning trainer that handles the training. - train_dataloders : DataLoader + train_dataloders Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. - val_dataloaders : DataLoader or List of DataLoader + It can be `any types of dataloader supported by Lightning `__. + val_dataloaders Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. + It can be `any types of dataloader supported by Lightning `__. """ def __init__(self, lightning_module: LightningModule, trainer: Trainer, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Union[DataLoader, List[DataLoader], None] = None): + train_dataloaders: Optional[Any] = None, + val_dataloaders: Optional[Any] = None, + train_dataloader: Optional[Any] = None): assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' + if train_dataloader is not None: + warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) + train_dataloaders = train_dataloader if cgo_import_failed: assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}' else: # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \ f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer' - assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.' - assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.' + if not _check_dataloader(train_dataloaders): + warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' + f'import DataLoader from {__name__}: {train_dataloaders}', + RuntimeWarning) + if not _check_dataloader(val_dataloaders): + warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' + f'import DataLoader from {__name__}: {val_dataloaders}', + RuntimeWarning) self.module = lightning_module self.trainer = trainer - self.train_dataloader = train_dataloader + self.train_dataloaders = train_dataloaders self.val_dataloaders = val_dataloaders @staticmethod def _load(ir): - return Lightning(ir['module'], ir['trainer'], ir['train_dataloader'], ir['val_dataloaders']) + return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders']) def _dump(self): return { 'type': self.__class__, 'module': self.module, 'trainer': self.trainer, - 'train_dataloader': self.train_dataloader, + 'train_dataloaders': self.train_dataloaders, 'val_dataloaders': self.val_dataloaders } def _execute(self, model_cls): return self.fit(model_cls) + @property + def train_dataloader(self): + warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning) + def __eq__(self, other): eq_func = False eq_args = False @@ -146,15 +169,18 @@ def fit(self, model): The model to fit. """ self.module.set_model(model) - return self.trainer.fit(self.module, self.train_dataloader, self.val_dataloaders) + return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders) def _check_dataloader(dataloader): - if dataloader is None: - return True + # Check the type of dataloader recursively. if isinstance(dataloader, list): return all([_check_dataloader(d) for d in dataloader]) - return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader) + if isinstance(dataloader, dict): + return all([_check_dataloader(v) for v in dataloader.values()]) + if isinstance(dataloader, torch_data.DataLoader): + return is_traceable(dataloader) + return True ### The following are some commonly used Lightning modules ### @@ -176,7 +202,6 @@ def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetr if export_onnx is None or export_onnx is True: self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx' - self.export_onnx.parent.mkdir(exist_ok=True) elif export_onnx: self.export_onnx = Path(export_onnx) else: @@ -199,7 +224,8 @@ def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) - if self.export_onnx is not None: + if self.running_mode == 'multi' and self.export_onnx is not None: + self.export_onnx.parent.mkdir(exist_ok=True) try: self.to_onnx(self.export_onnx, x, export_params=True) except RuntimeError as e: @@ -221,10 +247,12 @@ def configure_optimizers(self): return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore def on_validation_epoch_end(self): - nni.report_intermediate_result(self._get_validation_metrics()) + if self.running_mode == 'multi': + nni.report_intermediate_result(self._get_validation_metrics()) def on_fit_end(self): - nni.report_final_result(self._get_validation_metrics()) + if self.running_mode == 'multi': + nni.report_final_result(self._get_validation_metrics()) def _get_validation_metrics(self): if len(self.metrics) == 1: @@ -283,14 +311,18 @@ def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, - train_dataloader: Optional[DataLoader] = None, + train_dataloaders: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, export_onnx: bool = True, + train_dataloader: Optional[DataLoader] = None, **trainer_kwargs): + if train_dataloader is not None: + warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) + train_dataloaders = train_dataloader module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) super().__init__(module, Trainer(**trainer_kwargs), - train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) + train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) @nni.trace @@ -336,11 +368,15 @@ def __init__(self, criterion: Type[nn.Module] = nn.MSELoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, - train_dataloader: Optional[DataLoader] = None, + train_dataloaders: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, export_onnx: bool = True, + train_dataloader: Optional[DataLoader] = None, **trainer_kwargs): + if train_dataloader is not None: + warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) + train_dataloaders = train_dataloader module = _RegressionModule(criterion=criterion, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) super().__init__(module, Trainer(**trainer_kwargs), - train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) + train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) diff --git a/nni/retiarii/oneshot/pytorch/base_lightning.py b/nni/retiarii/oneshot/pytorch/base_lightning.py index 0047c93a57..53fb06d546 100644 --- a/nni/retiarii/oneshot/pytorch/base_lightning.py +++ b/nni/retiarii/oneshot/pytorch/base_lightning.py @@ -18,6 +18,7 @@ from nni.common.hpo_utils import ParameterSpec from nni.common.serializer import is_traceable from nni.retiarii.nn.pytorch.api import ValueChoiceX +from nni.typehint import Literal from .supermodule.base import BaseSuperNetModule __all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules'] @@ -334,21 +335,21 @@ def configure_optimizers(self): return arc_optimizers + w_optimizers, lr_schedulers def on_train_start(self): - # redirect the access to trainer/log to this module - # but note that we might be missing other attributes, - # which could potentially be a problem - self.model.trainer = self.trainer # type: ignore - self.model.log = self.log return self.model.on_train_start() def on_train_end(self): return self.model.on_train_end() def on_fit_start(self): - return self.model.on_train_start() + # redirect the access to trainer/log to this module + # but note that we might be missing other attributes, + # which could potentially be a problem + self.model.trainer = self.trainer # type: ignore + self.model.log = self.log + return self.model.on_fit_start() def on_fit_end(self): - return self.model.on_train_end() + return self.model.on_fit_end() def on_train_batch_start(self, batch, batch_idx, unused=0): return self.model.on_train_batch_start(batch, batch_idx, unused) @@ -356,6 +357,7 @@ def on_train_batch_start(self, batch, batch_idx, unused=0): def on_train_batch_end(self, outputs, batch, batch_idx, unused=0): return self.model.on_train_batch_end(outputs, batch, batch_idx, unused) + # Deprecated hooks in pytorch-lightning def on_epoch_start(self): return self.model.on_epoch_start() @@ -427,7 +429,7 @@ def apply(lr_scheduler): else: apply(lr_schedulers) - def call_weight_optimizers(self, method): + def call_weight_optimizers(self, method: Literal['step', 'zero_grad']): """ Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer. diff --git a/nni/retiarii/oneshot/pytorch/dataloader.py b/nni/retiarii/oneshot/pytorch/dataloader.py new file mode 100644 index 0000000000..98cf0e633d --- /dev/null +++ b/nni/retiarii/oneshot/pytorch/dataloader.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from typing import Any + +from pytorch_lightning.trainer.supporters import CombinedLoader, CombinedLoaderIterator + + +class ConcatLoader(CombinedLoader): + """This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders + instead of loading them in parallel. + + Parameters + ---------- + loaders + For example, :: + + { + "train": DataLoader(train_dataset), + "val": DataLoader(val_dataset) + } + + In this example, the loader will first produce the batches from "train", then "val". + + mode + Only support "min_size" for now. + """ + + def __init__(self, loaders: dict[str, Any], mode: str = 'min_size'): + # FIXME: max_cycle will make dataloaders cycle iterators, + # causing extra problems. + if mode != 'min_size': + raise ValueError('Only min_size mode is supported now.') + super().__init__(loaders, mode) + + def __iter__(self) -> Any: + """Replace the super-class iterator with ours.""" + self._try_to_patch_pytorch_dataloader() + iterator = ConcatLoaderIterator(self.loaders) + # handle fault tolerant restart. + self.on_restart(iterator) + self._iterator = iterator + return iterator + + @staticmethod + def _try_to_patch_pytorch_dataloader(): + """Copied from CombinedLoader.""" + from torch.utils.data.dataloader import _BaseDataLoaderIter + + # prevent `NotImplementedError` from PyTorch: + # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541 + def __getstate__patch__(*_): + return {} + + _BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore + + def __len__(self) -> int: + return int(sum(self._calc_num_batches(loader) for loader in self.loaders.values())) + + +class ConcatLoaderIterator(CombinedLoaderIterator): + """Similar to CombinedLoaderIterator in Lightning, but in a concat manner.""" + + def __next__(self) -> Any: + """Fetches the next batch from multiple data loaders, + by looking for the first iterator that isn't exhausted yet. + """ + if not len(self.loader_iters) == len(self.loaders): + raise RuntimeError('loader_iters must have the same length as loaders.') + for i, (loader_name, iterator) in enumerate(self.loader_iters.items()): + try: + return (self.request_next_batch(iterator), loader_name) + except StopIteration: + if i + 1 == len(self.loader_iters): + raise diff --git a/nni/retiarii/oneshot/pytorch/differentiable.py b/nni/retiarii/oneshot/pytorch/differentiable.py index 199134d39c..6167585ec3 100644 --- a/nni/retiarii/oneshot/pytorch/differentiable.py +++ b/nni/retiarii/oneshot/pytorch/differentiable.py @@ -75,8 +75,9 @@ def training_step(self, batch, batch_idx): if not isinstance(arc_optim, optim.Optimizer): raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}') - # The InterleavedTrainValDataLoader yields both train and val data in a batch - trn_batch, val_batch = batch + # DARTS strategy makes sure that ``train`` and ``val`` must be in the batch + trn_batch = batch['train'] + val_batch = batch['val'] # phase 1: architecture step # The _resample hook is kept for some darts-based NAS methods like proxyless. diff --git a/nni/retiarii/oneshot/pytorch/sampling.py b/nni/retiarii/oneshot/pytorch/sampling.py index 5376524752..1c21b078d8 100644 --- a/nni/retiarii/oneshot/pytorch/sampling.py +++ b/nni/retiarii/oneshot/pytorch/sampling.py @@ -133,29 +133,30 @@ def __init__(self, def configure_architecture_optimizers(self): return optim.Adam(self.controller.parameters(), lr=3.5e-4) - def training_step(self, batch, batch_idx): - # The ConcatenateTrainValDataloader yields both data and which dataloader it comes from. - batch, source = batch + def training_step(self, batch_packed, batch_idx): + batch, mode = batch_packed - if source == 'train': - # step 1: train model params - self.resample() + if mode == 'train': + # train model params + with torch.no_grad(): + self.resample() self.call_weight_optimizers('zero_grad') - loss_and_metrics = self.model.training_step(batch, batch_idx) - w_step_loss = loss_and_metrics['loss'] \ - if isinstance(loss_and_metrics, dict) else loss_and_metrics + step_output = self.model.training_step(batch, batch_idx) + w_step_loss = step_output['loss'] \ + if isinstance(step_output, dict) else step_output self.manual_backward(w_step_loss) self.call_weight_optimizers('step') - return loss_and_metrics - if source == 'val': - # step 2: train ENAS agent + else: + # train ENAS agent arc_opt = self.architecture_optimizers() if not isinstance(arc_opt, optim.Optimizer): raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}') arc_opt.zero_grad() self.resample() - self.model.validation_step(batch, batch_idx) + + step_output = self.model.validation_step(batch, batch_idx) + # use the default metric of self.model as reward function if len(self.trainer.callback_metrics) == 1: _, metric = next(iter(self.trainer.callback_metrics.items())) @@ -163,7 +164,9 @@ def training_step(self, batch, batch_idx): metric_name = self.reward_metric_name or 'default' if metric_name not in self.trainer.callback_metrics: raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but ' - f'found multiple metrics without default: {self.trainer.callback_metrics.keys()}') + f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. ' + f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, ' + 'and remember to set on_step=True.') metric = self.trainer.callback_metrics[metric_name] reward: float = metric.item() @@ -183,6 +186,8 @@ def training_step(self, batch, batch_idx): arc_opt.step() arc_opt.zero_grad() + return step_output + def resample(self): """Resample the architecture with ENAS controller.""" sample = self.controller.resample() diff --git a/nni/retiarii/oneshot/pytorch/strategy.py b/nni/retiarii/oneshot/pytorch/strategy.py index 61479336c4..58041aaae4 100644 --- a/nni/retiarii/oneshot/pytorch/strategy.py +++ b/nni/retiarii/oneshot/pytorch/strategy.py @@ -16,7 +16,6 @@ from typing import Any, Type import torch.nn as nn -from torch.utils.data import DataLoader from nni.retiarii.graph import Model from nni.retiarii.strategy.base import BaseStrategy @@ -25,7 +24,6 @@ from .base_lightning import BaseOneShotLightningModule from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule from .sampling import EnasLightningModule, RandomSamplingLightningModule -from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader class OneShotStrategy(BaseStrategy): @@ -37,15 +35,18 @@ def __init__(self, oneshot_module: Type[BaseOneShotLightningModule], **kwargs): self.model: BaseOneShotLightningModule | None = None - def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader | list[DataLoader]) \ - -> DataLoader | tuple[DataLoader, DataLoader]: + def preprocess_dataloader(self, train_dataloaders: Any, val_dataloaders: Any) -> tuple[Any, Any]: """ - One-shot strategy typically requires a customized dataloader. - - If only train dataloader is produced, return one dataloader. - Otherwise, return train dataloader and valid loader as a tuple. + One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way. + As one-shot strategy doesn't try to open the blackbox of a batch, + theoretically, these dataloader can be + `any dataloader types supported by Lightning `__. + + Returns + ------- + A tuple of preprocessed train dataloaders and validation dataloaders. """ - raise NotImplementedError() + return train_dataloaders, val_dataloaders def run(self, base_model: Model, applied_mutators): # one-shot strategy doesn't use ``applied_mutators`` @@ -64,18 +65,15 @@ def run(self, base_model: Model, applied_mutators): raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.') evaluator_module: LightningModule = base_model.evaluator.module + evaluator_module.running_mode = 'oneshot' evaluator_module.set_model(py_model) self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs) evaluator: Lightning = base_model.evaluator - if evaluator.train_dataloader is None or evaluator.val_dataloaders is None: - raise TypeError('Train or val dataloader is not set.') - dataloader = self._get_dataloader(evaluator.train_dataloader, evaluator.val_dataloaders) - if isinstance(dataloader, tuple): - dataloader, val_loader = dataloader - evaluator.trainer.fit(self.model, dataloader, val_loader) - else: - evaluator.trainer.fit(self.model, dataloader) + if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None: + raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.') + train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders) + evaluator.trainer.fit(self.model, train_loader, val_loader) def export_top_models(self, top_k: int = 1) -> list[Any]: if self.model is None: @@ -91,8 +89,12 @@ class DARTS(OneShotStrategy): def __init__(self, **kwargs): super().__init__(DartsLightningModule, **kwargs) - def _get_dataloader(self, train_dataloader, val_dataloaders): - return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) + def preprocess_dataloader(self, train_dataloaders, val_dataloaders): + # By returning a dict, we make a CombinedLoader (in Lightning) + return { + 'train': train_dataloaders, + 'val': val_dataloaders + }, None class Proxyless(OneShotStrategy): @@ -101,8 +103,11 @@ class Proxyless(OneShotStrategy): def __init__(self, **kwargs): super().__init__(ProxylessLightningModule, **kwargs) - def _get_dataloader(self, train_dataloader, val_dataloaders): - return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) + def preprocess_dataloader(self, train_dataloaders, val_dataloaders): + return { + 'train': train_dataloaders, + 'val': val_dataloaders + }, None class GumbelDARTS(OneShotStrategy): @@ -111,8 +116,11 @@ class GumbelDARTS(OneShotStrategy): def __init__(self, **kwargs): super().__init__(GumbelDartsLightningModule, **kwargs) - def _get_dataloader(self, train_dataloader, val_dataloaders): - return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) + def preprocess_dataloader(self, train_dataloaders, val_dataloaders): + return { + 'train': train_dataloaders, + 'val': val_dataloaders + }, None class ENAS(OneShotStrategy): @@ -121,8 +129,13 @@ class ENAS(OneShotStrategy): def __init__(self, **kwargs): super().__init__(EnasLightningModule, **kwargs) - def _get_dataloader(self, train_dataloader, val_dataloaders): - return ConcatenateTrainValDataLoader(train_dataloader, val_dataloaders) + def preprocess_dataloader(self, train_dataloaders, val_dataloaders): + # Import locally to avoid import error on legacy PL version + from .dataloader import ConcatLoader + return ConcatLoader({ + 'train': train_dataloaders, + 'val': val_dataloaders + }), None class RandomOneShot(OneShotStrategy): @@ -130,6 +143,3 @@ class RandomOneShot(OneShotStrategy): def __init__(self, **kwargs): super().__init__(RandomSamplingLightningModule, **kwargs) - - def _get_dataloader(self, train_dataloader, val_dataloaders): - return train_dataloader, val_dataloaders diff --git a/nni/retiarii/oneshot/pytorch/utils.py b/nni/retiarii/oneshot/pytorch/utils.py index ee5caab4bc..9a56070267 100644 --- a/nni/retiarii/oneshot/pytorch/utils.py +++ b/nni/retiarii/oneshot/pytorch/utils.py @@ -132,6 +132,7 @@ def summary(self): def _replace_module_with_type(root_module, init_fn, type_name, modules): if modules is None: modules = [] + def apply(m): for name, child in m.named_children(): if isinstance(child, type_name): diff --git a/nni/retiarii/strategy/tpe_strategy.py b/nni/retiarii/strategy/tpe_strategy.py index e77b87afe9..68d7841a8f 100644 --- a/nni/retiarii/strategy/tpe_strategy.py +++ b/nni/retiarii/strategy/tpe_strategy.py @@ -5,8 +5,6 @@ import time from typing import Optional -from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner - from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted from .base import BaseStrategy @@ -15,6 +13,9 @@ class TPESampler(Sampler): def __init__(self, optimize_mode='minimize'): + # Move import here to eliminate some warning messages about dill. + from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner + self.tpe_tuner = HyperoptTuner('tpe', optimize_mode) self.cur_sample: Optional[dict] = None self.index: Optional[int] = None diff --git a/test/ut/retiarii/test_oneshot.py b/test/ut/retiarii/test_oneshot.py index 98c6fd6976..646435a714 100644 --- a/test/ut/retiarii/test_oneshot.py +++ b/test/ut/retiarii/test_oneshot.py @@ -15,6 +15,9 @@ from nni.retiarii.strategy import BaseStrategy +pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') + + class DepthwiseSeparableConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() @@ -171,7 +174,7 @@ def forward(self, x): return F.log_softmax(x, dim=1) -def _mnist_net(type_): +def _mnist_net(type_, evaluator_kwargs): if type_ == 'simple': base_model = SimpleNet(False) elif type_ == 'simple_value_choice': @@ -187,17 +190,18 @@ def _mnist_net(type_): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform) + # Multi-GPU combined dataloader will break this subset sampler. Expected though. train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20)) train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler) valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform) valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20)) valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler) - evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, max_epochs=1) + evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs) return base_model, evaluator -def _multihead_attention_net(): +def _multihead_attention_net(evaluator_kwargs): base_model = MultiHeadAttentionNet(1) class AttentionRandDataset(Dataset): @@ -222,19 +226,29 @@ def __len__(self): train_loader = DataLoader(train_set, batch_size=32) val_loader = DataLoader(val_set, batch_size=32) - evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, max_epochs=1) + evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, **evaluator_kwargs) return base_model, evaluator -def _test_strategy(strategy_, support_value_choice=True): +def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False): + evaluator_kwargs = { + 'max_epochs': 1 + } + if multi_gpu: + evaluator_kwargs.update( + strategy='ddp', + accelerator='gpu', + devices=torch.cuda.device_count() + ) + to_test = [ # (model, evaluator), support_or_net - (_mnist_net('simple'), True), - (_mnist_net('simple_value_choice'), support_value_choice), - (_mnist_net('value_choice'), support_value_choice), - (_mnist_net('repeat'), False), # no strategy supports repeat currently - (_mnist_net('custom_op'), False), # this is definitely a NO - (_multihead_attention_net(), support_value_choice), + (_mnist_net('simple', evaluator_kwargs), True), + (_mnist_net('simple_value_choice', evaluator_kwargs), support_value_choice), + (_mnist_net('value_choice', evaluator_kwargs), support_value_choice), + (_mnist_net('repeat', evaluator_kwargs), False), # no strategy supports repeat currently + (_mnist_net('custom_op', evaluator_kwargs), False), # this is definitely a NO + (_multihead_attention_net(evaluator_kwargs), support_value_choice), ] for (base_model, evaluator), support_or_not in to_test: @@ -256,17 +270,19 @@ def _test_strategy(strategy_, support_value_choice=True): experiment.run(config) -@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') def test_darts(): _test_strategy(strategy.DARTS()) -@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() <= 1, reason='Must have multiple GPUs.') +def test_darts_multi_gpu(): + _test_strategy(strategy.DARTS(), multi_gpu=True) + + def test_proxyless(): _test_strategy(strategy.Proxyless(), False) -@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') def test_enas(): def strategy_fn(base_model, evaluator): if isinstance(base_model, MultiHeadAttentionNet): @@ -276,12 +292,20 @@ def strategy_fn(base_model, evaluator): _test_strategy(strategy_fn) -@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() <= 1, reason='Must have multiple GPUs.') +def test_enas_multi_gpu(): + def strategy_fn(base_model, evaluator): + if isinstance(base_model, MultiHeadAttentionNet): + return strategy.ENAS(reward_metric_name='val_mse') + return strategy.ENAS(reward_metric_name='val_acc') + + _test_strategy(strategy_fn, multi_gpu=True) + + def test_random(): _test_strategy(strategy.RandomOneShot()) -@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') def test_gumbel_darts(): _test_strategy(strategy.GumbelDARTS()) diff --git a/test/ut/retiarii/test_oneshot_utils.py b/test/ut/retiarii/test_oneshot_utils.py new file mode 100644 index 0000000000..0c919e5e10 --- /dev/null +++ b/test/ut/retiarii/test_oneshot_utils.py @@ -0,0 +1,131 @@ +import math +from typing import Union + +import pytest +import torch +import pytorch_lightning +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader, Dataset + +pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs') + + +class RandomDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log('train_loss', loss) + return {'loss': loss} + + def validation_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log('valid_loss', loss) + + def test_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log('test_loss', loss) + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + + +def test_concat_loader(): + from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader + + loaders = { + 'a': DataLoader(range(10), batch_size=4), + 'b': DataLoader(range(20), batch_size=5), + } + dataloader = ConcatLoader(loaders) + assert len(dataloader) == 7 + for i, (data, label) in enumerate(dataloader): + if i < 3: + assert len(data) <= 4 + assert label == 'a' + else: + assert len(data) <= 5 + assert label == 'b' + + +def test_concat_loader_nested(): + from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader + + loaders = { + 'a': [DataLoader(range(10), batch_size=4), DataLoader(range(20), batch_size=6)], + 'b': DataLoader(range(20), batch_size=5), + } + dataloader = ConcatLoader(loaders) + assert len(dataloader) == 7 + for i, (data, label) in enumerate(dataloader): + if i < 3: + assert isinstance(data, list) and len(data) == 2 + assert label == 'a' + else: + assert label == 'b' + + +@pytest.mark.parametrize('replace_sampler_ddp', [False, True]) +@pytest.mark.parametrize('is_min_size_mode', [True]) +@pytest.mark.parametrize('num_devices', ['auto', 1, 3, 10]) +def test_concat_loader_with_ddp( + replace_sampler_ddp: bool, is_min_size_mode: bool, num_devices: Union[int, str] +): + """Inspired by tests/trainer/test_supporters.py in lightning.""" + from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader + + mode = 'min_size' if is_min_size_mode else 'max_size_cycle' + dim = 3 + n1 = 8 + n2 = 6 + n3 = 9 + dataloader = ConcatLoader({ + 'a': { + 'a1': DataLoader(RandomDataset(dim, n1), batch_size=1), + 'a2': DataLoader(RandomDataset(dim, n2), batch_size=1), + }, + 'b': DataLoader(RandomDataset(dim, n3), batch_size=1), + }, mode=mode) + expected_length_before_ddp = n3 + (min(n1, n2) if is_min_size_mode else max(n1, n2)) + print(len(dataloader)) + assert len(dataloader) == expected_length_before_ddp + model = BoringModel() + trainer = Trainer( + strategy='ddp', + accelerator='auto', + devices=num_devices, + replace_sampler_ddp=replace_sampler_ddp, + ) + trainer._data_connector.attach_data( + model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None + ) + expected_length_after_ddp = ( + math.ceil(n3 / trainer.num_devices) + \ + math.ceil((min(n1, n2) if is_min_size_mode else max(n1, n2)) / trainer.num_devices) + if replace_sampler_ddp + else expected_length_before_ddp + ) + print('Num devices =', trainer.num_devices) + trainer.reset_train_dataloader(model=model) + assert trainer.train_dataloader is not None + assert trainer.train_dataloader.mode == mode + + assert trainer.num_training_batches == expected_length_after_ddp