diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index e60d61e5f90..dc4578c49f4 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -74,7 +74,7 @@ def __getitem__(self, idx): # of this class sample = self._dataset[idx] - sample = self._wrapper(sample) + sample = self._wrapper(idx, sample) # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`) # or joint (`transforms`), we can access the full functionality through `transforms` @@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): def classification_wrapper_factory(dataset): - return identity + def wrapper(idx, sample): + return sample + + return wrapper for dataset_cls in [ @@ -143,7 +146,7 @@ def classification_wrapper_factory(dataset): def segmentation_wrapper_factory(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, mask = sample return image, pil_image_to_mask(mask) @@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset): f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead." ) - def wrapper(sample): + def wrapper(idx, sample): video, audio, label = sample video = datapoints.Video(video) @@ -201,14 +204,17 @@ def segmentation_to_mask(segmentation, *, spatial_size): ) return torch.from_numpy(mask.decode(segmentation)) - def wrapper(sample): + def wrapper(idx, sample): + image_id = dataset.ids[idx] + image, target = sample + if not target: + return image, dict(image_id=image_id) + batched_target = list_of_dicts_to_dict_of_lists(target) - image_ids = batched_target.pop("image_id") - image_id = batched_target["image_id"] = image_ids.pop() - assert all(other_image_id == image_id for other_image_id in image_ids) + batched_target["image_id"] = image_id spatial_size = tuple(F.get_spatial_size(image)) batched_target["boxes"] = datapoints.BoundingBox( @@ -259,7 +265,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.VOCDetection) def voc_detection_wrapper_factory(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) @@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset): if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]): raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`") - def wrapper(sample): + def wrapper(idx, sample): image, target = sample target = wrap_target_by_type( @@ -318,7 +324,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.Kitti) def kitti_wrapper_factory(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample if target is not None: @@ -336,7 +342,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet) def oxford_iiit_pet_wrapper_factor(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample if target is not None: @@ -371,7 +377,7 @@ def instance_segmentation_wrapper(mask): labels.append(label) return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels)) - def wrapper(sample): + def wrapper(idx, sample): image, target = sample target = wrap_target_by_type( @@ -390,7 +396,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.WIDERFace) def widerface_wrapper(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample if target is not None: