Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update script with dynamic torch-ecosystem versions determinations #318

Merged
merged 5 commits into from
Oct 18, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 128 additions & 38 deletions scripts/adjust-torch-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,141 @@
import sys
from typing import Dict, List, Optional

VERSIONS = [
{"torch": "2.6.0", "torchvision": "0.21.0", "torchtext": "0.18.0", "torchaudio": "2.6.0"}, # nightly
{"torch": "2.5.0", "torchvision": "0.20.0", "torchtext": "0.18.0", "torchaudio": "2.5.0"}, # stable
{"torch": "2.4.1", "torchvision": "0.19.1", "torchtext": "0.18.0", "torchaudio": "2.4.1"},
{"torch": "2.4.0", "torchvision": "0.19.0", "torchtext": "0.18.0", "torchaudio": "2.4.0"},
{"torch": "2.3.1", "torchvision": "0.18.1", "torchtext": "0.18.0", "torchaudio": "2.3.1"},
{"torch": "2.3.0", "torchvision": "0.18.0", "torchtext": "0.18.0", "torchaudio": "2.3.0"},
{"torch": "2.2.2", "torchvision": "0.17.2", "torchtext": "0.17.2", "torchaudio": "2.2.2"},
{"torch": "2.2.1", "torchvision": "0.17.1", "torchtext": "0.17.1", "torchaudio": "2.2.1"},
{"torch": "2.2.0", "torchvision": "0.17.0", "torchtext": "0.17.0", "torchaudio": "2.2.0"},
{"torch": "2.1.2", "torchvision": "0.16.2", "torchtext": "0.16.2", "torchaudio": "2.1.2"},
{"torch": "2.1.1", "torchvision": "0.16.1", "torchtext": "0.16.1", "torchaudio": "2.1.1"},
{"torch": "2.1.0", "torchvision": "0.16.0", "torchtext": "0.16.0", "torchaudio": "2.1.0"},
{"torch": "2.0.1", "torchvision": "0.15.2", "torchtext": "0.15.2", "torchaudio": "2.0.2"},
{"torch": "2.0.0", "torchvision": "0.15.1", "torchtext": "0.15.1", "torchaudio": "2.0.1"},
{"torch": "1.14.0", "torchvision": "0.15.0", "torchtext": "0.15.0", "torchaudio": "0.14.0"}, # nightly / shifted
{"torch": "1.13.1", "torchvision": "0.14.1", "torchtext": "0.14.1", "torchaudio": "0.13.1"},
{"torch": "1.13.0", "torchvision": "0.14.0", "torchtext": "0.14.0", "torchaudio": "0.13.0"},
{"torch": "1.12.1", "torchvision": "0.13.1", "torchtext": "0.13.1", "torchaudio": "0.12.1"},
{"torch": "1.12.0", "torchvision": "0.13.0", "torchtext": "0.13.0", "torchaudio": "0.12.0"},
{"torch": "1.11.0", "torchvision": "0.12.0", "torchtext": "0.12.0", "torchaudio": "0.11.0"},
{"torch": "1.10.2", "torchvision": "0.11.3", "torchtext": "0.11.2", "torchaudio": "0.10.2"},
{"torch": "1.10.1", "torchvision": "0.11.2", "torchtext": "0.11.1", "torchaudio": "0.10.1"},
{"torch": "1.10.0", "torchvision": "0.11.1", "torchtext": "0.11.0", "torchaudio": "0.10.0"},
{"torch": "1.9.1", "torchvision": "0.10.1", "torchtext": "0.10.1", "torchaudio": "0.9.1"},
{"torch": "1.9.0", "torchvision": "0.10.0", "torchtext": "0.10.0", "torchaudio": "0.9.0"},
{"torch": "1.8.2", "torchvision": "0.9.1", "torchtext": "0.9.1", "torchaudio": "0.8.1"},
{"torch": "1.8.1", "torchvision": "0.9.1", "torchtext": "0.9.1", "torchaudio": "0.8.1"},
{"torch": "1.8.0", "torchvision": "0.9.0", "torchtext": "0.9.0", "torchaudio": "0.8.0"},
]

def _determine_torchaudio(torch_version: str) -> str:
"""Determine the torchaudio version based on the torch version.

>>> _determine_torchaudio("1.9.0")
'0.9.0'
>>> _determine_torchaudio("2.4.1")
'2.4.1'
>>> _determine_torchaudio("1.8.2")
'0.9.1'

"""
_version_exceptions = {
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
ta_ver_array = [ver_major, ver_minor, ver_bugfix]
if ver_major == 1:
ta_ver_array[0] = 0
ta_ver_array[2] = ver_bugfix
return ".".join(map(str, ta_ver_array))


def _determine_torchtext(torch_version: str) -> str:
"""Determine the torchtext version based on the torch version.

>>> _determine_torchtext("1.9.0")
'0.10.0'
>>> _determine_torchtext("2.4.1")
'0.18.0'
>>> _determine_torchtext("1.8.2")
'0.9.1'

"""
_version_exceptions = {
"2.0.1": "0.15.2",
"2.0.0": "0.15.1",
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
tt_ver_array = [0, 0, 0]
if ver_major == 1:
tt_ver_array[1] = ver_minor + 1
tt_ver_array[2] = ver_bugfix
elif ver_major == 2:
if ver_minor >= 3:
tt_ver_array[1] = 18
else:
tt_ver_array[1] = ver_minor + 15
tt_ver_array[2] = ver_bugfix
else:
raise ValueError(f"Invalid torch version: {torch_version}")
return ".".join(map(str, tt_ver_array))


def _determine_torchvision(torch_version: str) -> str:
"""Determine the torchvision version based on the torch version.

>>> _determine_torchvision("1.9.0")
'0.10.0'
>>> _determine_torchvision("2.4.1")
'0.19.1'
>>> _determine_torchvision("2.0.1")
'0.15.2'

"""
_version_exceptions = {
"2.0.1": "0.15.2",
"2.0.0": "0.15.1",
"1.10.2": "0.11.3",
"1.10.1": "0.11.2",
"1.10.0": "0.11.1",
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
tv_ver_array = [0, 0, 0]
if ver_major == 1:
tv_ver_array[1] = ver_minor + 1
elif ver_major == 2:
tv_ver_array[1] = ver_minor + 15
else:
raise ValueError(f"Invalid torch version: {torch_version}")
tv_ver_array[2] = ver_bugfix
return ".".join(map(str, tv_ver_array))


def find_latest(ver: str) -> Dict[str, str]:
"""Find the latest version."""
"""Find the latest version.

>>> from pprint import pprint
>>> pprint(find_latest("2.1.0"))
{'torch': '2.1.0',
'torchaudio': '2.1.0',
'torchtext': '0.16.0',
'torchvision': '0.16.0'}

"""
# drop all except semantic version
ver = re.search(r"([\.\d]+)", ver).groups()[0]
# in case there remaining dot at the end - e.g "1.9.0.dev20210504"
ver = ver[:-1] if ver[-1] == "." else ver
logging.debug(f"finding ecosystem versions for: {ver}")

# find first match
for option in VERSIONS:
if option["torch"].startswith(ver):
return option

raise ValueError(f"Missing {ver} in {VERSIONS}")
return {
"torch": ver,
"torchvision": _determine_torchvision(ver),
"torchtext": _determine_torchtext(ver),
"torchaudio": _determine_torchaudio(ver),
}


def adjust(requires: List[str], pytorch_version: Optional[str] = None) -> List[str]:
"""Adjust the versions to be paired within pytorch ecosystem."""
"""Adjust the versions to be paired within pytorch ecosystem.

>>> from pprint import pprint
>>> pprint(adjust(["torch>=1.9.0", "torchvision>=0.10.0", "torchtext>=0.10.0", "torchaudio>=0.9.0"], "2.1.0"))
['torch==2.1.0',
'torchvision==0.16.0',
'torchtext==0.16.0',
'torchaudio==2.1.0']

"""
if not pytorch_version:
import torch

Expand Down Expand Up @@ -88,7 +173,12 @@ def adjust(requires: List[str], pytorch_version: Optional[str] = None) -> List[s


def _offset_print(reqs: List[str], offset: str = "\t|\t") -> str:
"""Adding offset to each line for the printing requirements."""
r"""Adding offset to each line for the printing requirements.

>>> _offset_print(["torch==2.1.0", "torchvision==0.16.0", "torchtext==0.16.0", "torchaudio==2.1.0"])
'\t|\ttorch==2.1.0\n\t|\ttorchvision==0.16.0\n\t|\ttorchtext==0.16.0\n\t|\ttorchaudio==2.1.0'

"""
reqs = [offset + r for r in reqs]
return os.linesep.join(reqs)

Expand Down
Loading