8
8
import copy
9
9
import json
10
10
import logging
11
+ import multiprocessing
11
12
import os
12
13
import warnings
13
14
from collections import defaultdict
30
31
)
31
32
32
33
from pytorch3d .renderer .cameras import CamerasBase
34
+ from tqdm import tqdm
33
35
34
36
35
37
_CO3DV2_DATASET_ROOT : str = os .getenv ("CO3DV2_DATASET_ROOT" , "" )
@@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
147
149
(test frames can repeat across batches).
148
150
149
151
Args:
150
- category: The object category of the dataset.
152
+ category: Dataset categories to load expressed as a string of comma-separated
153
+ category names (e.g. `"apple,car,orange"`).
151
154
subset_name: The name of the dataset subset. For CO3Dv2, these include
152
155
e.g. "manyview_dev_0", "fewview_test", ...
153
156
dataset_root: The root folder of the dataset.
@@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
173
176
test_on_train : bool = False
174
177
only_test_set : bool = False
175
178
load_eval_batches : bool = True
179
+ num_load_workers : int = 4
176
180
177
181
n_known_frames_for_test : int = 0
178
182
@@ -189,11 +193,33 @@ def __post_init__(self):
189
193
if self .only_test_set and self .test_on_train :
190
194
raise ValueError ("Cannot have only_test_set and test_on_train" )
191
195
192
- frame_file = os .path .join (
193
- self .dataset_root , self .category , "frame_annotations.jgz"
194
- )
196
+ if "," in self .category :
197
+ # a comma-separated list of categories to load
198
+ categories = [c .strip () for c in self .category .split ("," )]
199
+ logger .info (f"Loading a list of categories: { str (categories )} ." )
200
+ with multiprocessing .Pool (
201
+ processes = min (self .num_load_workers , len (categories ))
202
+ ) as pool :
203
+ category_dataset_maps = list (
204
+ tqdm (
205
+ pool .imap (self ._load_category , categories ),
206
+ total = len (categories ),
207
+ )
208
+ )
209
+ dataset_map = category_dataset_maps [0 ]
210
+ dataset_map .join (category_dataset_maps [1 :])
211
+
212
+ else :
213
+ # one category to load
214
+ dataset_map = self ._load_category (self .category )
215
+
216
+ self .dataset_map = dataset_map
217
+
218
+ def _load_category (self , category : str ) -> DatasetMap :
219
+
220
+ frame_file = os .path .join (self .dataset_root , category , "frame_annotations.jgz" )
195
221
sequence_file = os .path .join (
196
- self .dataset_root , self . category , "sequence_annotations.jgz"
222
+ self .dataset_root , category , "sequence_annotations.jgz"
197
223
)
198
224
199
225
path_manager = self .path_manager_factory .get ()
@@ -232,7 +258,7 @@ def __post_init__(self):
232
258
233
259
dataset = dataset_type (** common_dataset_kwargs )
234
260
235
- available_subset_names = self ._get_available_subset_names ()
261
+ available_subset_names = self ._get_available_subset_names (category )
236
262
logger .debug (f"Available subset names: { str (available_subset_names )} ." )
237
263
if self .subset_name not in available_subset_names :
238
264
raise ValueError (
@@ -242,20 +268,20 @@ def __post_init__(self):
242
268
243
269
# load the list of train/val/test frames
244
270
subset_mapping = self ._load_annotation_json (
245
- os .path .join (
246
- self .category , "set_lists" , f"set_lists_{ self .subset_name } .json"
247
- )
271
+ os .path .join (category , "set_lists" , f"set_lists_{ self .subset_name } .json" )
248
272
)
249
273
250
274
# load the evaluation batches
251
275
if self .load_eval_batches :
252
276
eval_batch_index = self ._load_annotation_json (
253
277
os .path .join (
254
- self . category ,
278
+ category ,
255
279
"eval_batches" ,
256
280
f"eval_batches_{ self .subset_name } .json" ,
257
281
)
258
282
)
283
+ else :
284
+ eval_batch_index = None
259
285
260
286
train_dataset = None
261
287
if not self .only_test_set :
@@ -313,9 +339,7 @@ def __post_init__(self):
313
339
)
314
340
logger .info (f"# eval batches: { len (test_dataset .eval_batches )} " )
315
341
316
- self .dataset_map = DatasetMap (
317
- train = train_dataset , val = val_dataset , test = test_dataset
318
- )
342
+ return DatasetMap (train = train_dataset , val = val_dataset , test = test_dataset )
319
343
320
344
@classmethod
321
345
def dataset_tweak_args (cls , type , args : DictConfig ) -> None :
@@ -381,10 +405,10 @@ def _load_annotation_json(self, json_filename: str):
381
405
data = json .load (f )
382
406
return data
383
407
384
- def _get_available_subset_names (self ):
408
+ def _get_available_subset_names (self , category : str ):
385
409
return get_available_subset_names (
386
410
self .dataset_root ,
387
- self . category ,
411
+ category ,
388
412
path_manager = self .path_manager_factory .get (),
389
413
)
390
414
0 commit comments