Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SDK UX with task creation #5502

Merged
merged 18 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def tasks_create(
self,
name: str,
labels: List[Dict[str, str]],
resource_type: ResourceType,
resources: Sequence[str],
*,
resource_type: ResourceType = ResourceType.LOCAL,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = 2,
Expand Down
67 changes: 41 additions & 26 deletions cvat-sdk/cvat_sdk/core/proxies/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ class Task(

def upload_data(
self,
resource_type: ResourceType,
SpecLad marked this conversation as resolved.
Show resolved Hide resolved
resources: Sequence[StrPath],
*,
resource_type: ResourceType = ResourceType.LOCAL,
pbar: Optional[ProgressReporter] = None,
params: Optional[Dict[str, Any]] = None,
wait_for_completion: bool = True,
status_check_period: Optional[int] = None,
) -> None:
"""
Add local, remote, or shared files to an existing task.
Expand Down Expand Up @@ -119,6 +121,35 @@ def upload_data(
url, list(map(Path, resources)), pbar=pbar, **data
)

if wait_for_completion:
if status_check_period is None:
status_check_period = self._client.config.status_check_period

self._client.logger.info("Awaiting for task %s creation...", self.id)
status: models.RqStatus = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(self.id)

self._client.logger.info(
"Task %s creation status=%s, message=%s",
SpecLad marked this conversation as resolved.
Show resolved Hide resolved
self.id,
status.state.value,
status.message,
)

if (
status.state.value
== models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]
):
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)

status = status.state.value
SpecLad marked this conversation as resolved.
Show resolved Hide resolved

self.fetch()

def import_annotations(
self,
format_name: str,
Expand Down Expand Up @@ -294,9 +325,9 @@ class TasksRepo(
def create_from_data(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
*,
resource_type: ResourceType = ResourceType.LOCAL,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
Expand All @@ -311,9 +342,6 @@ def create_from_data(

Returns: id of the created task
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period

if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise exceptions.ApiValueError(
"Can't set labels to a task inside a project. "
Expand All @@ -324,27 +352,14 @@ def create_from_data(
task = self.create(spec=spec)
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)

task.upload_data(resource_type, resources, pbar=pbar, params=data_params)

self._client.logger.info("Awaiting for task %s creation...", task.id)
status: models.RqStatus = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(task.id)

self._client.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)

if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)

status = status.state.value
task.upload_data(
resource_type=resource_type,
resources=resources,
pbar=pbar,
params=data_params,
wait_for_completion=True,
status_check_period=status_check_period,
)

if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
Expand Down
77 changes: 41 additions & 36 deletions cvat-sdk/cvat_sdk/core/uploading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import os
from contextlib import ExitStack, closing
from contextlib import ExitStack
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -206,40 +206,6 @@ def _wait_for_completion(
positive_statuses=positive_statuses,
)

def _split_files_by_requests(
self, filenames: List[Path]
) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]:
bulk_files: Dict[str, int] = {}
separate_files: Dict[str, int] = {}

# sort by size
for filename in filenames:
filename = filename.resolve()
file_size = filename.stat().st_size
if MAX_REQUEST_SIZE < file_size:
separate_files[filename] = file_size
else:
bulk_files[filename] = file_size

total_size = sum(bulk_files.values()) + sum(separate_files.values())

# group small files by requests
bulk_file_groups: List[Tuple[List[str], int]] = []
current_group_size: int = 0
current_group: List[str] = []
for filename, file_size in bulk_files.items():
if MAX_REQUEST_SIZE < current_group_size + file_size:
bulk_file_groups.append((current_group, current_group_size))
current_group_size = 0
current_group = []

current_group.append(filename)
current_group_size += file_size
if current_group:
bulk_file_groups.append((current_group, current_group_size))

return bulk_file_groups, separate_files, total_size

@staticmethod
def _make_tus_uploader(api_client: ApiClient, url: str, **kwargs):
# Add headers required by CVAT server
Expand Down Expand Up @@ -353,6 +319,10 @@ def upload_file_and_wait(


class DataUploader(Uploader):
def __init__(self, client: Client, *, max_request_size: int = MAX_REQUEST_SIZE):
super().__init__(client)
self.max_request_size = max_request_size

def upload_files(
self,
url: str,
Expand All @@ -374,7 +344,7 @@ def upload_files(
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
os.fspath(filename),
es.enter_context(closing(open(filename, "rb"))).read(),
es.enter_context(open(filename, "rb")).read(),
SpecLad marked this conversation as resolved.
Show resolved Hide resolved
)
response = self._client.api_client.rest_client.POST(
url,
Expand All @@ -401,3 +371,38 @@ def upload_files(
)

self._tus_finish_upload(url, fields=kwargs)

def _split_files_by_requests(
self, filenames: List[Path]
) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]:
bulk_files: Dict[str, int] = {}
separate_files: Dict[str, int] = {}
max_request_size = self.max_request_size

# sort by size
for filename in filenames:
filename = filename.resolve()
file_size = filename.stat().st_size
if max_request_size < file_size:
separate_files[filename] = file_size
else:
bulk_files[filename] = file_size

total_size = sum(bulk_files.values()) + sum(separate_files.values())

# group small files by requests
bulk_file_groups: List[Tuple[List[str], int]] = []
current_group_size: int = 0
current_group: List[str] = []
for filename, file_size in bulk_files.items():
if max_request_size < current_group_size + file_size:
bulk_file_groups.append((current_group, current_group_size))
current_group_size = 0
current_group = []

current_group.append(filename)
current_group_size += file_size
if current_group:
bulk_file_groups.append((current_group, current_group_size))

return bulk_file_groups, separate_files, total_size
4 changes: 2 additions & 2 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _create_task(self):
models.PatchedLabelRequest(name="car"),
],
),
ResourceType.LOCAL,
list(map(os.fspath, image_paths)),
resource_type=ResourceType.LOCAL,
resources=list(map(os.fspath, image_paths)),
data_params={"chunk_size": 3},
)

Expand Down
33 changes: 32 additions & 1 deletion tests/python/sdk/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def fxt_new_task(self, fxt_image_file: Path):
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[fxt_image_file],
data_params={"image_quality": 80},
)
Expand Down Expand Up @@ -202,6 +201,38 @@ def test_can_create_task_with_git_repo(self, fxt_image_file: Path):
assert response_json["format"] == "CVAT for images 1.1"
assert response_json["lfs"] is False

def test_can_upload_data_to_empty_task(self):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)

task = self.client.tasks.create(
{
"name": f"test task",
"labels": [{"name": "car"}],
}
)

data_params = {
"image_quality": 75,
}

task_files = generate_image_files(7)
for i, f in enumerate(task_files):
fname = self.tmp_path / f.name
fname.write_bytes(f.getvalue())
task_files[i] = fname

task.upload_data(
resources=task_files,
resource_type=ResourceType.LOCAL,
params=data_params,
pbar=pbar,
)

assert task.size == 7
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == ""

def test_can_retrieve_task(self, fxt_new_task: Task):
task_id = fxt_new_task.id

Expand Down