diff --git a/src/fromager/progress.py b/src/fromager/progress.py index f0a26bf2..32c22c12 100644 --- a/src/fromager/progress.py +++ b/src/fromager/progress.py @@ -1,13 +1,21 @@ import sys import typing +from types import TracebackType import tqdm as _tqdm __all__ = ("progress",) +# fix for runtime errors caused by inheriting classes that are generic in stubs but not runtime +# https://mypy.readthedocs.io/en/latest/runtime_troubles.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +if typing.TYPE_CHECKING: + ProgressBarTqdm = _tqdm.tqdm[int] | _tqdm.tqdm[typing.Never] | None +else: + ProgressBarTqdm = _tqdm.tqdm | None + class Progressbar: - def __init__(self, tqdm: _tqdm.tqdm | None) -> None: + def __init__(self, tqdm: ProgressBarTqdm) -> None: self._tqdm = tqdm def update_total(self, n: int) -> None: @@ -24,12 +32,19 @@ def __enter__(self) -> "Progressbar": self._tqdm.__enter__() return self - def __exit__(self, typ, value, traceback) -> None: + def __exit__( + self, + typ: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: if self._tqdm is not None: self._tqdm.__exit__(typ, value, traceback) -def progress(it: typing.Iterable, *, unit="pkg", **kwargs: typing.Any) -> typing.Any: +def progress( + it: typing.Iterable[typing.Any], *, unit: str = "pkg", **kwargs: typing.Any +) -> typing.Any: """tqdm progress bar""" if not sys.stdout.isatty(): # wider progress bar in CI @@ -37,7 +52,9 @@ def progress(it: typing.Iterable, *, unit="pkg", **kwargs: typing.Any) -> typing yield from _tqdm.tqdm(it, unit=unit, **kwargs) -def progress_context(total: int, *, unit="pkg", **kwargs: typing.Any) -> Progressbar: +def progress_context( + total: int, *, unit: str = "pkg", **kwargs: typing.Any +) -> Progressbar: """Context manager for progress bar with dynamic updates""" if not sys.stdout.isatty(): # wider progress bar in CI