Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support deterministic="warn" in Trainer for Pytorch 1.11+ #12588

Merged
merged 19 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ concurrency:
jobs:
mypy:
runs-on: ubuntu-20.04
#strategy:
# fail-fast: false
# matrix:
# include:
# - {python-version: "3.8", pytorch-version: "1.8"}
# - {python-version: "3.9", pytorch-version: "1.10"}
steps:
- uses: actions/checkout@master
- uses: actions/setup-python@v2
with:
# python-version: ${{ matrix.python-version }}
python-version: 3.9

# Note: This uses an internal pip API and may not always work
Expand All @@ -37,15 +30,10 @@ jobs:
${{ runner.os }}-pip-

- name: Install dependencies
env:
# TORCH_VERSION: ${{ matrix.pytorch-version }}
TORCH_VERSION: "1.10"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
run: |
pip install "torch==$TORCH_VERSION" --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
# adjust versions according installed Torch version
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python ./requirements/adjust-versions.py requirements/extra.txt
python ./requirements/adjust-versions.py requirements/examples.txt
pip install '.[dev]' --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install '.[dev]'
pip list

- name: Type check
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))


-
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))


-
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List, Optional, Union

import torch
from typing_extensions import Literal
carmocca marked this conversation as resolved.
Show resolved Hide resolved

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
Expand Down Expand Up @@ -80,13 +81,21 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.imports import (
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_11,
_TPU_AVAILABLE,
)

log = logging.getLogger(__name__)

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

_LITERAL_WARN = Literal["warn"]


class AcceleratorConnector:
def __init__(
Expand All @@ -102,7 +111,7 @@ def __init__(
sync_batchnorm: bool = False,
benchmark: Optional[bool] = None,
replace_sampler_ddp: bool = True,
deterministic: bool = False,
deterministic: Union[bool, _LITERAL_WARN] = False,
auto_select_gpus: bool = False,
num_processes: Optional[int] = None, # deprecated
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
Expand Down Expand Up @@ -205,9 +214,12 @@ def __init__(
# 6. Instantiate Strategy - Part 2
self._lazy_init_strategy()

def _init_deterministic(self, deterministic: bool) -> None:
def _init_deterministic(self, deterministic: Union[bool, _LITERAL_WARN]) -> None:
self.deterministic = deterministic
torch.use_deterministic_algorithms(deterministic)
if _TORCH_GREATER_EQUAL_1_11 and deterministic == "warn":
torch.use_deterministic_algorithms(True, warn_only=True)
else:
torch.use_deterministic_algorithms(deterministic)
if deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: Optional[bool] = None,
deterministic: bool = False,
deterministic: Union[bool, _LITERAL_WARN] = False,
reload_dataloaders_every_n_epochs: int = 0,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
Expand Down Expand Up @@ -257,6 +257,8 @@ def __init__(
Default: ``False``.

deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
that don't support deterministic mode (requires Pytorch 1.11+).
Default: ``False``.

devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
Expand Down
2 changes: 1 addition & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_device
Trainer(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator)


@pytest.mark.parametrize("deterministic", [True, False])
@pytest.mark.parametrize("deterministic", [True, False, pytest.param("warn", marks=RunIf(min_torch="1.11.0"))])
def test_deterministic_init(deterministic):
trainer = Trainer(accelerator="auto", deterministic=deterministic)
assert trainer._accelerator_connector.deterministic == deterministic
Expand Down