-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Conversation
Tests require PL 1.6. Upgrade is in #4814. |
@@ -74,51 +79,67 @@ class Lightning(Evaluator): | |||
|
|||
Parameters | |||
---------- | |||
lightning_module : LightningModule | |||
lightning_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs can auto add type hint for these parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
f'You might have forgotten to import DataLoader from {__name__}: {train_dataloaders}', | ||
RuntimeWarning) | ||
if not _check_dataloader(val_dataloaders): | ||
warnings.warn(f'Unexpected dataloader type: {type(train_dataloaders)}. ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type(val_dataloaders)
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None): | ||
train_dataloaders: Optional[Any] = None, | ||
val_dataloaders: Optional[Any] = None, | ||
train_dataloader: Optional[Any] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why there are both train_dataloaders
and train_dataloader
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see, backward compatibility
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.' | ||
if not _check_dataloader(train_dataloaders): | ||
warnings.warn(f'Unexpected dataloader type: {type(train_dataloaders)}. ' | ||
f'You might have forgotten to import DataLoader from {__name__}: {train_dataloaders}', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may miss some background. what does this message mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_check_dataloader
used to check whether the dataloader is wrapped with nni.trace.
But now, serializer has improved, and due to the fact that we now support complex types of dataloaders (e.g., list of dataloader, dict of dataloader, list of dict of dataloaders), there is no easy way to check those dataloaders. So I give up. :(
@@ -36,6 +37,9 @@ class LightningModule(pl.LightningModule): | |||
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html | |||
""" | |||
|
|||
running_mode: Literal['multi', 'oneshot'] = 'multi' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible that a lightning class can be used both for multi and oneshot?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I used this flag, I meant to check before two actions:
- Whether to save an onnx graph.
- Whether to report intermediate / final results.
But on a second thought now, maybe evaluator could figure it out by itself whether its inner module is a one-shot supernet. So this might be not needed.
No we can't. Revert.
return { | ||
'train': train_dataloaders, | ||
'val': val_dataloaders | ||
}, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the return signature is changed from the base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put both train and val dataloaders into the train dataloader, and we don't need a val dataloader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, but the first returned value is a dict, not a train dataloader?
Description
Add multi-gpu support for one-shot NAS.
In its core, the technical issue with multi-GPU is to support data sharding in the Lightning way. We replaced the original ad-hoc dataloaders with
CombinedLoader
provided by lightning official, and a subclassConcatLoader
. These dataloaders naturally support data sharding.This PR also addresses some minor bugs and deprecated variable/method names.
Checklist
How to test
Test one-shot on multi-GPU following the guide here.