5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
8
+ import copy
8
9
import json
9
10
import logging
10
11
import os
11
12
import warnings
12
- from typing import Dict , List , Optional , Tuple , Type
13
+ from collections import defaultdict
14
+ from typing import Dict , List , Optional , Tuple , Type , Union
15
+
16
+ import numpy as np
13
17
14
18
from omegaconf import DictConfig
15
19
from pytorch3d .implicitron .dataset .dataset_map_provider import (
@@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
152
156
only_test_set: Load only the test set. Incompatible with `test_on_train`.
153
157
load_eval_batches: Load the file containing eval batches pointing to the
154
158
test dataset.
159
+ n_known_frames_for_test: Add a certain number of known frames to each
160
+ eval batch. Useful for evaluating models that require
161
+ source views as input (e.g. NeRF-WCE / PixelNeRF).
155
162
dataset_args: Specifies additional arguments to the
156
163
JsonIndexDataset constructor call.
157
164
path_manager_factory: (Optional) An object that generates an instance of
@@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
167
174
only_test_set : bool = False
168
175
load_eval_batches : bool = True
169
176
177
+ n_known_frames_for_test : int = 0
178
+
170
179
dataset_class_type : str = "JsonIndexDataset"
171
180
dataset : JsonIndexDataset
172
181
@@ -264,6 +273,18 @@ def __post_init__(self):
264
273
val_dataset = dataset .subset_from_frame_index (subset_mapping ["val" ])
265
274
logger .info (f"Val dataset: { str (val_dataset )} " )
266
275
logger .debug ("Extracting test dataset." )
276
+
277
+ if (self .n_known_frames_for_test > 0 ) and self .load_eval_batches :
278
+ # extend the test subset mapping and the dataset with additional
279
+ # known views from the train dataset
280
+ (
281
+ eval_batch_index ,
282
+ subset_mapping ["test" ],
283
+ ) = self ._extend_test_data_with_known_views (
284
+ subset_mapping ,
285
+ eval_batch_index ,
286
+ )
287
+
267
288
test_dataset = dataset .subset_from_frame_index (subset_mapping ["test" ])
268
289
logger .info (f"Test dataset: { str (test_dataset )} " )
269
290
if self .load_eval_batches :
@@ -369,6 +390,40 @@ def _get_available_subset_names(self):
369
390
dataset_root = self .dataset_root
370
391
return get_available_subset_names (dataset_root , self .category )
371
392
393
+ def _extend_test_data_with_known_views (
394
+ self ,
395
+ subset_mapping : Dict [str , List [Union [Tuple [str , int ], Tuple [str , int , str ]]]],
396
+ eval_batch_index : List [List [Union [Tuple [str , int , str ], Tuple [str , int ]]]],
397
+ ):
398
+ # convert the train subset mapping to a dict:
399
+ # sequence_to_train_frames: {sequence_name: frame_index}
400
+ sequence_to_train_frames = defaultdict (list )
401
+ for frame_entry in subset_mapping ["train" ]:
402
+ sequence_name = frame_entry [0 ]
403
+ sequence_to_train_frames [sequence_name ].append (frame_entry )
404
+ sequence_to_train_frames = dict (sequence_to_train_frames )
405
+ test_subset_mapping_set = {tuple (s ) for s in subset_mapping ["test" ]}
406
+
407
+ # extend the eval batches / subset mapping with the additional examples
408
+ eval_batch_index_out = copy .deepcopy (eval_batch_index )
409
+ generator = np .random .default_rng (seed = 0 )
410
+ for batch in eval_batch_index_out :
411
+ sequence_name = batch [0 ][0 ]
412
+ sequence_known_entries = sequence_to_train_frames [sequence_name ]
413
+ idx_to_add = generator .permutation (len (sequence_known_entries ))[
414
+ : self .n_known_frames_for_test
415
+ ]
416
+ entries_to_add = [sequence_known_entries [a ] for a in idx_to_add ]
417
+ assert all (e in subset_mapping ["train" ] for e in entries_to_add )
418
+
419
+ # extend the eval batch with the known views
420
+ batch .extend (entries_to_add )
421
+
422
+ # also add these new entries to the test subset mapping
423
+ test_subset_mapping_set .update (tuple (e ) for e in entries_to_add )
424
+
425
+ return eval_batch_index_out , list (test_subset_mapping_set )
426
+
372
427
373
428
def get_available_subset_names (dataset_root : str , category : str ) -> List [str ]:
374
429
"""
0 commit comments