-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
migrate country211 prototype dataset (#5753)
- Loading branch information
Showing
2 changed files
with
50 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,78 @@ | ||
import pathlib | ||
from typing import Any, Dict, List, Tuple | ||
from typing import Any, Dict, List, Tuple, Union | ||
|
||
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter | ||
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource | ||
from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling | ||
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource | ||
from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR | ||
from torchvision.prototype.features import EncodedImage, Label | ||
|
||
from .._api import register_dataset, register_info | ||
|
||
class Country211(Dataset): | ||
def _make_info(self) -> DatasetInfo: | ||
return DatasetInfo( | ||
"country211", | ||
homepage="https://github.com/openai/CLIP/blob/main/data/country211.md", | ||
valid_options=dict(split=("train", "val", "test")), | ||
) | ||
NAME = "country211" | ||
|
||
CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) | ||
|
||
|
||
@register_info(NAME) | ||
def _info() -> Dict[str, Any]: | ||
return dict(categories=CATEGORIES) | ||
|
||
|
||
@register_dataset(NAME) | ||
class Country211(Dataset2): | ||
""" | ||
- **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md | ||
""" | ||
|
||
def resources(self, config: DatasetConfig) -> List[OnlineResource]: | ||
def __init__( | ||
self, | ||
root: Union[str, pathlib.Path], | ||
*, | ||
split: str = "train", | ||
skip_integrity_check: bool = False, | ||
) -> None: | ||
self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) | ||
self._split_folder_name = "valid" if split == "val" else split | ||
|
||
self._categories = _info()["categories"] | ||
|
||
super().__init__(root, skip_integrity_check=skip_integrity_check) | ||
|
||
def _resources(self) -> List[OnlineResource]: | ||
return [ | ||
HttpResource( | ||
"https://openaipublic.azureedge.net/clip/data/country211.tgz", | ||
sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c", | ||
) | ||
] | ||
|
||
_SPLIT_NAME_MAPPER = { | ||
"train": "train", | ||
"val": "valid", | ||
"test": "test", | ||
} | ||
|
||
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: | ||
path, buffer = data | ||
category = pathlib.Path(path).parent.name | ||
return dict( | ||
label=Label.from_category(category, categories=self.categories), | ||
label=Label.from_category(category, categories=self._categories), | ||
path=path, | ||
image=EncodedImage.from_file(buffer), | ||
) | ||
|
||
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: | ||
return pathlib.Path(data[0]).parent.parent.name == split | ||
|
||
def _make_datapipe( | ||
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig | ||
) -> IterDataPipe[Dict[str, Any]]: | ||
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: | ||
dp = resource_dps[0] | ||
dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split])) | ||
dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name)) | ||
dp = hint_shuffling(dp) | ||
dp = hint_sharding(dp) | ||
return Mapper(dp, self._prepare_sample) | ||
|
||
def _generate_categories(self, root: pathlib.Path) -> List[str]: | ||
resources = self.resources(self.default_config) | ||
dp = resources[0].load(root) | ||
def __len__(self) -> int: | ||
return { | ||
"train": 31_650, | ||
"val": 10_550, | ||
"test": 21_100, | ||
}[self._split] | ||
|
||
def _generate_categories(self) -> List[str]: | ||
resources = self.resources() | ||
dp = resources[0].load(self.root) | ||
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) |