This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Multi-GPU support of one-shot NAS #4603
Merged
Merged
Changes from all commits
Commits
Show all changes
99 commits
Select commit
Hold shift + click to select a range
156fe72
init
v-fangdong 6ec5f85
init
v-fangdong 0fd40e8
differentiable and sampling-based ontshot algs
v-fangdong 32ac1ab
fix import mistake
v-fangdong 8091169
proxyless inherits darts
v-fangdong f29c39a
test files
v-fangdong d60f87e
add more comments and fix test files
v-fangdong 09a065f
revert unrelated changes
v-fangdong a2e2fa8
unify comments style and remove debug codes
v-fangdong 2289c19
remove unnecessary methods
v-fangdong 9193af6
fix link
v-fangdong 50ae803
unify code style
v-fangdong 559518c
SNAS
v-fangdong 1744e74
fix pylint, configure optimizers and training step loss
v-fangdong e4d5feb
fix pylint
v-fangdong 61150e5
disable pylint unsubscriptable-opject warning
v-fangdong 95038b3
use lib
v-fangdong 6159639
use metrics
v-fangdong 48081c9
remove unused
v-fangdong badcc83
fix bugs
v-fangdong a7fb932
fix lint
v-fangdong 0a60dee
solve lr_scheduler
v-fangdong e1d2995
fix pylint
v-fangdong f763aaf
fix pylint
v-fangdong de23dfa
remove validation_step
v-fangdong 1229bc6
fix pylint
v-fangdong 9bebecf
fix bug
v-fangdong 09da5c3
fix bug
v-fangdong 3e71de6
fix bug
v-fangdong 2ac3d3d
fix bug
v-fangdong 54d9c1b
fix legacy bugs
v-fangdong 3cb92d4
fix legacy bugs
v-fangdong 0f7ad72
fix legacy bugs
v-fangdong 1c01442
fix legacy bugs
v-fangdong 32acdc5
fix bugs
v-fangdong 2d99c6e
annealing for snas
v-fangdong 4a1eba1
rerun ut
v-fangdong c1dce81
fix SNAS loss
v-fangdong b3ede07
remove too long lines
v-fangdong f2d56a5
fix snas temp
v-fangdong 4a038b2
unify comments style, abstract optimizer-related behaviours into base…
v-fangdong 284fe44
fix pylint
v-fangdong 6068064
add more comments
v-fangdong 05429ed
add input choice test case
v-fangdong ef84ae0
add inputchoice testcase
v-fangdong 1c5839a
fix pylint
v-fangdong ee1120c
fix pylin
v-fangdong 80845ff
fix typos
v-fangdong af3ae61
resolve conversations
v-fangdong 0e4b0e5
add comments
v-fangdong 897b2ec
update
ultmaster 5e794ff
documents
ultmaster dfc82d6
Delete irrelevant files
ultmaster 2a99859
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong 35dbf11
correct typos
v-fangdong 0851ec8
solve dict issue
ultmaster c21ce0f
update testing
ultmaster 3bee06f
update serializer on windows (attempt)
ultmaster 1433559
Merge branch 'fix-4434' of https://github.com/ultmaster/nni into fix-…
ultmaster a6a8061
update test
ultmaster 6db030a
Avoid abc
ultmaster 213e5b8
fix is_traceable
ultmaster f590c93
update is_traceable
ultmaster df3e9a4
fix all
ultmaster e328f05
update
ultmaster 577e771
fix lightning problem on linux
ultmaster 278b5c9
fix super
ultmaster 7c51470
fix lightning issue
ultmaster 0bfe695
!= linux
ultmaster d49f6e8
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong 150279a
add handler for trainer
ultmaster ec9d7fa
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong abfd0a8
multicard
v-fangdong 124232c
Merge branch 'lightning' into cherrypick-4434
v-fangdong f279ab0
pseudo dataset
v-fangdong 84b01bd
fix pickle (again)
ultmaster d2ecda8
Merge remote-tracking branch 'ultmaster/fix-4434' into cherrypick-4434
v-fangdong 146ddc4
enas gpu
v-fangdong 3e6178d
reinit dataloaders
v-fangdong 0603606
always shuffle for trianing
v-fangdong 7a45b8f
Merge branch 'lightning' of github.com:Frandium/nni into multicard-on…
ultmaster af2124c
Merge branch 'master' of https://github.com/microsoft/nni into multic…
ultmaster 8fc6fff
Merge branch 'master' of github.com:microsoft/nni into multicard-oneshot
ultmaster 9d9beaf
revert
ultmaster 77a8c8b
revert
ultmaster 001b1b4
.
ultmaster a4e245b
checkpoint
ultmaster 6a1974b
finish up
ultmaster f31759c
.
ultmaster 6966aca
add tests
ultmaster 78374b1
revert
ultmaster f191ce6
revert
ultmaster 4d3f187
revert
ultmaster 1c66ff0
compat fix
ultmaster 74edc65
fix tests
ultmaster 5465417
fix
ultmaster 7dec65c
Merge branch 'master' of github.com:microsoft/nni into multicard-oneshot
ultmaster cad1b2c
fix comments
ultmaster b53f98b
.
ultmaster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import os | ||
import warnings | ||
from pathlib import Path | ||
from typing import Dict, Union, Optional, List, Callable, Type | ||
from typing import Any, Dict, Union, Optional, List, Callable, Type | ||
|
||
import pytorch_lightning as pl | ||
import torch.nn as nn | ||
|
@@ -22,6 +22,7 @@ | |
cgo_import_failed = True | ||
|
||
from nni.retiarii.graph import Evaluator | ||
from nni.typehint import Literal | ||
|
||
|
||
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression'] | ||
|
@@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule): | |
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html | ||
""" | ||
|
||
running_mode: Literal['multi', 'oneshot'] = 'multi' | ||
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot. | ||
This flag should be automatically set by experiments when they start to run. | ||
""" | ||
|
||
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None: | ||
"""Set the inner model (architecture) to train / evaluate. | ||
|
||
|
@@ -59,6 +65,7 @@ def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None: | |
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html | ||
""" | ||
|
||
|
||
@nni.trace | ||
class Lightning(Evaluator): | ||
""" | ||
|
@@ -74,51 +81,67 @@ class Lightning(Evaluator): | |
|
||
Parameters | ||
---------- | ||
lightning_module : LightningModule | ||
lightning_module | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docs can auto add type hint for these parameters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
Lightning module that defines the training logic. | ||
trainer : Trainer | ||
trainer | ||
Lightning trainer that handles the training. | ||
train_dataloders : DataLoader | ||
train_dataloders | ||
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. | ||
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. | ||
val_dataloaders : DataLoader or List of DataLoader | ||
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__. | ||
val_dataloaders | ||
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. | ||
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. | ||
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__. | ||
""" | ||
|
||
def __init__(self, lightning_module: LightningModule, trainer: Trainer, | ||
train_dataloader: Optional[DataLoader] = None, | ||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None): | ||
train_dataloaders: Optional[Any] = None, | ||
val_dataloaders: Optional[Any] = None, | ||
train_dataloader: Optional[Any] = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why there are both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i see, backward compatibility |
||
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' | ||
if train_dataloader is not None: | ||
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) | ||
train_dataloaders = train_dataloader | ||
if cgo_import_failed: | ||
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}' | ||
else: | ||
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different | ||
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \ | ||
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer' | ||
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.' | ||
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.' | ||
if not _check_dataloader(train_dataloaders): | ||
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' | ||
f'import DataLoader from {__name__}: {train_dataloaders}', | ||
RuntimeWarning) | ||
if not _check_dataloader(val_dataloaders): | ||
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' | ||
f'import DataLoader from {__name__}: {val_dataloaders}', | ||
RuntimeWarning) | ||
self.module = lightning_module | ||
self.trainer = trainer | ||
self.train_dataloader = train_dataloader | ||
self.train_dataloaders = train_dataloaders | ||
self.val_dataloaders = val_dataloaders | ||
|
||
@staticmethod | ||
def _load(ir): | ||
return Lightning(ir['module'], ir['trainer'], ir['train_dataloader'], ir['val_dataloaders']) | ||
return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders']) | ||
|
||
def _dump(self): | ||
return { | ||
'type': self.__class__, | ||
'module': self.module, | ||
'trainer': self.trainer, | ||
'train_dataloader': self.train_dataloader, | ||
'train_dataloaders': self.train_dataloaders, | ||
'val_dataloaders': self.val_dataloaders | ||
} | ||
|
||
def _execute(self, model_cls): | ||
return self.fit(model_cls) | ||
|
||
@property | ||
def train_dataloader(self): | ||
warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning) | ||
|
||
def __eq__(self, other): | ||
eq_func = False | ||
eq_args = False | ||
|
@@ -146,15 +169,18 @@ def fit(self, model): | |
The model to fit. | ||
""" | ||
self.module.set_model(model) | ||
return self.trainer.fit(self.module, self.train_dataloader, self.val_dataloaders) | ||
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders) | ||
|
||
|
||
def _check_dataloader(dataloader): | ||
if dataloader is None: | ||
return True | ||
# Check the type of dataloader recursively. | ||
if isinstance(dataloader, list): | ||
return all([_check_dataloader(d) for d in dataloader]) | ||
return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader) | ||
if isinstance(dataloader, dict): | ||
return all([_check_dataloader(v) for v in dataloader.values()]) | ||
if isinstance(dataloader, torch_data.DataLoader): | ||
return is_traceable(dataloader) | ||
return True | ||
|
||
|
||
### The following are some commonly used Lightning modules ### | ||
|
@@ -176,7 +202,6 @@ def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetr | |
|
||
if export_onnx is None or export_onnx is True: | ||
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx' | ||
self.export_onnx.parent.mkdir(exist_ok=True) | ||
elif export_onnx: | ||
self.export_onnx = Path(export_onnx) | ||
else: | ||
|
@@ -199,7 +224,8 @@ def validation_step(self, batch, batch_idx): | |
x, y = batch | ||
y_hat = self(x) | ||
|
||
if self.export_onnx is not None: | ||
if self.running_mode == 'multi' and self.export_onnx is not None: | ||
self.export_onnx.parent.mkdir(exist_ok=True) | ||
try: | ||
self.to_onnx(self.export_onnx, x, export_params=True) | ||
except RuntimeError as e: | ||
|
@@ -221,10 +247,12 @@ def configure_optimizers(self): | |
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore | ||
|
||
def on_validation_epoch_end(self): | ||
nni.report_intermediate_result(self._get_validation_metrics()) | ||
if self.running_mode == 'multi': | ||
nni.report_intermediate_result(self._get_validation_metrics()) | ||
|
||
def on_fit_end(self): | ||
nni.report_final_result(self._get_validation_metrics()) | ||
if self.running_mode == 'multi': | ||
nni.report_final_result(self._get_validation_metrics()) | ||
|
||
def _get_validation_metrics(self): | ||
if len(self.metrics) == 1: | ||
|
@@ -283,14 +311,18 @@ def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss, | |
learning_rate: float = 0.001, | ||
weight_decay: float = 0., | ||
optimizer: Type[optim.Optimizer] = optim.Adam, | ||
train_dataloader: Optional[DataLoader] = None, | ||
train_dataloaders: Optional[DataLoader] = None, | ||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, | ||
export_onnx: bool = True, | ||
train_dataloader: Optional[DataLoader] = None, | ||
**trainer_kwargs): | ||
if train_dataloader is not None: | ||
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) | ||
train_dataloaders = train_dataloader | ||
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate, | ||
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) | ||
super().__init__(module, Trainer(**trainer_kwargs), | ||
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) | ||
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) | ||
|
||
|
||
@nni.trace | ||
|
@@ -336,11 +368,15 @@ def __init__(self, criterion: Type[nn.Module] = nn.MSELoss, | |
learning_rate: float = 0.001, | ||
weight_decay: float = 0., | ||
optimizer: Type[optim.Optimizer] = optim.Adam, | ||
train_dataloader: Optional[DataLoader] = None, | ||
train_dataloaders: Optional[DataLoader] = None, | ||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, | ||
export_onnx: bool = True, | ||
train_dataloader: Optional[DataLoader] = None, | ||
**trainer_kwargs): | ||
if train_dataloader is not None: | ||
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) | ||
train_dataloaders = train_dataloader | ||
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate, | ||
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) | ||
super().__init__(module, Trainer(**trainer_kwargs), | ||
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) | ||
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from pytorch_lightning.trainer.supporters import CombinedLoader, CombinedLoaderIterator | ||
|
||
|
||
class ConcatLoader(CombinedLoader): | ||
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders | ||
instead of loading them in parallel. | ||
|
||
Parameters | ||
---------- | ||
loaders | ||
For example, :: | ||
|
||
{ | ||
"train": DataLoader(train_dataset), | ||
"val": DataLoader(val_dataset) | ||
} | ||
|
||
In this example, the loader will first produce the batches from "train", then "val". | ||
|
||
mode | ||
Only support "min_size" for now. | ||
""" | ||
|
||
def __init__(self, loaders: dict[str, Any], mode: str = 'min_size'): | ||
# FIXME: max_cycle will make dataloaders cycle iterators, | ||
# causing extra problems. | ||
if mode != 'min_size': | ||
raise ValueError('Only min_size mode is supported now.') | ||
super().__init__(loaders, mode) | ||
|
||
def __iter__(self) -> Any: | ||
"""Replace the super-class iterator with ours.""" | ||
self._try_to_patch_pytorch_dataloader() | ||
iterator = ConcatLoaderIterator(self.loaders) | ||
# handle fault tolerant restart. | ||
self.on_restart(iterator) | ||
self._iterator = iterator | ||
return iterator | ||
|
||
@staticmethod | ||
def _try_to_patch_pytorch_dataloader(): | ||
"""Copied from CombinedLoader.""" | ||
from torch.utils.data.dataloader import _BaseDataLoaderIter | ||
|
||
# prevent `NotImplementedError` from PyTorch: | ||
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541 | ||
def __getstate__patch__(*_): | ||
return {} | ||
|
||
_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore | ||
|
||
def __len__(self) -> int: | ||
return int(sum(self._calc_num_batches(loader) for loader in self.loaders.values())) | ||
|
||
|
||
class ConcatLoaderIterator(CombinedLoaderIterator): | ||
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner.""" | ||
|
||
def __next__(self) -> Any: | ||
"""Fetches the next batch from multiple data loaders, | ||
by looking for the first iterator that isn't exhausted yet. | ||
""" | ||
if not len(self.loader_iters) == len(self.loaders): | ||
raise RuntimeError('loader_iters must have the same length as loaders.') | ||
for i, (loader_name, iterator) in enumerate(self.loader_iters.items()): | ||
try: | ||
return (self.request_next_batch(iterator), loader_name) | ||
except StopIteration: | ||
if i + 1 == len(self.loader_iters): | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible that a lightning class can be used both for multi and oneshot?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I used this flag, I meant to check before two actions:
But on a second thought now, maybe evaluator could figure it out by itself whether its inner module is a one-shot supernet. So this might be not needed.No we can't. Revert.