Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace tqdm with Rich for progress bar #877

Merged
merged 13 commits into from
Mar 11, 2022
6 changes: 1 addition & 5 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ show_traceback = True
warn_redundant_casts = True
warn_unused_configs = True
warn_unused_ignores = True
; Enabling this will fail on subclasses of untyped imports, e.g. tqdm
; Enabling this will fail on subclasses of untyped imports, e.g. pkginfo
; disallow_subclassing_any = True
disallow_any_generics = True
disallow_untyped_calls = True
Expand Down Expand Up @@ -33,10 +33,6 @@ ignore_missing_imports = True
[mypy-rfc3986]
ignore_missing_imports = True

[mypy-tqdm]
; https://github.com/tqdm/tqdm/issues/260
ignore_missing_imports = True

[mypy-urllib3]
; https://github.com/urllib3/urllib3/issues/867
ignore_missing_imports = True
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ install_requires=
requests >= 2.20
requests-toolbelt >= 0.8.0, != 0.9.0
urllib3 >= 1.26.0
tqdm >= 4.14
importlib_metadata >= 3.6
keyring >= 15.1
rfc3986 >= 1.4.0
Expand Down
66 changes: 40 additions & 26 deletions tests/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,21 @@ def test_package_is_registered(default_repo):


@pytest.mark.parametrize("disable_progress_bar", [True, False])
def test_disable_progress_bar_is_forwarded_to_tqdm(
def test_disable_progress_bar_is_forwarded_to_rich(
monkeypatch, tmpdir, disable_progress_bar, default_repo
):
"""Toggle display of upload progress bar."""

@contextmanager
def progressbarstub(*args, **kwargs):
def ProgressStub(*args, **kwargs):
assert "disable" in kwargs
assert kwargs["disable"] == disable_progress_bar
yield
yield pretend.stub(
add_task=lambda description, total: None,
update=lambda task_id, completed: None,
)

monkeypatch.setattr(repository, "ProgressBar", progressbarstub)
monkeypatch.setattr(repository.rich.progress, "Progress", ProgressStub)
default_repo.disable_progress_bar = disable_progress_bar

default_repo.session = pretend.stub(
Expand All @@ -230,6 +233,27 @@ def dictfunc():
default_repo.upload(package)


@pytest.mark.parametrize("finished", [False, True])
@pytest.mark.parametrize(
"task_time, formatted",
[
(None, "--:--"),
(0, "00:00"),
(59, "00:59"),
(71, "01:11"),
(4210, "1:10:10"),
],
)
def test_time_column_renders_condensed_time(finished, task_time, formatted):
if finished:
task = pretend.stub(finished=finished, finished_time=task_time)
else:
task = pretend.stub(finished=finished, time_remaining=task_time)

column = repository.CondensedTimeColumn()
assert str(column.render(task)) == formatted


def test_upload_retry(tmpdir, default_repo, capsys):
"""Print retry messages when the upload response indicates a server error."""
default_repo.disable_progress_bar = True
Expand All @@ -251,35 +275,25 @@ def test_upload_retry(tmpdir, default_repo, capsys):
metadata_dictionary=lambda: {"name": "fake"},
)

def assert_retries(output, total):
retries = [line for line in output.splitlines() if line.startswith("Received")]
assert retries == [
(
'Received "500: Internal server error" '
f"Package upload appears to have failed. Retry {i} of {total}"
)
for i in range(1, total + 1)
]

# Upload with default max_redirects of 5
default_repo.upload(package)

msg = [
(
"Uploading fake.whl\n"
'Received "500: Internal server error" '
f"Package upload appears to have failed. Retry {i} of 5"
)
for i in range(1, 6)
]

captured = capsys.readouterr()
assert captured.out == "\n".join(msg) + "\n"
assert_retries(capsys.readouterr().out, 5)
bhrutledge marked this conversation as resolved.
Show resolved Hide resolved

# Upload with custom max_redirects of 3
default_repo.upload(package, 3)

msg = [
(
"Uploading fake.whl\n"
'Received "500: Internal server error" '
f"Package upload appears to have failed. Retry {i} of 3"
)
for i in range(1, 4)
]

captured = capsys.readouterr()
assert captured.out == "\n".join(msg) + "\n"
assert_retries(capsys.readouterr().out, 3)


@pytest.mark.parametrize(
Expand Down
58 changes: 40 additions & 18 deletions twine/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from typing import Any, Dict, List, Optional, Set, Tuple, cast

import requests
import requests_toolbelt
import tqdm
import rich.progress
import rich.text
import urllib3
from requests import adapters
from requests_toolbelt.utils import user_agent
Expand All @@ -37,14 +37,28 @@
logger = logging.getLogger(__name__)


class ProgressBar(tqdm.tqdm):
def update_to(self, n: int) -> None:
"""Update the bar in the way compatible with requests-toolbelt.
class CondensedTimeColumn(rich.progress.ProgressColumn):
"""Renders estimated time remaining, or elapsed time when the task is finished."""

This is identical to tqdm.update, except ``n`` will be the current
value - not the delta as tqdm expects.
"""
self.update(n - self.n) # will also do self.n = n
# Only refresh twice a second to prevent jitter
max_refresh = 0.5

def render(self, task: rich.progress.Task) -> rich.text.Text:
"""Show time."""
style = "progress.elapsed" if task.finished else "progress.remaining"
task_time = task.finished_time if task.finished else task.time_remaining
if task_time is None:
return rich.text.Text("--:--", style=style)

# Based on https://github.com/tqdm/tqdm/blob/master/tqdm/std.py
minutes, seconds = divmod(int(task_time), 60)
hours, minutes = divmod(minutes, 60)
if hours:
formatted = f"{hours:d}:{minutes:02d}:{seconds:02d}"
else:
formatted = f"{minutes:02d}:{seconds:02d}"

return rich.text.Text(formatted, style=style)
bhrutledge marked this conversation as resolved.
Show resolved Hide resolved


class Repository:
Expand Down Expand Up @@ -159,17 +173,25 @@ def _upload(self, package: package_file.PackageFile) -> requests.Response:
("content", (package.basefilename, fp, "application/octet-stream"))
)
encoder = requests_toolbelt.MultipartEncoder(data_to_send)
with ProgressBar(
total=encoder.len,
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
file=sys.stdout,

with rich.progress.Progress(
"[progress.percentage]{task.percentage:>3.0f}%",
rich.progress.BarColumn(),
rich.progress.DownloadColumn(),
"•",
CondensedTimeColumn(),
"•",
rich.progress.TransferSpeedColumn(),
disable=self.disable_progress_bar,
) as bar:
) as progress:
task_id = progress.add_task("", total=encoder.len)

monitor = requests_toolbelt.MultipartEncoderMonitor(
encoder, lambda monitor: bar.update_to(monitor.bytes_read)
encoder,
lambda monitor: progress.update(
task_id,
completed=monitor.bytes_read,
),
)

resp = self.session.post(
Expand Down