Skip to content

Commit

Permalink
Support deterministic="warn" in Trainer for Pytorch 1.11+ (#12588)
Browse files Browse the repository at this point in the history
Co-authored-by: carmocca <carlossmocholi@gmail.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
  • Loading branch information
3 people authored Apr 27, 2022
1 parent a414862 commit 6490996
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 22 deletions.
16 changes: 2 additions & 14 deletions .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ concurrency:
jobs:
mypy:
runs-on: ubuntu-20.04
#strategy:
# fail-fast: false
# matrix:
# include:
# - {python-version: "3.8", pytorch-version: "1.8"}
# - {python-version: "3.9", pytorch-version: "1.10"}
steps:
- uses: actions/checkout@master
- uses: actions/setup-python@v2
with:
# python-version: ${{ matrix.python-version }}
python-version: 3.9

# Note: This uses an internal pip API and may not always work
Expand All @@ -37,15 +30,10 @@ jobs:
${{ runner.os }}-pip-
- name: Install dependencies
env:
# TORCH_VERSION: ${{ matrix.pytorch-version }}
TORCH_VERSION: "1.10"
run: |
pip install "torch==$TORCH_VERSION" --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
# adjust versions according installed Torch version
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python ./requirements/adjust-versions.py requirements/extra.txt
python ./requirements/adjust-versions.py requirements/examples.txt
pip install '.[dev]' --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install '.[dev]'
pip list
- name: Type check
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))


-
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))


-
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List, Optional, Union

import torch
from typing_extensions import Literal

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
Expand Down Expand Up @@ -80,13 +81,21 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.imports import (
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_11,
_TPU_AVAILABLE,
)

log = logging.getLogger(__name__)

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

_LITERAL_WARN = Literal["warn"]


class AcceleratorConnector:
def __init__(
Expand All @@ -102,7 +111,7 @@ def __init__(
sync_batchnorm: bool = False,
benchmark: Optional[bool] = None,
replace_sampler_ddp: bool = True,
deterministic: bool = False,
deterministic: Union[bool, _LITERAL_WARN] = False,
auto_select_gpus: bool = False,
num_processes: Optional[int] = None, # deprecated
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
Expand Down Expand Up @@ -205,9 +214,12 @@ def __init__(
# 6. Instantiate Strategy - Part 2
self._lazy_init_strategy()

def _init_deterministic(self, deterministic: bool) -> None:
def _init_deterministic(self, deterministic: Union[bool, _LITERAL_WARN]) -> None:
self.deterministic = deterministic
torch.use_deterministic_algorithms(deterministic)
if _TORCH_GREATER_EQUAL_1_11 and deterministic == "warn":
torch.use_deterministic_algorithms(True, warn_only=True)
else:
torch.use_deterministic_algorithms(deterministic)
if deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: Optional[bool] = None,
deterministic: bool = False,
deterministic: Union[bool, _LITERAL_WARN] = False,
reload_dataloaders_every_n_epochs: int = 0,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
Expand Down Expand Up @@ -257,6 +257,8 @@ def __init__(
Default: ``False``.
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
that don't support deterministic mode (requires Pytorch 1.11+).
Default: ``False``.
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
Expand Down
2 changes: 1 addition & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_device
Trainer(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator)


@pytest.mark.parametrize("deterministic", [True, False])
@pytest.mark.parametrize("deterministic", [True, False, pytest.param("warn", marks=RunIf(min_torch="1.11.0"))])
def test_deterministic_init(deterministic):
trainer = Trainer(accelerator="auto", deterministic=deterministic)
assert trainer._accelerator_connector.deterministic == deterministic
Expand Down

0 comments on commit 6490996

Please sign in to comment.