diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index e668e749840..7ddbe08e597 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -378,9 +378,11 @@ def _list_dict_to_dict_list( .. versionadded:: 0.2 """ - collated = collections.defaultdict(list) + collated: dict[Any, list[Any]] = dict() for sample in samples: for key, value in sample.items(): + if key not in collated: + collated[key] = [] collated[key].append(value) return collated