Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add testing for PyTorch 2.4 (Trainer) #20010

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ jobs:
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "lightning"
"Lightning | future":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
PACKAGE_NAME: "lightning"
pool: lit-rtx-3090
variables:
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
Expand All @@ -76,9 +79,12 @@ jobs:
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))')
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
displayName: "set env. vars"
- bash: |
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
condition: endsWith(variables['Agent.JobName'], 'future')
displayName: "set env. vars 4 future"

Expand Down Expand Up @@ -107,7 +113,7 @@ jobs:

- bash: |
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
displayName: "Install package & dependencies"

- bash: pip uninstall -y lightning
Expand Down
4 changes: 4 additions & 0 deletions .github/checkgroup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,15 @@ subprojects:
- "build-cuda (3.10, 2.2, 12.1.0)"
- "build-cuda (3.11, 2.1, 12.1.0)"
- "build-cuda (3.11, 2.2, 12.1.0)"
- "build-cuda (3.11, 2.3, 12.1.0)"
- "build-cuda (3.11, 2.4, 12.1.0)"
#- "build-NGC"
- "build-pl (3.10, 2.1, 12.1.0)"
- "build-pl (3.10, 2.2, 12.1.0)"
- "build-pl (3.11, 2.1, 12.1.0)"
- "build-pl (3.11, 2.2, 12.1.0)"
- "build-pl (3.11, 2.3, 12.1.0)"
- "build-pl (3.11, 2.4, 12.1.0)"

# SECTION: lightning_fabric

Expand Down
9 changes: 7 additions & 2 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
Expand Down Expand Up @@ -82,7 +85,7 @@ jobs:
PACKAGE_NAME: ${{ matrix.pkg-name }}
TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
PYPI_CACHE_DIR: "_pip-wheels"
# TODO: Remove this - Enable running MPS tests on this platform
Expand Down Expand Up @@ -124,11 +127,13 @@ jobs:
- name: Env. variables
run: |
# Switch PyTorch URL
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV
# Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support"
python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV

- name: Install package & dependencies
timeout-minutes: 20
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ jobs:
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -74,7 +76,7 @@ jobs:
tags = [f"latest-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
if ver:
tags += [f"{ver}-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
if py_ver == '3.10' and pt_ver == '2.1' and cuda_ver == '12.1.0':
if py_ver == '3.11' and pt_ver == '2.3' and cuda_ver == '12.1.0':
tags += ["latest"]

tags = [f"{repo}:{tag}" for tag in tags]
Expand Down Expand Up @@ -108,6 +110,7 @@ jobs:
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# - { python_version: "3.12", pytorch_version: "2.2", cuda_version: "12.1.0" } # todo: pending on `onnxruntime`
steps:
- uses: actions/checkout@v4
Expand Down
5 changes: 3 additions & 2 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}

ARG PYTHON_VERSION=3.10
ARG PYTORCH_VERSION=2.1
ARG MAX_ALLOWED_NCCL=2.17.1
ARG MAX_ALLOWED_NCCL=2.22.3

SHELL ["/bin/bash", "-c"]
# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
Expand Down Expand Up @@ -92,7 +92,8 @@ RUN \
-r requirements/pytorch/test.txt \
-r requirements/pytorch/strategies.txt \
--find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch_test.html"
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch" \
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/pytorch-triton"

RUN \
# Show what we have
Expand Down
12 changes: 12 additions & 0 deletions docs/source-pytorch/versioning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ The table below indicates the coverage of tested versions in our CI. Versions ou
- ``torch``
- ``torchmetrics``
- Python
* - 2.4
- 2.4
- 2.4
- ≥2.1, ≤2.4
- ≥0.7.0
- ≥3.9, ≤3.12
* - 2.3
- 2.3
- 2.3
- ≥2.0, ≤2.3
- ≥0.7.0
- ≥3.8, ≤3.11
* - 2.2
- 2.2
- 2.2
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

numpy >=1.21.0, <1.27.0
torch >=2.1.0, <2.4.0
torch >=2.1.0, <2.5.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2024.4.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

requests <2.32.0
torchvision >=0.16.0, <0.19.0
torchvision >=0.16.0, <0.20.0
ipython[all] <8.15.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0
9 changes: 8 additions & 1 deletion src/lightning/fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
def _load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
weights_only: bool = False,
) -> Any:
"""Loads a checkpoint.

Expand All @@ -46,15 +47,21 @@ def _load(
return torch.load(
path_or_url,
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
weights_only=weights_only,
)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type]
weights_only=weights_only,
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location) # type: ignore[arg-type]
return torch.load(
f,
map_location=map_location, # type: ignore[arg-type]
weights_only=weights_only,
)


def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
Expand All @@ -39,7 +40,7 @@ def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
Expand All @@ -49,7 +50,7 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
25 changes: 21 additions & 4 deletions src/lightning/pytorch/profilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing_extensions import override

from lightning.fabric.accelerators.cuda import is_cuda_available
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.profilers.profiler import Profiler
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
Expand Down Expand Up @@ -295,7 +296,7 @@ def __init__(
self._emit_nvtx = emit_nvtx
self._export_to_chrome = export_to_chrome
self._row_limit = row_limit
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
self._sort_by_key = sort_by_key or _default_sort_by_key(profiler_kwargs)
self._record_module_names = record_module_names
self._profiler_kwargs = profiler_kwargs
self._table_kwargs = table_kwargs if table_kwargs is not None else {}
Expand Down Expand Up @@ -403,10 +404,16 @@ def _default_activities(self) -> List["ProfilerActivity"]:
activities: List[ProfilerActivity] = []
if not _KINETO_AVAILABLE:
return activities
if self._profiler_kwargs.get("use_cpu", True):
if _TORCH_GREATER_EQUAL_2_4:
activities.append(ProfilerActivity.CPU)
if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
activities.append(ProfilerActivity.CUDA)
if is_cuda_available():
activities.append(ProfilerActivity.CUDA)
else:
# `use_cpu` and `use_cuda` are deprecated in PyTorch >= 2.4
if self._profiler_kwargs.get("use_cpu", True):
activities.append(ProfilerActivity.CPU)
if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
activities.append(ProfilerActivity.CUDA)
return activities

@override
Expand Down Expand Up @@ -565,3 +572,13 @@ def teardown(self, stage: Optional[str]) -> None:
self._recording_map = {}

super().teardown(stage=stage)


def _default_sort_by_key(profiler_kwargs: dict) -> str:
activities = profiler_kwargs.get("activities", [])
is_cuda = (
profiler_kwargs.get("use_cuda", False) # `use_cuda` is deprecated in PyTorch >= 2.4
or (activities and ProfilerActivity.CUDA in activities)
or (not activities and is_cuda_available())
)
return f"{'cuda' if is_cuda else 'cpu'}_time_total"
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
"""
# NOTE: `get_extra_results` needs to be called before
callback_metrics_bytes = extra["callback_metrics_bytes"]
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes), weights_only=True)
trainer.callback_metrics.update(callback_metrics)

@override
Expand Down
10 changes: 5 additions & 5 deletions src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.init import _materialize_distributed_module
from lightning.fabric.utilities.load import _METADATA_FILENAME
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand All @@ -64,7 +64,7 @@ class ModelParallelStrategy(ParallelStrategy):
Currently supports up to 2D parallelism. Specifically, it supports the combination of
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
experimental in PyTorch (see https://pytorch.org/docs/stable/distributed.tensor.parallel.html).
Requires PyTorch 2.3 or newer.
Requires PyTorch 2.4 or newer.

Arguments:
data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which
Expand All @@ -86,8 +86,8 @@ def __init__(
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
super().__init__()
if not _TORCH_GREATER_EQUAL_2_3:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.")
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
self._data_parallel_size = data_parallel_size
self._tensor_parallel_size = tensor_parallel_size
self._save_distributed_checkpoint = save_distributed_checkpoint
Expand Down Expand Up @@ -170,7 +170,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
if any(isinstance(mod, FullyShardedDataParallel) for mod in self.model.modules()):
raise TypeError(
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
)

_materialize_distributed_module(self.model, self.root_device)
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def thread_police_duuu_daaa_duuu_daaa():
elif (
thread.name == "QueueFeederThread" # tensorboardX
or thread.name == "QueueManagerThread" # torch.compile
or "(_read_thread)" in thread.name # torch.compile
):
thread.join(timeout=20)
elif isinstance(thread, TMonitor):
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/models/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import torch
from fsspec.implementations.local import LocalFileSystem
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel

from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
from tests_pytorch.helpers.runif import RunIf


@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
def test_torchscript_input_output(modelclass):
"""Test that scripted LightningModule forward works."""
Expand All @@ -45,6 +47,7 @@ def test_torchscript_input_output(modelclass):
assert torch.allclose(script_output, model_output)


@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
def test_torchscript_example_input_output_trace(modelclass):
"""Test that traced LightningModule forward works with example_input_array."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from lightning.fabric import seed_everything
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins.precision import MixedPrecision
Expand All @@ -28,7 +29,8 @@ def __init__(self, fused=False):
self.fused = fused

def configure_optimizers(self):
assert isinstance(self.trainer.precision_plugin.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(self.trainer.precision_plugin.scaler, scaler_cls)
return torch.optim.Adam(self.parameters(), lr=1.0, fused=self.fused)


Expand Down
Loading
Loading