Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add better support for predict + ddp 2/3 #7215

Merged
merged 50 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2e9b932
wip
tchaton Apr 20, 2021
7effee2
update
tchaton Apr 20, 2021
9321a73
update
tchaton Apr 20, 2021
128cc45
update
tchaton Apr 20, 2021
c7e49e9
update
tchaton Apr 20, 2021
9f82f7a
update
tchaton Apr 20, 2021
ce85174
typo
tchaton Apr 20, 2021
d3f9f30
update on comments
tchaton Apr 21, 2021
e1ccd1a
update
tchaton Apr 21, 2021
2a994db
update
tchaton Apr 21, 2021
69b6d77
update
tchaton Apr 21, 2021
bcf3c2b
update
tchaton Apr 22, 2021
643c8e5
update changelog
tchaton Apr 22, 2021
7109c16
update
tchaton Apr 22, 2021
fea8294
Merge branch 'master' into predict_loop_1
carmocca Apr 22, 2021
ce2656d
Fix merge
carmocca Apr 22, 2021
4ba47ed
Fix merge
carmocca Apr 22, 2021
0705ca7
Merge branch 'master' into predict_loop_1
tchaton Apr 22, 2021
54a5008
Merge branch 'predict_loop_1' of https://github.com/PyTorchLightning/…
tchaton Apr 22, 2021
1bf0325
move code
tchaton Apr 22, 2021
5243c91
resolve test
tchaton Apr 22, 2021
550d3f3
add extra test
tchaton Apr 22, 2021
0169e9e
add an extra test
tchaton Apr 22, 2021
4962459
update on comments
tchaton Apr 23, 2021
a371c5c
add typing
tchaton Apr 23, 2021
a163c2d
resolve flake8
tchaton Apr 23, 2021
63551ca
Refactor and Docs
carmocca Apr 23, 2021
0937e73
Fix tests
carmocca Apr 23, 2021
d4f523e
Fix tests
carmocca Apr 23, 2021
9a44529
Fix tests
carmocca Apr 23, 2021
d66d704
Duplicate
carmocca Apr 23, 2021
71685f2
Fix tests
carmocca Apr 23, 2021
89b281e
resolve bug
tchaton Apr 26, 2021
4416fa5
update
tchaton Apr 26, 2021
b627ed0
update on comments
tchaton Apr 26, 2021
ca64408
update
tchaton Apr 26, 2021
689bde2
update changelog
tchaton Apr 26, 2021
c40c4fa
update
tchaton Apr 26, 2021
174e50c
update
tchaton Apr 26, 2021
f6b6ae0
remove tpu
tchaton Apr 26, 2021
666a526
resolve flake8
tchaton Apr 26, 2021
26ba61e
update on comments
tchaton Apr 26, 2021
13405a8
update on comments
tchaton Apr 26, 2021
48a100a
update on comment
tchaton Apr 27, 2021
775c5c5
resolve flake8
tchaton Apr 27, 2021
b00d903
add a cpu test for predict
tchaton Apr 27, 2021
6c481af
add None test
tchaton Apr 27, 2021
1654030
update
tchaton Apr 27, 2021
db9eda8
Update CHANGELOG.md
tchaton Apr 27, 2021
2171f77
resolve tests
tchaton Apr 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))


- Added new `UnrepeatedDistributedSampler` and `IndexBatchSamplerWrapper` for tracking distributed predictions ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))


- Added `trainer.predict(return_predictions=None|False|True)` ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


### Changed

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class _DataModuleWrapper(type):

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__has_added_checks = False

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class LightningModule(
"model_size",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

# see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
Expand Down
70 changes: 69 additions & 1 deletion pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any
from typing import Any, Iterator, List, Optional

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
Expand Down Expand Up @@ -75,3 +76,70 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any):
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False


class UnrepeatedDistributedSampler(DistributedSampler):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is NoPaddingDistributedSampler clearer?

"""
A fork of the pytorch DistributedSampler that doesn't repeat data, instead
allowing the number of batches per process to be off-by-one from each other.
This makes this sampler usable for predictions (it's deterministic and
doesn't require shuffling). It is potentially unsafe to use this sampler for
training, because during training the DistributedDataParallel syncs buffers
on each forward pass, so it could freeze if one of the processes runs one
fewer batch. During prediction, buffers are only synced on the first batch,
so this is safe to use as long as each process runs at least one batch. We
verify this in an assert.

Taken from https://github.com/jpuigcerver/PyLaia/blob/v1.0.0/laia/data/unpadded_distributed_sampler.py
and https://github.com/pytorch/pytorch/issues/25162#issuecomment-634146002
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.num_samples = len(range(self.rank, len(self.dataset), self.num_replicas))
self.total_size = len(self.dataset)
# If any process has at least one batch, every other process needs to
# have at least one batch, or the DistributedDataParallel could lock up.
assert self.num_samples >= 1 or self.total_size == 0

def __iter__(self) -> Iterator[List[int]]:
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

assert len(indices) == self.total_size

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
tchaton marked this conversation as resolved.
Show resolved Hide resolved
assert len(indices) == self.num_samples

return iter(indices)


class IndexBatchSamplerWrapper:
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""

def __init__(self, sampler: BatchSampler) -> None:
self._sampler = sampler
self.batch_indices: Optional[List[int]] = None

def __iter__(self) -> Iterator[List[int]]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for batch in self._sampler:
self.batch_indices = batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably shall be protected...

Suggested change
self.batch_indices = batch
self._indices = batch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before, let's keep it public.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

above #7215 (comment) you say private so what you mean?

yield batch

@property
def drop_last(self) -> bool:
return self._sampler.drop_last

@property
def batch_size(self) -> int:
return self._sampler.batch_size

@property
def sampler(self) -> Sampler:
return self._sampler.sampler
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from datetime import timedelta
from typing import List, Union, Optional, Dict
from typing import Dict, List, Optional, Union

from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.timer import Timer
Expand Down Expand Up @@ -58,6 +58,8 @@ def on_trainer_init(
# configure swa callback
self._configure_swa_callbacks()

# configure the timer callback.
# responsible to stop the training when max_time is reached.
self._configure_timer_callback(max_time)

# init progress bar
Expand Down Expand Up @@ -115,9 +117,7 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
if max_time is None:
return
if any(isinstance(cb, Timer) for cb in self.trainer.callbacks):
rank_zero_info(
"Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer."
)
rank_zero_info("Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer.")
return
timer = Timer(duration=max_time, interval="step")
self.trainer.callbacks.append(timer)
Expand Down
39 changes: 26 additions & 13 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -107,7 +109,9 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)

def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
def auto_add_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
Expand All @@ -133,20 +137,24 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
)

# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, shuffle)
dataloader = self.replace_sampler(dataloader, sampler)
sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode)
dataloader = self.replace_sampler(dataloader, sampler, mode=mode)

return dataloader

@staticmethod
def _resolve_batch_sampler(dl_args, dataloader, sampler):
def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
dl_args['batch_sampler'] = batch_sampler
dl_args['batch_size'] = 1
dl_args['shuffle'] = False
Expand All @@ -159,7 +167,7 @@ def _resolve_batch_sampler(dl_args, dataloader, sampler):

return dl_args

def replace_sampler(self, dataloader, sampler):
def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
skip_keys = ('sampler', 'batch_sampler', 'dataset_kind')
skip_signature_keys = ('args', 'kwargs', 'self')

Expand All @@ -174,7 +182,7 @@ def replace_sampler(self, dataloader, sampler):

dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}

dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler)
dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler, mode=mode)

multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context
Expand Down Expand Up @@ -205,12 +213,15 @@ def __init__(self, num_features, dataset, *args, **kwargs):
dataloader.multiprocessing_context = multiprocessing_context
return dataloader

def _get_distributed_sampler(self, dataloader, shuffle):
def _get_distributed_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DistributedSampler:
kwargs = self.distributed_sampler_kwargs
kwargs["shuffle"] = shuffle and not self.overfit_batches
if _TORCH_GREATER_EQUAL_1_6:
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
sampler = DistributedSampler(dataloader.dataset, **kwargs)
cls = UnrepeatedDistributedSampler if mode == RunningStage.PREDICTING else DistributedSampler
sampler = cls(dataloader.dataset, **kwargs)
return sampler

def reset_train_dataloader(self, model: LightningModule) -> None:
Expand Down Expand Up @@ -296,7 +307,7 @@ def _reset_eval_dataloader(

Args:
model: The current `LightningModule`
mode: Either `'val'` or `'test'`
mode: Either `'val'`, `'test'` or `'predict'`

Returns:
Tuple (num_batches, dataloaders)
Expand Down Expand Up @@ -342,7 +353,9 @@ def _reset_eval_dataloader(
rank_zero_warn("One of given dataloaders is None and it will be skipped.")

# add samplers
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]
dataloaders = [
self.auto_add_sampler(dl, shuffle=False, mode=self._running_stage) for dl in dataloaders if dl is not None
]

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
Expand Down
78 changes: 67 additions & 11 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union

import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache


Expand All @@ -24,6 +29,27 @@ def __init__(self, trainer):
self.max_batches = None
self.num_dataloaders = None
self.warning_cache = WarningCache()
self.batch_indices: Optional[List[int]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
# `DDPSpawnPlugin` plugins and derivate don't support return predictions.
self._return_predictions: Optional[bool] = None
self._previous_grad_status: Optional[bool] = None

@property
def return_predictions(self) -> bool:
return self._return_predictions

@return_predictions.setter
def return_predictions(self, return_predictions: Optional[bool] = None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the case of passing None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the default. It means the training type plugin will set its own default.
return_predictions = not training_type_plugin.use_spawn

# ``DDPSpawnPlugin`` plugins and derivate don't support return predictions.
is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin)
if return_predictions and is_ddp_spawn:
raise MisconfigurationException(
"`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. "
f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}."
)
# For non ``DDPSpawnPlugin`` plugin, the `return_predictions` is True by default unless user decide otherwise.
self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions

def on_trainer_init(self):
self.trainer.num_predict_batches = []
Expand Down Expand Up @@ -54,22 +80,26 @@ def setup(self, model, max_batches, dataloaders):

self.max_batches = max_batches
self.num_dataloaders = self._get_num_dataloaders(dataloaders)
self._predictions = [[] for _ in range(self.num_dataloaders)]
self.predictions = [[] for _ in range(self.num_dataloaders)]
self.epoch_batch_indices = [[] for _ in range(self.num_dataloaders)]

def _get_num_dataloaders(self, dataloaders):
def _get_num_dataloaders(self, dataloaders: List[DataLoader]) -> int:
# case where user does:
# return dl1, dl2
length = len(dataloaders)
if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length

def predict_step(self, batch, batch_idx, dataloader_idx):
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
# configure args
args = [batch, batch_idx]
if self.num_dataloaders:
args.append(dataloader_idx)

# extract batch_indices and store them
self._store_batch_indices(dataloader_idx)

model_ref = self.trainer.lightning_module

self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)
Expand All @@ -82,18 +112,44 @@ def predict_step(self, batch, batch_idx, dataloader_idx):

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)

self._predictions[dataloader_idx].append(predictions)
if self.return_predictions:
self.predictions[dataloader_idx].append(predictions)

def _store_batch_indices(self, dataloader_idx: int) -> None:
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
self.batch_indices = batch_sampler.batch_indices
if self.return_predictions:
self.epoch_batch_indices[dataloader_idx].append(batch_sampler.batch_indices)

def on_predict_epoch_end(self):
def on_predict_start(self) -> None:
# enable eval mode + no grads
self.on_predict_model_eval()
self.trainer.lightning_module.zero_grad()
self._previous_grad_status = torch.is_grad_enabled()
torch.set_grad_enabled(False)

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")

def on_predict_epoch_end(self) -> Optional[Union[List[Any], List[List[Any]]]]:
self.trainer.profiler.describe()

results = self._predictions
results: List[List[Any]] = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)

def _convert_to_numpy(v):
return v.cpu().numpy()
if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results

def on_predict_end(self):
# clear memory. the predictions are extracted in `on_predict_epoch_end`.
self.predictions = None
self.batch_indices = None

results = apply_to_collection(results, torch.Tensor, _convert_to_numpy)
# reset grad to its previous status.
torch.set_grad_enabled(self._previous_grad_status)

return results[0] if len(results) == 1 else results
# hook
self.trainer.call_hook("on_predict_end")
Loading