From 1b7ed0300e4a40a1ffe65f339d955eb87b5daa23 Mon Sep 17 00:00:00 2001 From: Adrian Tofting Date: Mon, 3 Jul 2023 10:47:22 +0200 Subject: [PATCH] Revert "Remove required os.exists for paths" This reverts commit 84bf62b944326c33d5ba8efdcab615c65b124792. --- torchgeo/datasets/geo.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index c40f880e2bf..fda6f24b859 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -438,14 +438,43 @@ def list_files(self, filename_glob: Optional[str] = None) -> list[str]: # using set to remove any duplicates if directories are overlapping filepaths: set[str] = set() for dir_or_file in paths: - if os.path.isdir(dir_or_file): - pathname = os.path.join(dir_or_file, "**", filename_glob) - filepaths |= set(glob.iglob(pathname, recursive=True)) + if os.path.exists(dir_or_file): + if os.path.isfile(dir_or_file): + filepaths.add(dir_or_file) + else: + pathname = os.path.join(dir_or_file, "**", filename_glob) + filepaths |= set(glob.iglob(pathname, recursive=True)) else: - filepaths.add(dir_or_file) + filepaths |= self.handle_nonlocal_path(dir_or_file) return list(filepaths) + def handle_nonlocal_path(self, path: str) -> set[str]: + """Override this method if your path can not be interpreted by os module. + + See docs for Advanced Datasets + https://rasterio.readthedocs.io/en/stable/topics/datasets.html + `fiona.listdir` can be used to list files in such directories: + https://fiona.readthedocs.io/en/stable/fiona.html#fiona.io.MemoryFile.listdir + + Args: + path: directory, cloud storage blob or archive to be listed + + Returns: + set of paths pointing to files + + Raises: + RuntimeError: if the path is not found locally, + or this method is not overridden by child + """ + raise RuntimeError( + f"Dataset not found in `{path}` " + "either specify a different `root` directory or make sure you " + "have manually downloaded the dataset as instructed in the documentation." + "If this is a remote file or archive " + "please override this method and return the filepath(s) as a set." + ) + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query.