Skip to content

Commit

Permalink
Merge branch 'develop' into bs/increase_assets_limits
Browse files Browse the repository at this point in the history
  • Loading branch information
bsekachev committed Jul 28, 2023
2 parents 8105b7b + 4b86439 commit ae3db31
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 163 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[Unreleased]
### Added

- TDB
- \[SDK\] A `DeferredTqdmProgressReporter` class, which doesn't have glitchy output
like `TqdmProgressReporter` in certain circumstances
(<https://github.com/opencv/cvat/pull/6556>)

### Changed

- Increased default guide assets limitations (30 assets, up to 10Mb each)
(<https://github.com/opencv/cvat/pull/6575>)
- \[SDK\] Custom `ProgressReporter` implementations should now override `start2` instead of `start`
(<https://github.com/opencv/cvat/pull/6556>)


### Deprecated

Expand Down
22 changes: 10 additions & 12 deletions cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import json
from typing import Dict, List, Sequence, Tuple

import tqdm
from cvat_sdk import Client, models
from cvat_sdk.core.helpers import TqdmProgressReporter
from cvat_sdk.core.helpers import DeferredTqdmProgressReporter
from cvat_sdk.core.proxies.tasks import ResourceType


Expand Down Expand Up @@ -67,7 +66,7 @@ def tasks_create(
status_check_period=status_check_period,
dataset_repository_url=dataset_repository_url,
use_lfs=lfs,
pbar=self._make_pbar(),
pbar=DeferredTqdmProgressReporter(),
)
print("Created task id", task.id)

Expand Down Expand Up @@ -109,7 +108,7 @@ def tasks_dump(
self.client.tasks.retrieve(obj_id=task_id).export_dataset(
format_name=fileformat,
filename=filename,
pbar=self._make_pbar(),
pbar=DeferredTqdmProgressReporter(),
status_check_period=status_check_period,
include_images=include_images,
)
Expand All @@ -123,22 +122,21 @@ def tasks_upload(
format_name=fileformat,
filename=filename,
status_check_period=status_check_period,
pbar=self._make_pbar(),
pbar=DeferredTqdmProgressReporter(),
)

def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None:
"""Download a task backup"""
self.client.tasks.retrieve(obj_id=task_id).download_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
filename=filename,
status_check_period=status_check_period,
pbar=DeferredTqdmProgressReporter(),
)

def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None:
"""Import a task from a backup file"""
self.client.tasks.create_from_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
)

def _make_pbar(self, title: str = None) -> TqdmProgressReporter:
return TqdmProgressReporter(
tqdm.tqdm(unit_scale=True, unit="B", unit_divisor=1024, desc=title)
filename=filename,
status_check_period=status_check_period,
pbar=DeferredTqdmProgressReporter(),
)
16 changes: 8 additions & 8 deletions cvat-sdk/cvat_sdk/core/downloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional

from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter
from cvat_sdk.core.utils import atomic_writer

if TYPE_CHECKING:
Expand Down Expand Up @@ -41,6 +41,9 @@ def download_file(

assert not output_path.exists()

if pbar is None:
pbar = NullProgressReporter()

response = self._client.api_client.rest_client.GET(
url,
_request_timeout=timeout,
Expand All @@ -53,18 +56,15 @@ def download_file(
except ValueError:
file_size = None

with atomic_writer(output_path, "wb") as fd:
if pbar is not None:
pbar.start(file_size, desc="Downloading")

with atomic_writer(output_path, "wb") as fd, pbar.task(
total=file_size, desc="Downloading", unit_scale=True, unit="B", unit_divisor=1024
):
while True:
chunk = response.read(amt=CHUNK_SIZE, decode_content=False)
if not chunk:
break

if pbar is not None:
pbar.advance(len(chunk))

pbar.advance(len(chunk))
fd.write(chunk)

def prepare_and_download_file_from_endpoint(
Expand Down
98 changes: 68 additions & 30 deletions cvat-sdk/cvat_sdk/core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import io
import json
import warnings
from typing import Any, Dict, Iterable, List, Optional, Union

import tqdm
import urllib3

from cvat_sdk import exceptions
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.progress import BaseProgressReporter, ProgressReporter


def get_paginated_collection(
Expand Down Expand Up @@ -46,62 +47,99 @@ def get_paginated_collection(
return results


class TqdmProgressReporter(ProgressReporter):
class _BaseTqdmProgressReporter(BaseProgressReporter):
tqdm: Optional[tqdm.tqdm]

def report_status(self, progress: int):
super().report_status(progress)
self.tqdm.update(progress - self.tqdm.n)

def advance(self, delta: int):
super().advance(delta)
self.tqdm.update(delta)


class TqdmProgressReporter(_BaseTqdmProgressReporter):
def __init__(self, instance: tqdm.tqdm) -> None:
super().__init__()
warnings.warn(f"use {DeferredTqdmProgressReporter.__name__} instead", DeprecationWarning)

self.tqdm = instance

@property
def period(self) -> float:
return 0
def start2(self, total: int, *, desc: Optional[str] = None, **kwargs) -> None:
super().start2(total=total, desc=desc, **kwargs)

def start(self, total: int, *, desc: Optional[str] = None):
self.tqdm.reset(total)
self.tqdm.set_description_str(desc)

def report_status(self, progress: int):
self.tqdm.update(progress - self.tqdm.n)
def finish(self):
self.tqdm.refresh()
super().finish()

def advance(self, delta: int):
self.tqdm.update(delta)

class DeferredTqdmProgressReporter(_BaseTqdmProgressReporter):
def __init__(self, tqdm_args: Optional[dict] = None) -> None:
super().__init__()
self.tqdm_args = tqdm_args or {}
self.tqdm = None

def start2(
self,
total: int,
*,
desc: Optional[str] = None,
unit: str = "it",
unit_scale: bool = False,
unit_divisor: int = 1000,
**kwargs,
) -> None:
super().start2(
total=total,
desc=desc,
unit=unit,
unit_scale=unit_scale,
unit_divisor=unit_divisor,
**kwargs,
)
assert not self.tqdm

self.tqdm = tqdm.tqdm(
**self.tqdm_args,
total=total,
desc=desc,
unit=unit,
unit_scale=unit_scale,
unit_divisor=unit_divisor,
)

def finish(self):
self.tqdm.refresh()
self.tqdm.close()
self.tqdm = None
super().finish()


class StreamWithProgress:
def __init__(self, stream: io.RawIOBase, pbar: ProgressReporter, length: Optional[int] = None):
def __init__(self, stream: io.RawIOBase, pbar: ProgressReporter):
self.stream = stream
self.pbar = pbar

if hasattr(stream, "__len__"):
length = len(stream)

self.length = length
pbar.start(length)
assert self.stream.tell() == 0

def read(self, size=-1):
chunk = self.stream.read(size)
if chunk is not None:
self.pbar.advance(len(chunk))
return chunk

def __len__(self):
return self.length
def seek(self, pos: int, whence: int = io.SEEK_SET) -> None:
old_pos = self.stream.tell()
new_pos = self.stream.seek(pos, whence)
self.pbar.advance(new_pos - old_pos)
return new_pos

def seek(self, pos, start=0):
self.stream.seek(pos, start)
self.pbar.report_status(pos)

def tell(self):
def tell(self) -> int:
return self.stream.tell()

def __enter__(self) -> StreamWithProgress:
return self

def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
self.pbar.finish()


def expect_status(codes: Union[int, Iterable[int]], response: urllib3.HTTPResponse) -> None:
if not hasattr(codes, "__iter__"):
Expand Down
Loading

0 comments on commit ae3db31

Please sign in to comment.