Skip to content

Commit 2ff2c7c

Browse files
davnov134facebook-github-bot
authored andcommitted
Enable additional test-time source views for json dataset provider v2
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge Reviewed By: bottler Differential Revision: D38705904 fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
1 parent e8616cc commit 2ff2c7c

File tree

4 files changed

+102
-23
lines changed

4 files changed

+102
-23
lines changed

Diff for: projects/implicitron_trainer/tests/experiment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args:
6161
test_on_train: false
6262
only_test_set: false
6363
load_eval_batches: true
64+
n_known_frames_for_test: 0
6465
dataset_class_type: JsonIndexDataset
6566
path_manager_factory_class_type: PathManagerFactory
6667
dataset_JsonIndexDataset_args:

Diff for: pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
import copy
89
import json
910
import logging
1011
import os
1112
import warnings
12-
from typing import Dict, List, Optional, Tuple, Type
13+
from collections import defaultdict
14+
from typing import Dict, List, Optional, Tuple, Type, Union
15+
16+
import numpy as np
1317

1418
from omegaconf import DictConfig
1519
from pytorch3d.implicitron.dataset.dataset_map_provider import (
@@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
152156
only_test_set: Load only the test set. Incompatible with `test_on_train`.
153157
load_eval_batches: Load the file containing eval batches pointing to the
154158
test dataset.
159+
n_known_frames_for_test: Add a certain number of known frames to each
160+
eval batch. Useful for evaluating models that require
161+
source views as input (e.g. NeRF-WCE / PixelNeRF).
155162
dataset_args: Specifies additional arguments to the
156163
JsonIndexDataset constructor call.
157164
path_manager_factory: (Optional) An object that generates an instance of
@@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
167174
only_test_set: bool = False
168175
load_eval_batches: bool = True
169176

177+
n_known_frames_for_test: int = 0
178+
170179
dataset_class_type: str = "JsonIndexDataset"
171180
dataset: JsonIndexDataset
172181

@@ -264,6 +273,18 @@ def __post_init__(self):
264273
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
265274
logger.info(f"Val dataset: {str(val_dataset)}")
266275
logger.debug("Extracting test dataset.")
276+
277+
if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
278+
# extend the test subset mapping and the dataset with additional
279+
# known views from the train dataset
280+
(
281+
eval_batch_index,
282+
subset_mapping["test"],
283+
) = self._extend_test_data_with_known_views(
284+
subset_mapping,
285+
eval_batch_index,
286+
)
287+
267288
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
268289
logger.info(f"Test dataset: {str(test_dataset)}")
269290
if self.load_eval_batches:
@@ -369,6 +390,40 @@ def _get_available_subset_names(self):
369390
dataset_root = self.dataset_root
370391
return get_available_subset_names(dataset_root, self.category)
371392

393+
def _extend_test_data_with_known_views(
394+
self,
395+
subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
396+
eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
397+
):
398+
# convert the train subset mapping to a dict:
399+
# sequence_to_train_frames: {sequence_name: frame_index}
400+
sequence_to_train_frames = defaultdict(list)
401+
for frame_entry in subset_mapping["train"]:
402+
sequence_name = frame_entry[0]
403+
sequence_to_train_frames[sequence_name].append(frame_entry)
404+
sequence_to_train_frames = dict(sequence_to_train_frames)
405+
test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}
406+
407+
# extend the eval batches / subset mapping with the additional examples
408+
eval_batch_index_out = copy.deepcopy(eval_batch_index)
409+
generator = np.random.default_rng(seed=0)
410+
for batch in eval_batch_index_out:
411+
sequence_name = batch[0][0]
412+
sequence_known_entries = sequence_to_train_frames[sequence_name]
413+
idx_to_add = generator.permutation(len(sequence_known_entries))[
414+
: self.n_known_frames_for_test
415+
]
416+
entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
417+
assert all(e in subset_mapping["train"] for e in entries_to_add)
418+
419+
# extend the eval batch with the known views
420+
batch.extend(entries_to_add)
421+
422+
# also add these new entries to the test subset mapping
423+
test_subset_mapping_set.update(tuple(e) for e in entries_to_add)
424+
425+
return eval_batch_index_out, list(test_subset_mapping_set)
426+
372427

373428
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
374429
"""

Diff for: tests/implicitron/data/data_source.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
4949
test_on_train: false
5050
only_test_set: false
5151
load_eval_batches: true
52+
n_known_frames_for_test: 0
5253
dataset_class_type: JsonIndexDataset
5354
path_manager_factory_class_type: PathManagerFactory
5455
dataset_JsonIndexDataset_args:

Diff for: tests/implicitron/test_json_index_dataset_provider_v2.py

+44-22
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,47 @@ def test_random_dataset(self):
3737
expand_args_fields(JsonIndexDatasetMapProviderV2)
3838
categories = ["A", "B"]
3939
subset_name = "test"
40+
eval_batch_size = 5
4041
with tempfile.TemporaryDirectory() as tmpd:
41-
_make_random_json_dataset_map_provider_v2_data(tmpd, categories)
42-
for category in categories:
43-
dataset_provider = JsonIndexDatasetMapProviderV2(
44-
category=category,
45-
subset_name="test",
46-
dataset_root=tmpd,
47-
)
48-
dataset_map = dataset_provider.get_dataset_map()
49-
for set_ in ["train", "val", "test"]:
50-
dataloader = torch.utils.data.DataLoader(
51-
getattr(dataset_map, set_),
52-
batch_size=3,
53-
shuffle=True,
54-
collate_fn=FrameData.collate,
42+
_make_random_json_dataset_map_provider_v2_data(
43+
tmpd,
44+
categories,
45+
eval_batch_size=eval_batch_size,
46+
)
47+
for n_known_frames_for_test in [0, 2]:
48+
for category in categories:
49+
dataset_provider = JsonIndexDatasetMapProviderV2(
50+
category=category,
51+
subset_name="test",
52+
dataset_root=tmpd,
53+
n_known_frames_for_test=n_known_frames_for_test,
5554
)
56-
for _ in dataloader:
57-
pass
58-
category_to_subset_list = (
59-
dataset_provider.get_category_to_subset_name_list()
60-
)
61-
category_to_subset_list_ = {c: [subset_name] for c in categories}
62-
self.assertTrue(category_to_subset_list == category_to_subset_list_)
55+
dataset_map = dataset_provider.get_dataset_map()
56+
for set_ in ["train", "val", "test"]:
57+
if set_ in ["train", "val"]:
58+
dataloader = torch.utils.data.DataLoader(
59+
getattr(dataset_map, set_),
60+
batch_size=3,
61+
shuffle=True,
62+
collate_fn=FrameData.collate,
63+
)
64+
else:
65+
dataloader = torch.utils.data.DataLoader(
66+
getattr(dataset_map, set_),
67+
batch_sampler=dataset_map[set_].get_eval_batches(),
68+
collate_fn=FrameData.collate,
69+
)
70+
for batch in dataloader:
71+
if set_ == "test":
72+
self.assertTrue(
73+
batch.image_rgb.shape[0]
74+
== n_known_frames_for_test + eval_batch_size
75+
)
76+
category_to_subset_list = (
77+
dataset_provider.get_category_to_subset_name_list()
78+
)
79+
category_to_subset_list_ = {c: [subset_name] for c in categories}
80+
self.assertTrue(category_to_subset_list == category_to_subset_list_)
6381

6482

6583
def _make_random_json_dataset_map_provider_v2_data(
@@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
7088
H: int = 50,
7189
W: int = 30,
7290
subset_name: str = "test",
91+
eval_batch_size: int = 5,
7392
):
7493
os.makedirs(root, exist_ok=True)
7594
category_to_subset_list = {}
@@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
142161
with open(set_list_file, "w") as f:
143162
json.dump(set_list, f)
144163

145-
eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
164+
eval_batches = [
165+
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
166+
]
167+
146168
eval_b_dir = os.path.join(root, category, "eval_batches")
147169
os.makedirs(eval_b_dir, exist_ok=True)
148170
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")

0 commit comments

Comments
 (0)