Skip to content

Commit

Permalink
🧑‍💻 torch.compile is not compatible with Windows. (TissueImageAnaly…
Browse files Browse the repository at this point in the history
…tics#888)

- `torch.compile` is not currently compatible with Windows. See pytorch/pytorch#122094
  • Loading branch information
shaneahmed authored Nov 29, 2024
1 parent 5f1cecb commit 4a1940d
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions tiatoolbox/models/architecture/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import sys
from typing import NoReturn

import numpy as np
import torch
Expand All @@ -12,13 +11,17 @@
from tiatoolbox import logger


def is_torch_compile_compatible() -> NoReturn:
def is_torch_compile_compatible() -> bool:
"""Check if the current GPU is compatible with torch-compile.
Returns:
True if current GPU is compatible with torch-compile, False otherwise.
Raises:
Warning if GPU is not compatible with `torch.compile`.
"""
gpu_compatibility = True
if torch.cuda.is_available(): # pragma: no cover
device_cap = torch.cuda.get_device_capability()
if device_cap not in ((7, 0), (8, 0), (9, 0)):
Expand All @@ -28,13 +31,17 @@ def is_torch_compile_compatible() -> NoReturn:
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False
else:
logger.warning(
"No GPU detected or cuda not installed, "
"torch.compile is only supported on selected NVIDIA GPUs. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False

return gpu_compatibility


def compile_model(
Expand Down Expand Up @@ -68,12 +75,24 @@ def compile_model(
return model

# Check if GPU is compatible with torch.compile
is_torch_compile_compatible()
gpu_compatibility = is_torch_compile_compatible()

if not gpu_compatibility:
return model

if sys.platform == "win32": # pragma: no cover
msg = (
"`torch.compile` is not supported on Windows. Please see "
"https://github.com/pytorch/pytorch/issues/122094."
)
logger.warning(msg=msg)
return model

# This check will be removed when torch.compile is supported in Python 3.12+
if sys.version_info > (3, 12): # pragma: no cover
msg = "torch-compile is currently not supported in Python 3.12+."
logger.warning(
("torch-compile is currently not supported in Python 3.12+. ",),
msg=msg,
)
return model

Expand Down

0 comments on commit 4a1940d

Please sign in to comment.