Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Multi-GPU support of one-shot NAS #4603

Merged
merged 99 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
156fe72
init
v-fangdong Jan 11, 2022
6ec5f85
init
v-fangdong Jan 11, 2022
0fd40e8
differentiable and sampling-based ontshot algs
v-fangdong Jan 13, 2022
32ac1ab
fix import mistake
v-fangdong Jan 13, 2022
8091169
proxyless inherits darts
v-fangdong Jan 17, 2022
f29c39a
test files
v-fangdong Jan 17, 2022
d60f87e
add more comments and fix test files
v-fangdong Jan 18, 2022
09a065f
revert unrelated changes
v-fangdong Jan 20, 2022
a2e2fa8
unify comments style and remove debug codes
v-fangdong Jan 20, 2022
2289c19
remove unnecessary methods
v-fangdong Jan 20, 2022
9193af6
fix link
v-fangdong Jan 21, 2022
50ae803
unify code style
v-fangdong Jan 21, 2022
559518c
SNAS
v-fangdong Jan 21, 2022
1744e74
fix pylint, configure optimizers and training step loss
v-fangdong Jan 23, 2022
e4d5feb
fix pylint
v-fangdong Jan 24, 2022
61150e5
disable pylint unsubscriptable-opject warning
v-fangdong Jan 24, 2022
95038b3
use lib
v-fangdong Jan 25, 2022
6159639
use metrics
v-fangdong Jan 26, 2022
48081c9
remove unused
v-fangdong Jan 26, 2022
badcc83
fix bugs
v-fangdong Jan 26, 2022
a7fb932
fix lint
v-fangdong Jan 27, 2022
0a60dee
solve lr_scheduler
v-fangdong Jan 27, 2022
e1d2995
fix pylint
v-fangdong Jan 28, 2022
f763aaf
fix pylint
v-fangdong Jan 28, 2022
de23dfa
remove validation_step
v-fangdong Jan 28, 2022
1229bc6
fix pylint
v-fangdong Jan 28, 2022
9bebecf
fix bug
v-fangdong Jan 28, 2022
09da5c3
fix bug
v-fangdong Jan 28, 2022
3e71de6
fix bug
v-fangdong Jan 28, 2022
2ac3d3d
fix bug
v-fangdong Jan 28, 2022
54d9c1b
fix legacy bugs
v-fangdong Jan 28, 2022
3cb92d4
fix legacy bugs
v-fangdong Jan 28, 2022
0f7ad72
fix legacy bugs
v-fangdong Jan 29, 2022
1c01442
fix legacy bugs
v-fangdong Jan 29, 2022
32acdc5
fix bugs
v-fangdong Jan 30, 2022
2d99c6e
annealing for snas
v-fangdong Jan 30, 2022
4a1eba1
rerun ut
v-fangdong Feb 2, 2022
c1dce81
fix SNAS loss
v-fangdong Feb 7, 2022
b3ede07
remove too long lines
v-fangdong Feb 7, 2022
f2d56a5
fix snas temp
v-fangdong Feb 7, 2022
4a038b2
unify comments style, abstract optimizer-related behaviours into base…
v-fangdong Feb 8, 2022
284fe44
fix pylint
v-fangdong Feb 8, 2022
6068064
add more comments
v-fangdong Feb 10, 2022
05429ed
add input choice test case
v-fangdong Feb 10, 2022
ef84ae0
add inputchoice testcase
v-fangdong Feb 10, 2022
1c5839a
fix pylint
v-fangdong Feb 10, 2022
ee1120c
fix pylin
v-fangdong Feb 11, 2022
80845ff
fix typos
v-fangdong Feb 14, 2022
af3ae61
resolve conversations
v-fangdong Feb 14, 2022
0e4b0e5
add comments
v-fangdong Feb 15, 2022
897b2ec
update
ultmaster Feb 15, 2022
5e794ff
documents
ultmaster Feb 15, 2022
dfc82d6
Delete irrelevant files
ultmaster Feb 15, 2022
2a99859
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong Feb 15, 2022
35dbf11
correct typos
v-fangdong Feb 15, 2022
0851ec8
solve dict issue
ultmaster Feb 15, 2022
c21ce0f
update testing
ultmaster Feb 15, 2022
3bee06f
update serializer on windows (attempt)
ultmaster Feb 16, 2022
1433559
Merge branch 'fix-4434' of https://github.com/ultmaster/nni into fix-…
ultmaster Feb 16, 2022
a6a8061
update test
ultmaster Feb 16, 2022
6db030a
Avoid abc
ultmaster Feb 16, 2022
213e5b8
fix is_traceable
ultmaster Feb 16, 2022
f590c93
update is_traceable
ultmaster Feb 16, 2022
df3e9a4
fix all
ultmaster Feb 16, 2022
e328f05
update
ultmaster Feb 16, 2022
577e771
fix lightning problem on linux
ultmaster Feb 17, 2022
278b5c9
fix super
ultmaster Feb 17, 2022
7c51470
fix lightning issue
ultmaster Feb 17, 2022
0bfe695
!= linux
ultmaster Feb 17, 2022
d49f6e8
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong Feb 17, 2022
150279a
add handler for trainer
ultmaster Feb 17, 2022
ec9d7fa
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong Feb 17, 2022
abfd0a8
multicard
v-fangdong Feb 17, 2022
124232c
Merge branch 'lightning' into cherrypick-4434
v-fangdong Feb 17, 2022
f279ab0
pseudo dataset
v-fangdong Feb 17, 2022
84b01bd
fix pickle (again)
ultmaster Feb 17, 2022
d2ecda8
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong Feb 17, 2022
146ddc4
enas gpu
v-fangdong Feb 17, 2022
3e6178d
reinit dataloaders
v-fangdong Feb 23, 2022
0603606
always shuffle for trianing
v-fangdong Feb 24, 2022
7a45b8f
Merge branch 'lightning' of github.com:Frandium/nni into multicard-on…
ultmaster Mar 3, 2022
af2124c
Merge branch 'master' of https://github.com/microsoft/nni into multic…
ultmaster Mar 3, 2022
8fc6fff
Merge branch 'master' of github.com:microsoft/nni into multicard-oneshot
ultmaster May 9, 2022
9d9beaf
revert
ultmaster May 9, 2022
77a8c8b
revert
ultmaster May 9, 2022
001b1b4
.
ultmaster May 10, 2022
a4e245b
checkpoint
ultmaster May 10, 2022
6a1974b
finish up
ultmaster May 10, 2022
f31759c
.
ultmaster May 10, 2022
6966aca
add tests
ultmaster May 10, 2022
78374b1
revert
ultmaster May 10, 2022
f191ce6
revert
ultmaster May 10, 2022
4d3f187
revert
ultmaster May 10, 2022
1c66ff0
compat fix
ultmaster May 10, 2022
74edc65
fix tests
ultmaster May 10, 2022
5465417
fix
ultmaster May 11, 2022
7dec65c
Merge branch 'master' of github.com:microsoft/nni into multicard-oneshot
ultmaster May 11, 2022
cad1b2c
fix comments
ultmaster May 19, 2022
b53f98b
.
ultmaster May 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 60 additions & 24 deletions nni/retiarii/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import warnings
from pathlib import Path
from typing import Dict, Union, Optional, List, Callable, Type
from typing import Any, Dict, Union, Optional, List, Callable, Type

import pytorch_lightning as pl
import torch.nn as nn
Expand All @@ -22,6 +22,7 @@
cgo_import_failed = True

from nni.retiarii.graph import Evaluator
from nni.typehint import Literal


__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
Expand All @@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule):
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""

running_mode: Literal['multi', 'oneshot'] = 'multi'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that a lightning class can be used both for multi and oneshot?

Copy link

@matluster matluster May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I used this flag, I meant to check before two actions:

  1. Whether to save an onnx graph.
  2. Whether to report intermediate / final results.

But on a second thought now, maybe evaluator could figure it out by itself whether its inner module is a one-shot supernet. So this might be not needed.

No we can't. Revert.

"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
This flag should be automatically set by experiments when they start to run.
"""

def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate.

Expand All @@ -59,6 +65,7 @@ def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""


@nni.trace
class Lightning(Evaluator):
"""
Expand All @@ -74,51 +81,67 @@ class Lightning(Evaluator):

Parameters
----------
lightning_module : LightningModule
lightning_module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs can auto add type hint for these parameters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Lightning module that defines the training logic.
trainer : Trainer
trainer
Lightning trainer that handles the training.
train_dataloders : DataLoader
train_dataloders
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
val_dataloaders
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
"""

def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
train_dataloaders: Optional[Any] = None,
val_dataloaders: Optional[Any] = None,
train_dataloader: Optional[Any] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why there are both train_dataloaders and train_dataloader?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, backward compatibility

assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
if cgo_import_failed:
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
if not _check_dataloader(train_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}',
RuntimeWarning)
if not _check_dataloader(val_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {val_dataloaders}',
RuntimeWarning)
self.module = lightning_module
self.trainer = trainer
self.train_dataloader = train_dataloader
self.train_dataloaders = train_dataloaders
self.val_dataloaders = val_dataloaders

@staticmethod
def _load(ir):
return Lightning(ir['module'], ir['trainer'], ir['train_dataloader'], ir['val_dataloaders'])
return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders'])

def _dump(self):
return {
'type': self.__class__,
'module': self.module,
'trainer': self.trainer,
'train_dataloader': self.train_dataloader,
'train_dataloaders': self.train_dataloaders,
'val_dataloaders': self.val_dataloaders
}

def _execute(self, model_cls):
return self.fit(model_cls)

@property
def train_dataloader(self):
warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning)

def __eq__(self, other):
eq_func = False
eq_args = False
Expand Down Expand Up @@ -146,15 +169,18 @@ def fit(self, model):
The model to fit.
"""
self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloader, self.val_dataloaders)
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders)


def _check_dataloader(dataloader):
if dataloader is None:
return True
# Check the type of dataloader recursively.
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader)
if isinstance(dataloader, dict):
return all([_check_dataloader(v) for v in dataloader.values()])
if isinstance(dataloader, torch_data.DataLoader):
return is_traceable(dataloader)
return True


### The following are some commonly used Lightning modules ###
Expand All @@ -176,7 +202,6 @@ def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetr

if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
Expand All @@ -199,7 +224,8 @@ def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)

if self.export_onnx is not None:
if self.running_mode == 'multi' and self.export_onnx is not None:
self.export_onnx.parent.mkdir(exist_ok=True)
try:
self.to_onnx(self.export_onnx, x, export_params=True)
except RuntimeError as e:
Expand All @@ -221,10 +247,12 @@ def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore

def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
if self.running_mode == 'multi':
nni.report_intermediate_result(self._get_validation_metrics())

def on_fit_end(self):
nni.report_final_result(self._get_validation_metrics())
if self.running_mode == 'multi':
nni.report_final_result(self._get_validation_metrics())

def _get_validation_metrics(self):
if len(self.metrics) == 1:
Expand Down Expand Up @@ -283,14 +311,18 @@ def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)


@nni.trace
Expand Down Expand Up @@ -336,11 +368,15 @@ def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
18 changes: 10 additions & 8 deletions nni/retiarii/oneshot/pytorch/base_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule

__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
Expand Down Expand Up @@ -334,28 +335,29 @@ def configure_optimizers(self):
return arc_optimizers + w_optimizers, lr_schedulers

def on_train_start(self):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_train_start()

def on_train_end(self):
return self.model.on_train_end()

def on_fit_start(self):
return self.model.on_train_start()
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_fit_start()

def on_fit_end(self):
return self.model.on_train_end()
return self.model.on_fit_end()

def on_train_batch_start(self, batch, batch_idx, unused=0):
return self.model.on_train_batch_start(batch, batch_idx, unused)

def on_train_batch_end(self, outputs, batch, batch_idx, unused=0):
return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)

# Deprecated hooks in pytorch-lightning
def on_epoch_start(self):
return self.model.on_epoch_start()

Expand Down Expand Up @@ -427,7 +429,7 @@ def apply(lr_scheduler):
else:
apply(lr_schedulers)

def call_weight_optimizers(self, method):
def call_weight_optimizers(self, method: Literal['step', 'zero_grad']):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Expand Down
77 changes: 77 additions & 0 deletions nni/retiarii/oneshot/pytorch/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from typing import Any

from pytorch_lightning.trainer.supporters import CombinedLoader, CombinedLoaderIterator


class ConcatLoader(CombinedLoader):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.

Parameters
----------
loaders
For example, ::

{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}

In this example, the loader will first produce the batches from "train", then "val".

mode
Only support "min_size" for now.
"""

def __init__(self, loaders: dict[str, Any], mode: str = 'min_size'):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if mode != 'min_size':
raise ValueError('Only min_size mode is supported now.')
super().__init__(loaders, mode)

def __iter__(self) -> Any:
"""Replace the super-class iterator with ours."""
self._try_to_patch_pytorch_dataloader()
iterator = ConcatLoaderIterator(self.loaders)
# handle fault tolerant restart.
self.on_restart(iterator)
self._iterator = iterator
return iterator

@staticmethod
def _try_to_patch_pytorch_dataloader():
"""Copied from CombinedLoader."""
from torch.utils.data.dataloader import _BaseDataLoaderIter

# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def __getstate__patch__(*_):
return {}

_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore

def __len__(self) -> int:
return int(sum(self._calc_num_batches(loader) for loader in self.loaders.values()))


class ConcatLoaderIterator(CombinedLoaderIterator):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""

def __next__(self) -> Any:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if not len(self.loader_iters) == len(self.loaders):
raise RuntimeError('loader_iters must have the same length as loaders.')
for i, (loader_name, iterator) in enumerate(self.loader_iters.items()):
try:
return (self.request_next_batch(iterator), loader_name)
except StopIteration:
if i + 1 == len(self.loader_iters):
raise
5 changes: 3 additions & 2 deletions nni/retiarii/oneshot/pytorch/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def training_step(self, batch, batch_idx):
if not isinstance(arc_optim, optim.Optimizer):
raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}')

# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch, val_batch = batch
# DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch = batch['train']
val_batch = batch['val']

# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
Expand Down
Loading