From e7ac62960d91e44a6b097277dd1e625d4fbbecba Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 21 Aug 2024 15:50:09 +0200 Subject: [PATCH] Don't use Any in sample/batch utils --- torchgeo/datasets/utils.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index ced0460def0..72846c9f5ef 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -14,10 +14,10 @@ import shutil import subprocess import sys -from collections.abc import Iterable, Iterator, Sequence, Mapping +from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, TypeAlias, cast, overload +from typing import Any, TypeAlias, TypeVar, cast, overload import numpy as np import rasterio @@ -43,6 +43,8 @@ Path: TypeAlias = str | pathlib.Path +K = TypeVar('K') +V = TypeVar('V') @dataclass(frozen=True) @@ -367,7 +369,7 @@ def working_dir(dirname: Path, create: bool = False) -> Iterator[None]: os.chdir(cwd) -def _list_dict_to_dict_list(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, list[Any]]: +def _list_dict_to_dict_list(samples: Iterable[Mapping[K, V]]) -> dict[K, list[V]]: """Convert a list of dictionaries to a dictionary of lists. Args: @@ -385,7 +387,7 @@ def _list_dict_to_dict_list(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, l return collated -def _dict_list_to_list_dict(sample: Mapping[Any, Sequence[Any]]) -> list[dict[Any, Any]]: +def _dict_list_to_list_dict(sample: Mapping[K, Sequence[V]]) -> list[dict[K, V]]: """Convert a dictionary of lists to a list of dictionaries. Args: @@ -396,16 +398,14 @@ def _dict_list_to_list_dict(sample: Mapping[Any, Sequence[Any]]) -> list[dict[An .. versionadded:: 0.2 """ - uncollated: list[dict[Any, Any]] = [ - {} for _ in range(max(map(len, sample.values()))) - ] + uncollated: list[dict[K, V]] = [{} for _ in range(max(map(len, sample.values())))] for key, values in sample.items(): for i, value in enumerate(values): uncollated[i][key] = value return uncollated -def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: +def stack_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]: """Stack a list of samples along a new axis. Useful for forming a mini-batch of samples to pass to @@ -419,14 +419,14 @@ def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: dict[K, V] = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.stack(value) return collated -def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: +def concat_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]: """Concatenate a list of samples along an existing axis. Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`. @@ -439,7 +439,7 @@ def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: dict[K, V] = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.cat(value) @@ -448,7 +448,7 @@ def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: return collated -def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: +def merge_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]: """Merge a list of samples. Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`. @@ -461,7 +461,7 @@ def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = {} + collated: dict[K, V] = {} for sample in samples: for key, value in sample.items(): if key in collated and isinstance(value, Tensor): @@ -473,7 +473,7 @@ def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: return collated -def unbind_samples(sample: Mapping[Any, Sequence[Any]]) -> list[dict[Any, Any]]: +def unbind_samples(sample: Mapping[K, Sequence[V] | Tensor]) -> list[dict[K, V]]: """Reverse of :func:`stack_samples`. Useful for turning a mini-batch of samples into a list of samples. These individual