Skip to content

Commit

Permalink
Merge branch 'develop' into keyerror-catch
Browse files Browse the repository at this point in the history
  • Loading branch information
nmanovic authored Jun 28, 2023
2 parents 1b835dd + d950d24 commit 4d91f8c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Some internal errors occurring during lambda function invocations
could be mistakenly reported as invalid requests
(<https://github.com/opencv/cvat/pull/6394>)
- \[SDK\] Loading tasks that have been cached with the PyTorch adapter
(<https://github.com/opencv/cvat/issues/6047>)

### Security
- TDB
Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/pytorch/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _initialize_task_dir(self, task: Task) -> None:
if task_dir.exists():
shutil.rmtree(task_dir)
else:
if saved_task.updated_date < task.updated_date:
if saved_task.api_model.updated_date < task.updated_date:
self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
)
Expand Down
65 changes: 57 additions & 8 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import os
from logging import Logger
from pathlib import Path
from typing import Tuple
from typing import Container, Tuple
from urllib.parse import urlparse

import pytest
from cvat_sdk import Client, models
Expand All @@ -21,7 +22,10 @@
import torchvision.transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader
except ImportError:
except ModuleNotFoundError as e:
if e.name.split(".")[0] not in {"torch", "torchvision"}:
raise

cvatpt = None

from shared.utils.helpers import generate_image_files
Expand All @@ -43,11 +47,18 @@ def _common_setup(
api_client.configuration.logger[k] = logger


def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None:
def disabled_request(*args, **kwargs):
raise RuntimeError("Disabled!")
def _restrict_api_requests(
monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = ()
) -> None:
original_request = RESTClientObject.request

def restricted_request(self, method, url, *args, **kwargs):
parsed_url = urlparse(url)
if parsed_url.path in allow_paths:
return original_request(self, method, url, *args, **kwargs)
raise RuntimeError("Disallowed!")

monkeypatch.setattr(RESTClientObject, "request", disabled_request)
monkeypatch.setattr(RESTClientObject, "request", restricted_request)


@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
Expand Down Expand Up @@ -243,7 +254,7 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch):

fresh_samples = list(dataset)

_disable_api_requests(monkeypatch)
_restrict_api_requests(monkeypatch)

dataset = cvatpt.TaskVisionDataset(
self.client,
Expand All @@ -255,6 +266,44 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch):

assert fresh_samples == cached_samples

def test_update(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)

# Recreating the dataset should only result in minimal requests.
_restrict_api_requests(
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
)

dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)

assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]

# After an update, the annotations should be redownloaded.
monkeypatch.undo()

self.task.update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(
id=dataset[5][1].annotations.tags[0].id, frame=5, label_id=self.label_ids[1]
),
]
)
)

dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)

assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[1]


@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestProjectVisionDataset:
Expand Down Expand Up @@ -398,7 +447,7 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch):

fresh_samples = list(dataset)

_disable_api_requests(monkeypatch)
_restrict_api_requests(monkeypatch)

dataset = cvatpt.ProjectVisionDataset(
self.client,
Expand Down

0 comments on commit 4d91f8c

Please sign in to comment.