Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp the progress reporting API #6556

Merged
merged 4 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[Unreleased]
### Added

- TDB
- \[SDK\] A `DeferredTqdmProgressReporter` class, which doesn't have glitchy output
like `TqdmProgressReporter` in certain circumstances
(<https://github.com/opencv/cvat/pull/6556>)

### Changed

- TDB
- \[SDK\] Custom `ProgressReporter` implementations should now override `start2` instead of `start`
(<https://github.com/opencv/cvat/pull/6556>)

### Deprecated

Expand Down
22 changes: 10 additions & 12 deletions cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import json
from typing import Dict, List, Sequence, Tuple

import tqdm
from cvat_sdk import Client, models
from cvat_sdk.core.helpers import TqdmProgressReporter
from cvat_sdk.core.helpers import DeferredTqdmProgressReporter
from cvat_sdk.core.proxies.tasks import ResourceType


Expand Down Expand Up @@ -67,7 +66,7 @@ def tasks_create(
status_check_period=status_check_period,
dataset_repository_url=dataset_repository_url,
use_lfs=lfs,
pbar=self._make_pbar(),
pbar=DeferredTqdmProgressReporter(),
)
print("Created task id", task.id)

Expand Down Expand Up @@ -109,7 +108,7 @@ def tasks_dump(
self.client.tasks.retrieve(obj_id=task_id).export_dataset(
format_name=fileformat,
filename=filename,
pbar=self._make_pbar(),
pbar=DeferredTqdmProgressReporter(),
status_check_period=status_check_period,
include_images=include_images,
)
Expand All @@ -123,22 +122,21 @@ def tasks_upload(
format_name=fileformat,
filename=filename,
status_check_period=status_check_period,
pbar=self._make_pbar(),
pbar=DeferredTqdmProgressReporter(),
)

def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None:
"""Download a task backup"""
self.client.tasks.retrieve(obj_id=task_id).download_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
filename=filename,
status_check_period=status_check_period,
pbar=DeferredTqdmProgressReporter(),
)

def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None:
"""Import a task from a backup file"""
self.client.tasks.create_from_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
)

def _make_pbar(self, title: str = None) -> TqdmProgressReporter:
return TqdmProgressReporter(
tqdm.tqdm(unit_scale=True, unit="B", unit_divisor=1024, desc=title)
filename=filename,
status_check_period=status_check_period,
pbar=DeferredTqdmProgressReporter(),
)
16 changes: 8 additions & 8 deletions cvat-sdk/cvat_sdk/core/downloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional

from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter
from cvat_sdk.core.utils import atomic_writer

if TYPE_CHECKING:
Expand Down Expand Up @@ -41,6 +41,9 @@ def download_file(

assert not output_path.exists()

if pbar is None:
pbar = NullProgressReporter()

response = self._client.api_client.rest_client.GET(
url,
_request_timeout=timeout,
Expand All @@ -53,18 +56,15 @@ def download_file(
except ValueError:
file_size = None

with atomic_writer(output_path, "wb") as fd:
if pbar is not None:
pbar.start(file_size, desc="Downloading")

with atomic_writer(output_path, "wb") as fd, pbar.task(
total=file_size, desc="Downloading", unit_scale=True, unit="B", unit_divisor=1024
):
while True:
chunk = response.read(amt=CHUNK_SIZE, decode_content=False)
if not chunk:
break

if pbar is not None:
pbar.advance(len(chunk))

pbar.advance(len(chunk))
fd.write(chunk)

def prepare_and_download_file_from_endpoint(
Expand Down
98 changes: 68 additions & 30 deletions cvat-sdk/cvat_sdk/core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import io
import json
import warnings
from typing import Any, Dict, Iterable, List, Optional, Union

import tqdm
import urllib3

from cvat_sdk import exceptions
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.progress import BaseProgressReporter, ProgressReporter


def get_paginated_collection(
Expand Down Expand Up @@ -46,62 +47,99 @@ def get_paginated_collection(
return results


class TqdmProgressReporter(ProgressReporter):
class _BaseTqdmProgressReporter(BaseProgressReporter):
tqdm: Optional[tqdm.tqdm]

def report_status(self, progress: int):
super().report_status(progress)
self.tqdm.update(progress - self.tqdm.n)

def advance(self, delta: int):
super().advance(delta)
self.tqdm.update(delta)


class TqdmProgressReporter(_BaseTqdmProgressReporter):
def __init__(self, instance: tqdm.tqdm) -> None:
super().__init__()
warnings.warn(f"use {DeferredTqdmProgressReporter.__name__} instead", DeprecationWarning)

self.tqdm = instance

@property
def period(self) -> float:
return 0
def start2(self, total: int, *, desc: Optional[str] = None, **kwargs) -> None:
super().start2(total=total, desc=desc, **kwargs)

def start(self, total: int, *, desc: Optional[str] = None):
self.tqdm.reset(total)
self.tqdm.set_description_str(desc)

def report_status(self, progress: int):
self.tqdm.update(progress - self.tqdm.n)
def finish(self):
self.tqdm.refresh()
super().finish()

def advance(self, delta: int):
self.tqdm.update(delta)

class DeferredTqdmProgressReporter(_BaseTqdmProgressReporter):
def __init__(self, tqdm_args: Optional[dict] = None) -> None:
super().__init__()
self.tqdm_args = tqdm_args or {}
self.tqdm = None

def start2(
self,
total: int,
*,
desc: Optional[str] = None,
unit: str = "it",
unit_scale: bool = False,
unit_divisor: int = 1000,
**kwargs,
) -> None:
super().start2(
total=total,
desc=desc,
unit=unit,
unit_scale=unit_scale,
unit_divisor=unit_divisor,
**kwargs,
)
assert not self.tqdm

self.tqdm = tqdm.tqdm(
**self.tqdm_args,
total=total,
desc=desc,
unit=unit,
unit_scale=unit_scale,
unit_divisor=unit_divisor,
)

def finish(self):
self.tqdm.refresh()
self.tqdm.close()
self.tqdm = None
super().finish()


class StreamWithProgress:
def __init__(self, stream: io.RawIOBase, pbar: ProgressReporter, length: Optional[int] = None):
def __init__(self, stream: io.RawIOBase, pbar: ProgressReporter):
self.stream = stream
self.pbar = pbar

if hasattr(stream, "__len__"):
length = len(stream)

self.length = length
pbar.start(length)
assert self.stream.tell() == 0

def read(self, size=-1):
chunk = self.stream.read(size)
if chunk is not None:
self.pbar.advance(len(chunk))
return chunk

def __len__(self):
return self.length
def seek(self, pos: int, whence: int = io.SEEK_SET) -> None:
old_pos = self.stream.tell()
new_pos = self.stream.seek(pos, whence)
self.pbar.advance(new_pos - old_pos)
return new_pos

def seek(self, pos, start=0):
self.stream.seek(pos, start)
self.pbar.report_status(pos)

def tell(self):
def tell(self) -> int:
return self.stream.tell()

def __enter__(self) -> StreamWithProgress:
return self

def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
self.pbar.finish()


def expect_status(codes: Union[int, Iterable[int]], response: urllib3.HTTPResponse) -> None:
if not hasattr(codes, "__iter__"):
Expand Down
Loading
Loading