@@ -74,7 +74,7 @@ def __getitem__(self, idx):
7474 # of this class
7575 sample = self ._dataset [idx ]
7676
77- sample = self ._wrapper (sample )
77+ sample = self ._wrapper (idx , sample )
7878
7979 # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
8080 # or joint (`transforms`), we can access the full functionality through `transforms`
@@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
125125
126126
127127def classification_wrapper_factory (dataset ):
128- return identity
128+ def wrapper (idx , sample ):
129+ return sample
130+
131+ return wrapper
129132
130133
131134for dataset_cls in [
@@ -143,7 +146,7 @@ def classification_wrapper_factory(dataset):
143146
144147
145148def segmentation_wrapper_factory (dataset ):
146- def wrapper (sample ):
149+ def wrapper (idx , sample ):
147150 image , mask = sample
148151 return image , pil_image_to_mask (mask )
149152
@@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
163166 f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
164167 )
165168
166- def wrapper (sample ):
169+ def wrapper (idx , sample ):
167170 video , audio , label = sample
168171
169172 video = datapoints .Video (video )
@@ -201,14 +204,17 @@ def segmentation_to_mask(segmentation, *, spatial_size):
201204 )
202205 return torch .from_numpy (mask .decode (segmentation ))
203206
204- def wrapper (sample ):
207+ def wrapper (idx , sample ):
208+ image_id = dataset .ids [idx ]
209+
205210 image , target = sample
206211
212+ if not target :
213+ return image , dict (image_id = image_id )
214+
207215 batched_target = list_of_dicts_to_dict_of_lists (target )
208216
209- image_ids = batched_target .pop ("image_id" )
210- image_id = batched_target ["image_id" ] = image_ids .pop ()
211- assert all (other_image_id == image_id for other_image_id in image_ids )
217+ batched_target ["image_id" ] = image_id
212218
213219 spatial_size = tuple (F .get_spatial_size (image ))
214220 batched_target ["boxes" ] = datapoints .BoundingBox (
@@ -259,7 +265,7 @@ def wrapper(sample):
259265
260266@WRAPPER_FACTORIES .register (datasets .VOCDetection )
261267def voc_detection_wrapper_factory (dataset ):
262- def wrapper (sample ):
268+ def wrapper (idx , sample ):
263269 image , target = sample
264270
265271 batched_instances = list_of_dicts_to_dict_of_lists (target ["annotation" ]["object" ])
@@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
294300 if any (target_type in dataset .target_type for target_type in ["attr" , "landmarks" ]):
295301 raise_not_supported ("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`" )
296302
297- def wrapper (sample ):
303+ def wrapper (idx , sample ):
298304 image , target = sample
299305
300306 target = wrap_target_by_type (
@@ -318,7 +324,7 @@ def wrapper(sample):
318324
319325@WRAPPER_FACTORIES .register (datasets .Kitti )
320326def kitti_wrapper_factory (dataset ):
321- def wrapper (sample ):
327+ def wrapper (idx , sample ):
322328 image , target = sample
323329
324330 if target is not None :
@@ -336,7 +342,7 @@ def wrapper(sample):
336342
337343@WRAPPER_FACTORIES .register (datasets .OxfordIIITPet )
338344def oxford_iiit_pet_wrapper_factor (dataset ):
339- def wrapper (sample ):
345+ def wrapper (idx , sample ):
340346 image , target = sample
341347
342348 if target is not None :
@@ -371,7 +377,7 @@ def instance_segmentation_wrapper(mask):
371377 labels .append (label )
372378 return dict (masks = datapoints .Mask (torch .stack (masks )), labels = torch .stack (labels ))
373379
374- def wrapper (sample ):
380+ def wrapper (idx , sample ):
375381 image , target = sample
376382
377383 target = wrap_target_by_type (
@@ -390,7 +396,7 @@ def wrapper(sample):
390396
391397@WRAPPER_FACTORIES .register (datasets .WIDERFace )
392398def widerface_wrapper (dataset ):
393- def wrapper (sample ):
399+ def wrapper (idx , sample ):
394400 image , target = sample
395401
396402 if target is not None :
0 commit comments