Skip to content

Commit 3b10147

Browse files
authored
fix category file generation (#5770)
* fix category file generation * revert unrelated change * revert unrelated change
1 parent 9ea341a commit 3b10147

File tree

6 files changed

+8
-9
lines changed

6 files changed

+8
-9
lines changed

torchvision/prototype/datasets/_builtin/country211.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,6 @@ def __len__(self) -> int:
7373
}[self._split]
7474

7575
def _generate_categories(self) -> List[str]:
76-
resources = self.resources()
77-
dp = resources[0].load(self.root)
76+
resources = self._resources()
77+
dp = resources[0].load(self._root)
7878
return sorted({pathlib.Path(path).parent.name for path, _ in dp})

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _filter_images(self, data: Tuple[str, Any]) -> bool:
135135
return self._classify_archive(data) == DTDDemux.IMAGES
136136

137137
def _generate_categories(self) -> List[str]:
138-
resources = self.resources()
138+
resources = self._resources()
139139

140140
dp = resources[0].load(self._root)
141141
dp = Filter(dp, self._filter_images)

torchvision/prototype/datasets/_builtin/food101.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
9696
return Mapper(dp, self._prepare_sample)
9797

9898
def _generate_categories(self) -> List[str]:
99-
resources = self.resources()
99+
resources = self._resources()
100100
dp = resources[0].load(self._root)
101101
dp = Filter(dp, path_comparator("name", "classes.txt"))
102102
dp = LineReader(dp, decode=True, return_path=False)

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,11 @@ def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
136136
return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION
137137

138138
def _generate_categories(self) -> List[str]:
139-
config = self.default_config
140-
resources = self.resources(config)
139+
resources = self._resources()
141140

142141
dp = resources[1].load(self._root)
143142
dp = Filter(dp, self._filter_split_and_classification_anns)
144-
dp = Filter(dp, path_comparator("name", f"{config.split}.txt"))
143+
dp = Filter(dp, path_comparator("name", "trainval.txt"))
145144
dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ")
146145

147146
raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp}

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _make_datapipe(
121121
return Mapper(dp, self._prepare_sample)
122122

123123
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
124-
resources = self.resources(self.default_config)
124+
resources = self._resources(self.default_config)
125125

126126
dp = resources[0].load(root)
127127
dp = Filter(dp, path_comparator("name", "category_names.m"))

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _generate_categories(self) -> List[str]:
218218
resources = self._resources()
219219

220220
archive_dp = resources[0].load(self._root)
221-
dp = Filter(archive_dp, self._filter_detection_anns)
221+
dp = Filter(archive_dp, self._filter_anns)
222222
dp = Mapper(dp, self._parse_detection_ann, input_col=1)
223223

224224
return sorted({instance["name"] for _, anns in dp for instance in anns["object"]})

0 commit comments

Comments
 (0)