Skip to content

Commit

Permalink
feat(utility): add "URL" class for get and update url
Browse files Browse the repository at this point in the history
PR Closed: #1088
  • Loading branch information
zhen.chen authored and AChenQ committed Nov 5, 2021
1 parent 1b75b50 commit 7d1fa11
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 113 deletions.
10 changes: 9 additions & 1 deletion tensorbay/client/cloud_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from tensorbay.client.requests import Client
from tensorbay.dataset import AuthData
from tensorbay.utility import URL


class CloudClient:
Expand Down Expand Up @@ -61,5 +62,12 @@ def list_auth_data(self, path: str = "") -> List[AuthData]:
"""
return [
AuthData(cloud_path, _url_getter=self._get_url) for cloud_path in self._list_files(path)
AuthData(
cloud_path,
url=URL.from_getter(
lambda c=cloud_path: self._get_url(c),
lambda c=cloud_path: self._get_url(c), # type: ignore[misc]
),
)
for cloud_path in self._list_files(path)
]
47 changes: 14 additions & 33 deletions tensorbay/client/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import time
from copy import deepcopy
from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Tuple, Union

import filetype
from requests_toolbelt import MultipartEncoder
Expand All @@ -40,7 +40,7 @@
from tensorbay.exception import FrameError, InvalidParamsError, ResponseError
from tensorbay.label import Label
from tensorbay.sensor.sensor import Sensor, Sensors
from tensorbay.utility import FileMixin, chunked, locked
from tensorbay.utility import URL, FileMixin, chunked, locked

if TYPE_CHECKING:
from tensorbay.client.dataset import DatasetClient, FusionDatasetClient
Expand All @@ -49,18 +49,6 @@
_MASK_KEYS = ("semantic_mask", "instance_mask", "panoptic_mask")


class _UrlGetters:
def __init__(self, urls: LazyPage[str]) -> None:
self._urls = urls

def __getitem__(self, index: int) -> Callable[[str], str]:
return lambda _: self._urls.items[index].get()

def update(self) -> None:
"""Update all urls."""
self._urls.pull()


class SegmentClientBase:
"""This class defines the basic concept of :class:`SegmentClient`.
Expand Down Expand Up @@ -364,41 +352,34 @@ def _generate_data_paths(self, offset: int = 0, limit: int = 128) -> Generator[s
def _generate_data(self, offset: int = 0, limit: int = 128) -> Generator[RemoteData, None, int]:
response = self._list_data_details(offset, limit)

urls = _UrlGetters(
LazyPage.from_items(
offset,
limit,
self._generate_urls,
(item["url"] for item in response["dataDetails"]),
),
urls = LazyPage.from_items(
offset,
limit,
self._generate_urls,
(item["url"] for item in response["dataDetails"]),
)

mask_urls = {}
for key in _MASK_KEYS:
mask_urls[key] = _UrlGetters(
LazyPage(
offset,
limit,
lambda offset, limit, k=key: self._generate_mask_urls( # type: ignore[misc]
k.upper(), offset, limit
),
mask_urls[key] = LazyPage(
offset,
limit,
lambda offset, limit, k=key: self._generate_mask_urls( # type: ignore[misc]
k.upper(), offset, limit
),
)

for i, item in enumerate(response["dataDetails"]):
data = RemoteData.from_response_body(
item,
_url_getter=urls[i],
_url_updater=urls.update,
url=URL.from_getter(urls.items[i].get, urls.pull),
cache_path=self._cache_path,
)
label = data.label
for key in _MASK_KEYS:
mask = getattr(label, key, None)
if mask:
# pylint: disable=protected-access
mask._url_getter = mask_urls[key][i]
mask._url_updater = mask_urls[key].update
mask.url = URL.from_getter(mask_urls[key].items[i].get, mask_urls[key].pull)
mask.cache_path = os.path.join(self._cache_path, key, mask.path)

yield data
Expand Down
34 changes: 15 additions & 19 deletions tensorbay/dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
"""

import os
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union

from tensorbay.label import Label
from tensorbay.utility import FileMixin, RemoteFileMixin, ReprMixin
from tensorbay.utility import URL, FileMixin, RemoteFileMixin, ReprMixin


class DataBase(ReprMixin):
Expand Down Expand Up @@ -138,14 +138,15 @@ class RemoteData(DataBase, RemoteFileMixin):
Arguments:
remote_path: The file remote path.
timestamp: The timestamp for the file.
_url_getter: The url getter of the remote file.
url: The URL instance used to get and update url.
cache_path: The path to store the cache.
Attributes:
path: The file remote path.
timestamp: The timestamp for the file.
label: The :class:`~tensorbay.label.label.Label` instance that contains
all the label information of the file.
url: The :class:`~tensorbay.utility.file.URL` instance used to get and update url.
"""

Expand All @@ -156,16 +157,14 @@ def __init__(
remote_path: str,
*,
timestamp: Optional[float] = None,
_url_getter: Optional[Callable[[str], str]] = None,
_url_updater: Optional[Callable[[], None]] = None,
url: Optional[URL] = None,
cache_path: str = "",
) -> None:
DataBase.__init__(self, timestamp)
RemoteFileMixin.__init__(
self,
remote_path,
_url_getter=_url_getter,
_url_updater=_url_updater,
url=url,
cache_path=cache_path,
)

Expand All @@ -174,10 +173,9 @@ def from_response_body(
cls: Type[_T],
body: Dict[str, Any],
*,
_url_getter: Optional[Callable[[str], str]],
_url_updater: Optional[Callable[[], None]] = None,
cache_path: str = "", # noqa: DAR101
) -> _T:
url: Optional[URL] = None,
cache_path: str = "",
) -> _T: # noqa: DAR101
"""Loads a :class:`RemoteData` object from a response body.
Arguments:
Expand All @@ -198,9 +196,7 @@ def from_response_body(
"SENTENCE": {...}
}
}
_url_getter: The url getter of the remote file.
_url_updater: The url updater of the remote file.
url: The URL instance used to get and update url.
cache_path: The path to store the cache.
Returns:
Expand All @@ -210,8 +206,7 @@ def from_response_body(
data = cls(
body["remotePath"],
timestamp=body.get("timestamp"),
_url_getter=_url_getter,
_url_updater=_url_updater,
url=url,
cache_path=cache_path,
)
data.label._loads(body["label"]) # pylint: disable=protected-access
Expand All @@ -230,13 +225,14 @@ class AuthData(DataBase, RemoteFileMixin):
cloud_path: The cloud file path.
target_remote_path: The file remote path after uploading to tensorbay.
timestamp: The timestamp for the file.
_url_getter: The url getter of the remote file.
url: The URL instance used to get and update url.
Attributes:
path: The cloud file path.
timestamp: The timestamp for the file.
label: The :class:`~tensorbay.label.label.Label` instance that contains
all the label information of the file.
url: The :class:`~tensorbay.utility.file.URL` instance used to get and update url.
"""

Expand All @@ -246,10 +242,10 @@ def __init__(
*,
target_remote_path: Optional[str] = None,
timestamp: Optional[float] = None,
_url_getter: Optional[Callable[[str], str]] = None,
url: Optional[URL] = None,
) -> None:
DataBase.__init__(self, timestamp)
RemoteFileMixin.__init__(self, cloud_path, _url_getter=_url_getter)
RemoteFileMixin.__init__(self, cloud_path, url=url)
self._target_remote_path = target_remote_path

@property
Expand Down
8 changes: 3 additions & 5 deletions tensorbay/dataset/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from tensorbay.client.lazy import LazyPage
from tensorbay.dataset.data import DataBase, RemoteData
from tensorbay.utility import UserMutableMapping
from tensorbay.utility import URL, UserMutableMapping

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,12 +114,10 @@ def from_response_body(
frame = cls(frame_id)
for data_contents in body["frame"]:
sensor_name = data_contents["sensorName"]
url = URL.from_getter(lambda s=sensor_name: urls.items[url_index].get()[s], urls.pull)
frame[sensor_name] = RemoteData.from_response_body(
data_contents,
_url_getter=lambda _, s=sensor_name: urls.items[ # type: ignore[misc]
url_index
].get()[s],
_url_updater=urls.pull,
url=url,
cache_path=cache_path,
)
return frame
Expand Down
14 changes: 7 additions & 7 deletions tensorbay/dataset/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import pytest

from tensorbay.dataset.data import Data, RemoteData
from tensorbay.utility import URL

_REMOTE_DATA = {
"remotePath": "test.json",
"timestamp": 1614667532,
"label": {},
"url": "url",
}
url = URL("url", lambda: "url")


class TestData:
Expand Down Expand Up @@ -62,21 +64,19 @@ class TestRemoteData:
def test_init(self):
remote_path = "A/test.json"
timestamp = 1614667532
remote_data = RemoteData(remote_path, timestamp=timestamp, _url_getter=lambda x: x)
remote_data = RemoteData(remote_path, timestamp=timestamp, url=url)
assert remote_data.path == remote_path
assert remote_data.timestamp == timestamp
assert remote_data.get_url() == remote_path
assert remote_data.url.get() == "url"

def test_get_url(self):
remote_data = RemoteData("A/test.josn")
with pytest.raises(ValueError):
remote_data.get_url()
remote_data.open()

def test_from_response_body(self):
data = RemoteData.from_response_body(
_REMOTE_DATA, _url_getter=lambda _: "url", cache_path="cache_path"
)
data = RemoteData.from_response_body(_REMOTE_DATA, url=url, cache_path="cache_path")
assert data.path == _REMOTE_DATA["remotePath"]
assert data.timestamp == _REMOTE_DATA["timestamp"]
assert data.get_url() == "url"
assert data.url.get() == "url"
assert data.cache_path == os.path.join("cache_path", _REMOTE_DATA["remotePath"])
4 changes: 2 additions & 2 deletions tensorbay/dataset/tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def test_from_response_body(self):
assert frame.frame_id == _FRAME_ID
assert frame["sensor1"].path == "test1.png"
assert frame["sensor1"].timestamp == 1614945883
assert frame["sensor1"].get_url() == "url1"
assert frame["sensor1"].url.get() == "url1"
assert frame["sensor1"].cache_path == os.path.join("cache_path", "test1.png")
assert frame["sensor2"].path == "test2.png"
assert frame["sensor2"].timestamp == 1614945884
assert frame["sensor2"].get_url() == "url2"
assert frame["sensor2"].url.get() == "url2"
assert frame["sensor2"].cache_path == os.path.join("cache_path", "test2.png")
10 changes: 4 additions & 6 deletions tensorbay/label/label_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

"""Mask related classes."""

from typing import Any, Callable, Dict, Optional, Type, TypeVar
from typing import Any, Dict, Optional, Type, TypeVar

from tensorbay.label.basic import AttributeType, SubcatalogBase
from tensorbay.label.supports import AttributesMixin, IsTrackingMixin, MaskCategoriesMixin
from tensorbay.utility import FileMixin, RemoteFileMixin, ReprMixin
from tensorbay.utility import URL, FileMixin, RemoteFileMixin, ReprMixin


class SemanticMaskSubcatalog(SubcatalogBase, MaskCategoriesMixin, AttributesMixin):
Expand Down Expand Up @@ -462,11 +462,9 @@ class RemotePanopticMask(PanopticMaskBase, RemoteFileMixin):

_T = TypeVar("_T", bound="RemotePanopticMask")

def __init__(
self, remote_path: str, *, _url_getter: Optional[Callable[[str], str]] = None
) -> None:
def __init__(self, remote_path: str, *, url: Optional[URL] = None) -> None:
PanopticMaskBase.__init__(self)
RemoteFileMixin.__init__(self, remote_path, _url_getter=_url_getter)
RemoteFileMixin.__init__(self, remote_path, url=url)

@classmethod
def from_response_body(cls: Type[_T], body: Dict[str, Any]) -> _T:
Expand Down
3 changes: 2 additions & 1 deletion tensorbay/utility/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Disable,
KwargsDeprecated,
)
from tensorbay.utility.file import FileMixin, RemoteFileMixin
from tensorbay.utility.file import URL, FileMixin, RemoteFileMixin
from tensorbay.utility.itertools import chunked
from tensorbay.utility.name import NameList, NameMixin, SortedNameList
from tensorbay.utility.repr import ReprMixin, ReprType, repr_config
Expand Down Expand Up @@ -43,6 +43,7 @@
"TypeEnum",
"TypeMixin",
"TypeRegister",
"URL",
"UserMapping",
"UserMutableMapping",
"UserMutableSequence",
Expand Down
Loading

0 comments on commit 7d1fa11

Please sign in to comment.