Skip to content

Commit e4a3298

Browse files
davnov134facebook-github-bot
authored andcommitted
CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
1 parent c54e048 commit e4a3298

File tree

9 files changed

+272
-25
lines changed

9 files changed

+272
-25
lines changed

projects/implicitron_trainer/tests/experiment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ data_source_ImplicitronDataSource_args:
6262
test_on_train: false
6363
only_test_set: false
6464
load_eval_batches: true
65+
num_load_workers: 4
6566
n_known_frames_for_test: 0
6667
dataset_class_type: JsonIndexDataset
6768
path_manager_factory_class_type: PathManagerFactory

pytorch3d/implicitron/dataset/dataset_base.py

+23
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import (
1010
Any,
1111
ClassVar,
12+
Dict,
1213
Iterable,
1314
Iterator,
1415
List,
@@ -259,6 +260,12 @@ def get_frame_numbers_and_timestamps(
259260
"""
260261
raise ValueError("This dataset does not contain videos.")
261262

263+
def join(self, other_datasets: Iterable["DatasetBase"]) -> None:
264+
"""
265+
Joins the current dataset with a list of other datasets of the same type.
266+
"""
267+
raise NotImplementedError()
268+
262269
def get_eval_batches(self) -> Optional[List[List[int]]]:
263270
return None
264271

@@ -267,6 +274,22 @@ def sequence_names(self) -> Iterable[str]:
267274
# pyre-ignore[16]
268275
return self._seq_to_idx.keys()
269276

277+
def category_to_sequence_names(self) -> Dict[str, List[str]]:
278+
"""
279+
Returns a dict mapping from each dataset category to a list of its
280+
sequence names.
281+
282+
Returns:
283+
category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]}
284+
"""
285+
c2seq = defaultdict(list)
286+
for sequence_name in self.sequence_names():
287+
first_frame_idx = next(self.sequence_indices_in_order(sequence_name))
288+
# crashes without overriding __getitem__
289+
sequence_category = self[first_frame_idx].sequence_category
290+
c2seq[sequence_category].append(sequence_name)
291+
return dict(c2seq)
292+
270293
def sequence_frames_in_order(
271294
self, seq_name: str
272295
) -> Iterator[Tuple[float, int, int]]:

pytorch3d/implicitron/dataset/dataset_map_provider.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import os
99
from dataclasses import dataclass
10-
from typing import Iterator, Optional
10+
from typing import Iterable, Iterator, Optional
1111

1212
from iopath.common.file_io import PathManager
1313
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
@@ -51,6 +51,34 @@ def iter_datasets(self) -> Iterator[DatasetBase]:
5151
if self.test is not None:
5252
yield self.test
5353

54+
def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None:
55+
"""
56+
Joins the current DatasetMap with other dataset maps from the input list.
57+
58+
For each subset of each dataset map (train/val/test), the function
59+
omits joining the subsets that are None.
60+
61+
Note the train/val/test datasets of the current dataset map will be
62+
modified in-place.
63+
64+
Args:
65+
other_dataset_maps: The list of dataset maps to be joined into the
66+
current dataset map.
67+
"""
68+
for set_ in ["train", "val", "test"]:
69+
dataset_list = [
70+
getattr(self, set_),
71+
*[getattr(dmap, set_) for dmap in other_dataset_maps],
72+
]
73+
dataset_list = [d for d in dataset_list if d is not None]
74+
if len(dataset_list) == 0:
75+
setattr(self, set_, None)
76+
continue
77+
d0 = dataset_list[0]
78+
if len(dataset_list) > 1:
79+
d0.join(dataset_list[1:])
80+
setattr(self, set_, d0)
81+
5482

5583
class DatasetMapProviderBase(ReplaceableBase):
5684
"""

pytorch3d/implicitron/dataset/json_index_dataset.py

+70-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import (
2020
Any,
2121
ClassVar,
22+
Dict,
23+
Iterable,
2224
List,
2325
Optional,
2426
Sequence,
@@ -188,7 +190,44 @@ def _extract_and_set_eval_batches(self):
188190
self.eval_batch_index
189191
)
190192

191-
def is_filtered(self):
193+
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
194+
"""
195+
Join the dataset with other JsonIndexDataset objects.
196+
197+
Args:
198+
other_datasets: A list of JsonIndexDataset objects to be joined
199+
into the current dataset.
200+
"""
201+
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
202+
raise ValueError("This function can only join a list of JsonIndexDataset")
203+
# pyre-ignore[16]
204+
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
205+
# pyre-ignore[16]
206+
self.seq_annots.update(
207+
# https://gist.github.com/treyhunner/f35292e676efa0be1728
208+
functools.reduce(
209+
lambda a, b: {**a, **b},
210+
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
211+
)
212+
)
213+
all_eval_batches = [
214+
self.eval_batches,
215+
# pyre-ignore
216+
*[d.eval_batches for d in other_datasets],
217+
]
218+
if not (
219+
all(ba is None for ba in all_eval_batches)
220+
or all(ba is not None for ba in all_eval_batches)
221+
):
222+
raise ValueError(
223+
"When joining datasets, either all joined datasets have to have their"
224+
" eval_batches defined, or all should have their eval batches undefined."
225+
)
226+
if self.eval_batches is not None:
227+
self.eval_batches = sum(all_eval_batches, [])
228+
self._invalidate_indexes(filter_seq_annots=True)
229+
230+
def is_filtered(self) -> bool:
192231
"""
193232
Returns `True` in case the dataset has been filtered and thus some frame annotations
194233
stored on the disk might be missing in the dataset object.
@@ -211,6 +250,7 @@ def seq_frame_index_to_dataset_index(
211250
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
212251
allow_missing_indices: bool = False,
213252
remove_missing_indices: bool = False,
253+
suppress_missing_index_warning: bool = True,
214254
) -> List[List[Union[Optional[int], int]]]:
215255
"""
216256
Obtain indices into the dataset object given a list of frame ids.
@@ -228,6 +268,11 @@ def seq_frame_index_to_dataset_index(
228268
If `False`, returns `None` in place of `seq_frame_index` entries that
229269
are not present in the dataset.
230270
If `True` removes missing indices from the returned indices.
271+
suppress_missing_index_warning:
272+
Active if `allow_missing_indices==True`. Suppressess a warning message
273+
in case an entry from `seq_frame_index` is missing in the dataset
274+
(expected in certain cases - e.g. when setting
275+
`self.remove_empty_masks=True`).
231276
232277
Returns:
233278
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
@@ -254,7 +299,8 @@ def _get_dataset_idx(
254299
)
255300
if not allow_missing_indices:
256301
raise IndexError(msg)
257-
warnings.warn(msg)
302+
if not suppress_missing_index_warning:
303+
warnings.warn(msg)
258304
return idx
259305
if path is not None:
260306
# Check that the loaded frame path is consistent
@@ -288,6 +334,21 @@ def subset_from_frame_index(
288334
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
289335
allow_missing_indices: bool = True,
290336
) -> "JsonIndexDataset":
337+
"""
338+
Generate a dataset subset given the list of frames specified in `frame_index`.
339+
340+
Args:
341+
frame_index: The list of frame indentifiers (as stored in the metadata)
342+
specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
343+
Image paths relative to the dataset_root can be stored specified as well:
344+
`List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
345+
in the latter case, if imaga_path do not match the stored paths, an error
346+
is raised.
347+
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
348+
entry from `frame_index` which is missing in the dataset.
349+
Otherwise, generates a subset consisting of frames entries that actually
350+
exist in the dataset.
351+
"""
291352
# Get the indices into the frame annots.
292353
dataset_indices = self.seq_frame_index_to_dataset_index(
293354
[frame_index],
@@ -838,6 +899,13 @@ def get_frame_numbers_and_timestamps(
838899
)
839900
return out
840901

902+
def category_to_sequence_names(self) -> Dict[str, List[str]]:
903+
c2seq = defaultdict(list)
904+
# pyre-ignore
905+
for sequence_name, sa in self.seq_annots.items():
906+
c2seq[sa.category].append(sequence_name)
907+
return dict(c2seq)
908+
841909
def get_eval_batches(self) -> Optional[List[List[int]]]:
842910
return self.eval_batches
843911

pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py

+39-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import copy
99
import json
1010
import logging
11+
import multiprocessing
1112
import os
1213
import warnings
1314
from collections import defaultdict
@@ -30,6 +31,7 @@
3031
)
3132

3233
from pytorch3d.renderer.cameras import CamerasBase
34+
from tqdm import tqdm
3335

3436

3537
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
@@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
147149
(test frames can repeat across batches).
148150
149151
Args:
150-
category: The object category of the dataset.
152+
category: Dataset categories to load expressed as a string of comma-separated
153+
category names (e.g. `"apple,car,orange"`).
151154
subset_name: The name of the dataset subset. For CO3Dv2, these include
152155
e.g. "manyview_dev_0", "fewview_test", ...
153156
dataset_root: The root folder of the dataset.
@@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
173176
test_on_train: bool = False
174177
only_test_set: bool = False
175178
load_eval_batches: bool = True
179+
num_load_workers: int = 4
176180

177181
n_known_frames_for_test: int = 0
178182

@@ -189,11 +193,33 @@ def __post_init__(self):
189193
if self.only_test_set and self.test_on_train:
190194
raise ValueError("Cannot have only_test_set and test_on_train")
191195

192-
frame_file = os.path.join(
193-
self.dataset_root, self.category, "frame_annotations.jgz"
194-
)
196+
if "," in self.category:
197+
# a comma-separated list of categories to load
198+
categories = [c.strip() for c in self.category.split(",")]
199+
logger.info(f"Loading a list of categories: {str(categories)}.")
200+
with multiprocessing.Pool(
201+
processes=min(self.num_load_workers, len(categories))
202+
) as pool:
203+
category_dataset_maps = list(
204+
tqdm(
205+
pool.imap(self._load_category, categories),
206+
total=len(categories),
207+
)
208+
)
209+
dataset_map = category_dataset_maps[0]
210+
dataset_map.join(category_dataset_maps[1:])
211+
212+
else:
213+
# one category to load
214+
dataset_map = self._load_category(self.category)
215+
216+
self.dataset_map = dataset_map
217+
218+
def _load_category(self, category: str) -> DatasetMap:
219+
220+
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
195221
sequence_file = os.path.join(
196-
self.dataset_root, self.category, "sequence_annotations.jgz"
222+
self.dataset_root, category, "sequence_annotations.jgz"
197223
)
198224

199225
path_manager = self.path_manager_factory.get()
@@ -232,7 +258,7 @@ def __post_init__(self):
232258

233259
dataset = dataset_type(**common_dataset_kwargs)
234260

235-
available_subset_names = self._get_available_subset_names()
261+
available_subset_names = self._get_available_subset_names(category)
236262
logger.debug(f"Available subset names: {str(available_subset_names)}.")
237263
if self.subset_name not in available_subset_names:
238264
raise ValueError(
@@ -242,20 +268,20 @@ def __post_init__(self):
242268

243269
# load the list of train/val/test frames
244270
subset_mapping = self._load_annotation_json(
245-
os.path.join(
246-
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
247-
)
271+
os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json")
248272
)
249273

250274
# load the evaluation batches
251275
if self.load_eval_batches:
252276
eval_batch_index = self._load_annotation_json(
253277
os.path.join(
254-
self.category,
278+
category,
255279
"eval_batches",
256280
f"eval_batches_{self.subset_name}.json",
257281
)
258282
)
283+
else:
284+
eval_batch_index = None
259285

260286
train_dataset = None
261287
if not self.only_test_set:
@@ -313,9 +339,7 @@ def __post_init__(self):
313339
)
314340
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
315341

316-
self.dataset_map = DatasetMap(
317-
train=train_dataset, val=val_dataset, test=test_dataset
318-
)
342+
return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
319343

320344
@classmethod
321345
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
@@ -381,10 +405,10 @@ def _load_annotation_json(self, json_filename: str):
381405
data = json.load(f)
382406
return data
383407

384-
def _get_available_subset_names(self):
408+
def _get_available_subset_names(self, category: str):
385409
return get_available_subset_names(
386410
self.dataset_root,
387-
self.category,
411+
category,
388412
path_manager=self.path_manager_factory.get(),
389413
)
390414

0 commit comments

Comments
 (0)