From 6490996b3974019a254f2750f5c01e6d38ff6a6e Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 27 Apr 2022 08:05:26 -0400 Subject: [PATCH] Support deterministic="warn" in Trainer for Pytorch 1.11+ (#12588) Co-authored-by: carmocca Co-authored-by: Akihiro Nitta --- .github/workflows/code-checks.yml | 16 ++------------- CHANGELOG.md | 2 +- .../connectors/accelerator_connector.py | 20 +++++++++++++++---- pytorch_lightning/trainer/trainer.py | 6 ++++-- .../test_accelerator_connector.py | 2 +- 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index bb019ecb40811..df34d2c47208d 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -13,17 +13,10 @@ concurrency: jobs: mypy: runs-on: ubuntu-20.04 - #strategy: - # fail-fast: false - # matrix: - # include: - # - {python-version: "3.8", pytorch-version: "1.8"} - # - {python-version: "3.9", pytorch-version: "1.10"} steps: - uses: actions/checkout@master - uses: actions/setup-python@v2 with: - # python-version: ${{ matrix.python-version }} python-version: 3.9 # Note: This uses an internal pip API and may not always work @@ -37,15 +30,10 @@ jobs: ${{ runner.os }}-pip- - name: Install dependencies - env: - # TORCH_VERSION: ${{ matrix.pytorch-version }} - TORCH_VERSION: "1.10" run: | - pip install "torch==$TORCH_VERSION" --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - # adjust versions according installed Torch version + pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html python ./requirements/adjust-versions.py requirements/extra.txt - python ./requirements/adjust-versions.py requirements/examples.txt - pip install '.[dev]' --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install '.[dev]' pip list - name: Type check diff --git a/CHANGELOG.md b/CHANGELOG.md index e5d266c77dfd8..0100e69bbcfd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532)) -- +- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588)) - diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 826e278e27a73..b352e91c9759a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Union import torch +from typing_extensions import Literal from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator @@ -80,13 +81,21 @@ rank_zero_warn, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities.imports import ( + _HOROVOD_AVAILABLE, + _HPU_AVAILABLE, + _IPU_AVAILABLE, + _TORCH_GREATER_EQUAL_1_11, + _TPU_AVAILABLE, +) log = logging.getLogger(__name__) if _HOROVOD_AVAILABLE: import horovod.torch as hvd +_LITERAL_WARN = Literal["warn"] + class AcceleratorConnector: def __init__( @@ -102,7 +111,7 @@ def __init__( sync_batchnorm: bool = False, benchmark: Optional[bool] = None, replace_sampler_ddp: bool = True, - deterministic: bool = False, + deterministic: Union[bool, _LITERAL_WARN] = False, auto_select_gpus: bool = False, num_processes: Optional[int] = None, # deprecated tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated @@ -205,9 +214,12 @@ def __init__( # 6. Instantiate Strategy - Part 2 self._lazy_init_strategy() - def _init_deterministic(self, deterministic: bool) -> None: + def _init_deterministic(self, deterministic: Union[bool, _LITERAL_WARN]) -> None: self.deterministic = deterministic - torch.use_deterministic_algorithms(deterministic) + if _TORCH_GREATER_EQUAL_1_11 and deterministic == "warn": + torch.use_deterministic_algorithms(True, warn_only=True) + else: + torch.use_deterministic_algorithms(deterministic) if deterministic: # fixing non-deterministic part of horovod # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 41e47684fd81c..70e3e3bfc8e66 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -65,7 +65,7 @@ from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations -from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector @@ -173,7 +173,7 @@ def __init__( resume_from_checkpoint: Optional[Union[Path, str]] = None, profiler: Optional[Union[BaseProfiler, str]] = None, benchmark: Optional[bool] = None, - deterministic: bool = False, + deterministic: Union[bool, _LITERAL_WARN] = False, reload_dataloaders_every_n_epochs: int = 0, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, @@ -257,6 +257,8 @@ def __init__( Default: ``False``. deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. + Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations + that don't support deterministic mode (requires Pytorch 1.11+). Default: ``False``. devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`, diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 5f5d3a3877fa7..2c022ca413382 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -699,7 +699,7 @@ def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_device Trainer(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator) -@pytest.mark.parametrize("deterministic", [True, False]) +@pytest.mark.parametrize("deterministic", [True, False, pytest.param("warn", marks=RunIf(min_torch="1.11.0"))]) def test_deterministic_init(deterministic): trainer = Trainer(accelerator="auto", deterministic=deterministic) assert trainer._accelerator_connector.deterministic == deterministic