Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warnings to on_before/after_batch_transfer hooks #6059

Merged
merged 6 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,13 @@ Override to alter or apply augmentations to your batch before it is transferred
.. testcode::

class MNISTDataModule(LightningDataModule):
def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch


.. warning::
The hook signature will change once the dataloader_idx is supported as an argument.
Currently dataloader_idx always returns 0 and will be updated to support the true idx in the future.

.. note:: This hook only runs on single GPU training and DDP (no data-parallel).

Expand All @@ -332,13 +332,13 @@ Override to alter or apply augmentations to your batch after it is transferred t
.. testcode::

class MNISTDataModule(LightningDataModule):
def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch


.. warning::
The hook signature will change once the dataloader_idx is supported as an argument.
Currently dataloader_idx always returns 0 and will be updated to support the true idx in the future.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

.. note::
This hook only runs on single GPU training and DDP (no data-parallel). This hook
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,24 +616,25 @@ def transfer_batch_to_device(self, batch, device):
device = device or self.device
return move_data_to_device(batch, device)

def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch before it is transferred to the device.

.. warning:: The hook signature will change once the dataloader_idx is supported as an argument.
.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.

Note:
This hook only runs on single GPU training and DDP (no data-parallel).

Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: DataLoader idx for batch (Default: 0)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A batch of data

Example::

def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch

Expand All @@ -643,24 +644,25 @@ def on_before_batch_transfer(self, batch):
"""
return batch

def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch after it is transferred to the device.

.. warning:: The hook signature will change once the dataloader_idx is supported as an argument.
.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

Note:
This hook only runs on single GPU training and DDP (no data-parallel).

Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: DataLoader idx for batch (Default: 0)

Returns:
A batch of data

Example::

def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def logger(self):
""" Reference to the logger object in the Trainer. """
return self.trainer.logger if self.trainer else None

def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None):
batch = self.on_before_batch_transfer(batch)
def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0):
batch = self.on_before_batch_transfer(batch, dataloader_idx)
batch = self.transfer_batch_to_device(batch, device)
batch = self.on_after_batch_transfer(batch)
batch = self.on_after_batch_transfer(batch, dataloader_idx)
return batch

def print(self, *args, **kwargs) -> None:
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,13 @@ def _format_precision_config(self):
precision = self.lightning_module.trainer.accelerator_connector.precision
if precision == 16:
if "amp" not in self.config and amp_type == AMPType.NATIVE:
self.config["fp16"] = {"enabled": True}
self.config["fp16"] = {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
elif "apex" not in self.config and amp_type == AMPType.APEX:
self.config["amp"] = {
"enabled": True,
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,13 @@ class CurrentTestDM(LightningDataModule):
on_before_batch_transfer_hook_rank = None
on_after_batch_transfer_hook_rank = None

def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
self.on_before_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.samples += 1
return batch

def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
assert batch.samples.device == batch.targets.device == expected_device
self.on_after_batch_transfer_hook_rank = self.rank
self.rank += 1
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ class CurrentTestModel(BoringModel):
on_before_batch_transfer_hook_rank = None
on_after_batch_transfer_hook_rank = None

def on_before_batch_transfer(self, batch):
def on_before_batch_transfer(self, batch, dataloader_idx):
self.on_before_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.samples += 1
return batch

def on_after_batch_transfer(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
assert batch.samples.device == batch.targets.device == expected_device
self.on_after_batch_transfer_hook_rank = self.rank
self.rank += 1
Expand Down