Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add better support for predict + ddp 2/3 #7215

Merged
merged 50 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2e9b932
wip
tchaton Apr 20, 2021
7effee2
update
tchaton Apr 20, 2021
9321a73
update
tchaton Apr 20, 2021
128cc45
update
tchaton Apr 20, 2021
c7e49e9
update
tchaton Apr 20, 2021
9f82f7a
update
tchaton Apr 20, 2021
ce85174
typo
tchaton Apr 20, 2021
d3f9f30
update on comments
tchaton Apr 21, 2021
e1ccd1a
update
tchaton Apr 21, 2021
2a994db
update
tchaton Apr 21, 2021
69b6d77
update
tchaton Apr 21, 2021
bcf3c2b
update
tchaton Apr 22, 2021
643c8e5
update changelog
tchaton Apr 22, 2021
7109c16
update
tchaton Apr 22, 2021
fea8294
Merge branch 'master' into predict_loop_1
carmocca Apr 22, 2021
ce2656d
Fix merge
carmocca Apr 22, 2021
4ba47ed
Fix merge
carmocca Apr 22, 2021
0705ca7
Merge branch 'master' into predict_loop_1
tchaton Apr 22, 2021
54a5008
Merge branch 'predict_loop_1' of https://github.com/PyTorchLightning/…
tchaton Apr 22, 2021
1bf0325
move code
tchaton Apr 22, 2021
5243c91
resolve test
tchaton Apr 22, 2021
550d3f3
add extra test
tchaton Apr 22, 2021
0169e9e
add an extra test
tchaton Apr 22, 2021
4962459
update on comments
tchaton Apr 23, 2021
a371c5c
add typing
tchaton Apr 23, 2021
a163c2d
resolve flake8
tchaton Apr 23, 2021
63551ca
Refactor and Docs
carmocca Apr 23, 2021
0937e73
Fix tests
carmocca Apr 23, 2021
d4f523e
Fix tests
carmocca Apr 23, 2021
9a44529
Fix tests
carmocca Apr 23, 2021
d66d704
Duplicate
carmocca Apr 23, 2021
71685f2
Fix tests
carmocca Apr 23, 2021
89b281e
resolve bug
tchaton Apr 26, 2021
4416fa5
update
tchaton Apr 26, 2021
b627ed0
update on comments
tchaton Apr 26, 2021
ca64408
update
tchaton Apr 26, 2021
689bde2
update changelog
tchaton Apr 26, 2021
c40c4fa
update
tchaton Apr 26, 2021
174e50c
update
tchaton Apr 26, 2021
f6b6ae0
remove tpu
tchaton Apr 26, 2021
666a526
resolve flake8
tchaton Apr 26, 2021
26ba61e
update on comments
tchaton Apr 26, 2021
13405a8
update on comments
tchaton Apr 26, 2021
48a100a
update on comment
tchaton Apr 27, 2021
775c5c5
resolve flake8
tchaton Apr 27, 2021
b00d903
add a cpu test for predict
tchaton Apr 27, 2021
6c481af
add None test
tchaton Apr 27, 2021
1654030
update
tchaton Apr 27, 2021
db9eda8
Update CHANGELOG.md
tchaton Apr 27, 2021
2171f77
resolve tests
tchaton Apr 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))


- Added new `UnrepeatedDistributedSampler` and `IndexBatchSamplerWrapper` for tracking distributed predictions ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))

- Added `trainer.predict(return_predictions=None|False|True)` ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


### Changed

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class _DataModuleWrapper(type):

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__has_added_checks = False

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class LightningModule(
"model_size",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

# see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
Expand Down
70 changes: 69 additions & 1 deletion pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any
from typing import Any, Iterator, List, Optional

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
Expand Down Expand Up @@ -75,3 +76,70 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any):
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False


class UnrepeatedDistributedSampler(DistributedSampler):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is NoPaddingDistributedSampler clearer?

"""
A fork of the pytorch DistributedSampler that doesn't repeat data, instead
allowing the number of batches per process to be off-by-one from each other.
This makes this sampler usable for predictions (it's deterministic and
doesn't require shuffling). It is potentially unsafe to use this sampler for
training, because during training the DistributedDataParallel syncs buffers
on each forward pass, so it could freeze if one of the processes runs one
fewer batch. During prediction, buffers are only synced on the first batch,
so this is safe to use as long as each process runs at least one batch. We
verify this in an assert.

Taken from https://github.com/jpuigcerver/PyLaia/blob/v1.0.0/laia/data/unpadded_distributed_sampler.py
and https://github.com/pytorch/pytorch/issues/25162#issuecomment-634146002
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.num_samples = len(range(self.rank, len(self.dataset), self.num_replicas))
self.total_size = len(self.dataset)
# If any process has at least one batch, every other process needs to
# have at least one batch, or the DistributedDataParallel could lock up.
assert self.num_samples >= 1 or self.total_size == 0

def __iter__(self) -> Iterator[List[int]]:
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

assert len(indices) == self.total_size

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
tchaton marked this conversation as resolved.
Show resolved Hide resolved
assert len(indices) == self.num_samples

return iter(indices)


class IndexBatchSamplerWrapper:
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""

def __init__(self, sampler: BatchSampler) -> None:
self.batch_sampler = sampler
self.batch_indices: Optional[List[int]] = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self) -> Iterator[List[int]]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for batch in self.batch_sampler:
self.batch_indices = batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably shall be protected...

Suggested change
self.batch_indices = batch
self._indices = batch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before, let's keep it public.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

above #7215 (comment) you say private so what you mean?

yield batch

@property
def drop_last(self) -> bool:
return self.batch_sampler.drop_last

@property
def batch_size(self) -> int:
return self.batch_sampler.batch_size

@property
def sampler(self) -> Sampler:
return self.batch_sampler.sampler
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def use_spawn(self) -> bool:
return True

@property
def _is_single_process_single_device(self):
return True
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def root_device(self):
def on_gpu(self):
return self.root_device.type == "cuda" and torch.cuda.is_available()

@property
def use_spawn(self) -> bool:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid this. Can we do isinstance check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I find it less clean.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is this style of adding properties that identify each plugin will not scale long term and at the finest granularity it is the same as an instance check.

it could be part of the accelerator connector refactor one day. The use_spawn could be handled by the connector.


@property
def lightning_module(self):
return unwrap_lightning_module(self._model)
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def on_tpu(self) -> bool:
def on_gpu(self) -> bool:
return self.device.type == "cuda" and torch.cuda.is_available()

@property
def use_spawn(self) -> bool:
return False

def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
"""
Reduces a tensor from several distributed processes to one aggregated tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def model_to_device(self) -> None:
def is_global_zero(self) -> bool:
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""

@property
@abstractmethod
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def use_spawn(self) -> bool:
"""Whether the current processes are being spawned"""

@abstractmethod
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
"""
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from datetime import timedelta
from typing import List, Union, Optional, Dict
from typing import Dict, List, Optional, Union

from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.timer import Timer
Expand Down Expand Up @@ -58,6 +58,8 @@ def on_trainer_init(
# configure swa callback
self._configure_swa_callbacks()

# configure the timer callback.
# responsible to stop the training when max_time is reached.
self._configure_timer_callback(max_time)

# init progress bar
Expand Down Expand Up @@ -115,9 +117,7 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
if max_time is None:
return
if any(isinstance(cb, Timer) for cb in self.trainer.callbacks):
rank_zero_info(
"Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer."
)
rank_zero_info("Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer.")
return
timer = Timer(duration=max_time, interval="step")
self.trainer.callbacks.append(timer)
Expand Down
38 changes: 25 additions & 13 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -107,7 +109,9 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)

def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
def auto_add_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
Expand All @@ -133,20 +137,23 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
)

# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, shuffle)
dataloader = self.replace_sampler(dataloader, sampler)
sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode)
dataloader = self.replace_sampler(dataloader, sampler, mode=mode)

return dataloader

@staticmethod
def _resolve_batch_sampler(dl_args, dataloader, sampler):
def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
is_predicting = mode == RunningStage.PREDICTING
if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
drop_last=False if is_predicting else batch_sampler.drop_last,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
dl_args['batch_sampler'] = batch_sampler
dl_args['batch_size'] = 1
dl_args['shuffle'] = False
Expand All @@ -159,7 +166,7 @@ def _resolve_batch_sampler(dl_args, dataloader, sampler):

return dl_args

def replace_sampler(self, dataloader, sampler):
def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
skip_keys = ('sampler', 'batch_sampler', 'dataset_kind')
skip_signature_keys = ('args', 'kwargs', 'self')

Expand All @@ -174,7 +181,7 @@ def replace_sampler(self, dataloader, sampler):

dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}

dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler)
dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler, mode=mode)

multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context
Expand Down Expand Up @@ -205,12 +212,15 @@ def __init__(self, num_features, dataset, *args, **kwargs):
dataloader.multiprocessing_context = multiprocessing_context
return dataloader

def _get_distributed_sampler(self, dataloader, shuffle):
def _get_distributed_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DistributedSampler:
kwargs = self.distributed_sampler_kwargs
kwargs["shuffle"] = shuffle and not self.overfit_batches
if _TORCH_GREATER_EQUAL_1_6:
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
sampler = DistributedSampler(dataloader.dataset, **kwargs)
cls = UnrepeatedDistributedSampler if mode == RunningStage.PREDICTING else DistributedSampler
sampler = cls(dataloader.dataset, **kwargs)
return sampler

def reset_train_dataloader(self, model: LightningModule) -> None:
Expand Down Expand Up @@ -296,7 +306,7 @@ def _reset_eval_dataloader(

Args:
model: The current `LightningModule`
mode: Either `'val'` or `'test'`
mode: Either `'val'`, `'test'` or `'predict'`

Returns:
Tuple (num_batches, dataloaders)
Expand Down Expand Up @@ -342,7 +352,9 @@ def _reset_eval_dataloader(
rank_zero_warn("One of given dataloaders is None and it will be skipped.")

# add samplers
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]
dataloaders = [
self.auto_add_sampler(dl, shuffle=False, mode=self._running_stage) for dl in dataloaders if dl is not None
]

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
Expand Down
Loading