Skip to content

Commit

Permalink
Use base version when comparing torch versions (#16657)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit 63b9034)
  • Loading branch information
awaelchli authored and Borda committed Mar 30, 2023
1 parent 92d1c9f commit 60a04f8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
9 changes: 6 additions & 3 deletions src/lightning_fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
# 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)

_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0")
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
# We use "base_version" for non-nightly builds as well, because some environments like NVIDIA's PyTorch dockers
# install PyTorch from source at a commit that doesn't align with the released version tag.
# See: https://github.com/Lightning-AI/lightning/issues/16644
_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0", use_base_version=True)
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0", use_base_version=True)
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `num_nodes` not being set for `DDPFullyShardedNativeStrategy` ([#17160](https://github.com/Lightning-AI/lightning/pull/17160))


- Fixed an issue with comparing torch versions when using a version of torch built from source ([#16657](https://github.com/Lightning-AI/lightning/pull/16657))


## [1.9.4] - 2023-03-01

### Added
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2")
_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2", use_base_version=True)
# duplicated from fabric because HPU is patching it below
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_11 = compare_version("torchmetrics", operator.ge, "0.11.0") # using new API with task

Expand Down

0 comments on commit 60a04f8

Please sign in to comment.