diff --git a/CHANGES.md b/CHANGES.md index 242cf518..1bf5855b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## 0.10.0 (yyyy-mm-dd) +### Improvements + +- Add support to read, write, list, and remove `/vsimem/` files (#457) + ### Bug fixes - Silence warning from `write_dataframe` with `GeoSeries.notna()` (#435). diff --git a/docs/source/api.rst b/docs/source/api.rst index 007470d7..105fbb3f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -5,7 +5,7 @@ Core ---- .. automodule:: pyogrio - :members: list_drivers, detect_write_driver, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, __gdal_version__, __gdal_version_string__ + :members: list_drivers, detect_write_driver, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, vsi_listtree, vsi_rmtree, vsi_unlink, __gdal_version__, __gdal_version_string__ GeoPandas integration --------------------- diff --git a/environment-dev.yml b/environment-dev.yml index c6de8e69..d92ac999 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -3,13 +3,13 @@ channels: - conda-forge dependencies: # Required - - numpy - libgdal-core + - numpy - shapely>=2 # Optional - geopandas-base - - pyproj - pyarrow + - pyproj # Specific for dev - cython - pre-commit diff --git a/pyogrio/__init__.py b/pyogrio/__init__.py index a5f32511..c5450c6e 100644 --- a/pyogrio/__init__.py +++ b/pyogrio/__init__.py @@ -21,6 +21,9 @@ read_bounds, read_info, set_gdal_config_options, + vsi_listtree, + vsi_rmtree, + vsi_unlink, ) from pyogrio.geopandas import read_dataframe, write_dataframe from pyogrio.raw import open_arrow, read_arrow, write_arrow @@ -37,10 +40,13 @@ "set_gdal_config_options", "get_gdal_config_option", "get_gdal_data_path", - "read_arrow", "open_arrow", - "write_arrow", + "read_arrow", "read_dataframe", + "vsi_listtree", + "vsi_rmtree", + "vsi_unlink", + "write_arrow", "write_dataframe", "__gdal_version__", "__gdal_version_string__", diff --git a/pyogrio/_io.pyx b/pyogrio/_io.pyx index 7c6427ce..a9c934e5 100644 --- a/pyogrio/_io.pyx +++ b/pyogrio/_io.pyx @@ -12,6 +12,7 @@ import math import os import sys import warnings +from pathlib import Path from libc.stdint cimport uint8_t, uintptr_t from libc.stdlib cimport malloc, free @@ -1184,7 +1185,7 @@ def ogr_read( ): cdef int err = 0 - cdef bint is_vsimem = isinstance(path_or_buffer, bytes) + cdef bint use_tmp_vsimem = isinstance(path_or_buffer, bytes) cdef const char *path_c = NULL cdef char **dataset_options = NULL cdef const char *where_c = NULL @@ -1224,7 +1225,7 @@ def ogr_read( raise ValueError("'max_features' must be >= 0") try: - path = read_buffer_to_vsimem(path_or_buffer) if is_vsimem else path_or_buffer + path = read_buffer_to_vsimem(path_or_buffer) if use_tmp_vsimem else path_or_buffer if encoding: # for shapefiles, SHAPE_ENCODING must be set before opening the file @@ -1362,8 +1363,8 @@ def ogr_read( CPLFree(prev_shape_encoding) prev_shape_encoding = NULL - if is_vsimem: - delete_vsimem_file(path) + if use_tmp_vsimem: + vsimem_rmtree_toplevel(path) return ( meta, @@ -1424,7 +1425,7 @@ def ogr_open_arrow( ): cdef int err = 0 - cdef bint is_vsimem = isinstance(path_or_buffer, bytes) + cdef bint use_tmp_vsimem = isinstance(path_or_buffer, bytes) cdef const char *path_c = NULL cdef char **dataset_options = NULL cdef const char *where_c = NULL @@ -1480,7 +1481,7 @@ def ogr_open_arrow( reader = None try: - path = read_buffer_to_vsimem(path_or_buffer) if is_vsimem else path_or_buffer + path = read_buffer_to_vsimem(path_or_buffer) if use_tmp_vsimem else path_or_buffer if encoding: override_shape_encoding = True @@ -1679,8 +1680,8 @@ def ogr_open_arrow( CPLFree(prev_shape_encoding) prev_shape_encoding = NULL - if is_vsimem: - delete_vsimem_file(path) + if use_tmp_vsimem: + vsimem_rmtree_toplevel(path) def ogr_read_bounds( @@ -1697,7 +1698,7 @@ def ogr_read_bounds( object mask=None): cdef int err = 0 - cdef bint is_vsimem = isinstance(path_or_buffer, bytes) + cdef bint use_tmp_vsimem = isinstance(path_or_buffer, bytes) cdef const char *path_c = NULL cdef const char *where_c = NULL cdef OGRDataSourceH ogr_dataset = NULL @@ -1715,7 +1716,7 @@ def ogr_read_bounds( raise ValueError("'max_features' must be >= 0") try: - path = read_buffer_to_vsimem(path_or_buffer) if is_vsimem else path_or_buffer + path = read_buffer_to_vsimem(path_or_buffer) if use_tmp_vsimem else path_or_buffer ogr_dataset = ogr_open(path.encode('UTF-8'), 0, NULL) if layer is None: @@ -1744,8 +1745,8 @@ def ogr_read_bounds( GDALClose(ogr_dataset) ogr_dataset = NULL - if is_vsimem: - delete_vsimem_file(path) + if use_tmp_vsimem: + vsimem_rmtree_toplevel(path) return bounds @@ -1758,7 +1759,7 @@ def ogr_read_info( int force_feature_count=False, int force_total_bounds=False): - cdef bint is_vsimem = isinstance(path_or_buffer, bytes) + cdef bint use_tmp_vsimem = isinstance(path_or_buffer, bytes) cdef const char *path_c = NULL cdef char **dataset_options = NULL cdef OGRDataSourceH ogr_dataset = NULL @@ -1767,7 +1768,7 @@ def ogr_read_info( cdef bint override_shape_encoding = False try: - path = read_buffer_to_vsimem(path_or_buffer) if is_vsimem else path_or_buffer + path = read_buffer_to_vsimem(path_or_buffer) if use_tmp_vsimem else path_or_buffer if encoding: override_shape_encoding = True @@ -1826,19 +1827,19 @@ def ogr_read_info( if prev_shape_encoding != NULL: CPLFree(prev_shape_encoding) - if is_vsimem: - delete_vsimem_file(path) + if use_tmp_vsimem: + vsimem_rmtree_toplevel(path) return meta def ogr_list_layers(object path_or_buffer): - cdef bint is_vsimem = isinstance(path_or_buffer, bytes) + cdef bint use_tmp_vsimem = isinstance(path_or_buffer, bytes) cdef const char *path_c = NULL cdef OGRDataSourceH ogr_dataset = NULL try: - path = read_buffer_to_vsimem(path_or_buffer) if is_vsimem else path_or_buffer + path = read_buffer_to_vsimem(path_or_buffer) if use_tmp_vsimem else path_or_buffer ogr_dataset = ogr_open(path.encode('UTF-8'), 0, NULL) layers = get_layer_names(ogr_dataset) @@ -1847,8 +1848,8 @@ def ogr_list_layers(object path_or_buffer): GDALClose(ogr_dataset) ogr_dataset = NULL - if is_vsimem: - delete_vsimem_file(path) + if use_tmp_vsimem: + vsimem_rmtree_toplevel(path) return layers @@ -1931,6 +1932,16 @@ cdef void * ogr_create(const char* path_c, const char* driver_c, char** options) except CPLE_BaseError as exc: raise DataSourceError(str(exc)) + # For /vsimem/ files, with GDAL >= 3.8 parent directories are created automatically. + IF CTE_GDAL_VERSION < (3, 8, 0): + path = path_c.decode("UTF-8") + if "/vsimem/" in path: + parent = str(Path(path).parent.as_posix()) + if not parent.endswith("/vsimem"): + retcode = VSIMkdirRecursive(parent.encode("UTF-8"), 0666) + if retcode != 0: + raise OSError(f"Could not create parent directory '{parent}'") + # Create the dataset try: ogr_dataset = exc_wrap_pointer(GDALCreate(ogr_driver, path_c, 0, 0, 0, GDT_Unknown, options)) @@ -2014,7 +2025,7 @@ cdef infer_field_types(list dtypes): cdef create_ogr_dataset_layer( str path, - bint is_vsi, + bint use_tmp_vsimem, str layer, str driver, str crs, @@ -2048,6 +2059,8 @@ cdef create_ogr_dataset_layer( encoding : str Only used if `driver` is "ESRI Shapefile". If not None, it overrules the default shapefile encoding, which is "UTF-8" in pyogrio. + use_tmp_vsimem : bool + Whether the file path is meant to save a temporary memory file to. Returns ------- @@ -2075,8 +2088,8 @@ cdef create_ogr_dataset_layer( driver_b = driver.encode('UTF-8') driver_c = driver_b - # in-memory dataset is always created from scratch - path_exists = os.path.exists(path) if not is_vsi else False + # temporary in-memory dataset is always created from scratch + path_exists = os.path.exists(path) if not use_tmp_vsimem else False if not layer: layer = os.path.splitext(os.path.split(path)[1])[0] @@ -2112,10 +2125,7 @@ cdef create_ogr_dataset_layer( raise exc # otherwise create from scratch - if is_vsi: - VSIUnlink(path_c) - else: - os.unlink(path) + os.unlink(path) ogr_dataset = NULL @@ -2250,7 +2260,7 @@ def ogr_write( cdef int num_records = -1 cdef int num_field_data = len(field_data) if field_data is not None else 0 cdef int num_fields = len(fields) if fields is not None else 0 - cdef bint is_vsi = False + cdef bint use_tmp_vsimem = False if num_fields != num_field_data: raise ValueError("field_data array needs to be same length as fields array") @@ -2291,12 +2301,11 @@ def ogr_write( try: # Setup in-memory handler if needed - path = get_ogr_vsimem_write_path(path_or_fp, driver) - is_vsi = path.startswith('/vsimem/') + path, use_tmp_vsimem = get_ogr_vsimem_write_path(path_or_fp, driver) # Setup dataset and layer layer_created = create_ogr_dataset_layer( - path, is_vsi, layer, driver, crs, geometry_type, encoding, + path, use_tmp_vsimem, layer, driver, crs, geometry_type, encoding, dataset_kwargs, layer_kwargs, append, dataset_metadata, layer_metadata, &ogr_dataset, &ogr_layer, @@ -2501,7 +2510,7 @@ def ogr_write( raise DataSourceError(f"Failed to write features to dataset {path}; {exc}") # copy in-memory file back to path_or_fp object - if is_vsi: + if use_tmp_vsimem: read_vsimem_to_buffer(path, path_or_fp) finally: @@ -2523,8 +2532,8 @@ def ogr_write( if ogr_dataset != NULL: ogr_close(ogr_dataset) - if is_vsi: - delete_vsimem_file(path) + if use_tmp_vsimem: + vsimem_rmtree_toplevel(path) def ogr_write_arrow( @@ -2548,7 +2557,7 @@ def ogr_write_arrow( cdef OGRDataSourceH ogr_dataset = NULL cdef OGRLayerH ogr_layer = NULL cdef char **options = NULL - cdef bint is_vsi = False + cdef bint use_tmp_vsimem = False cdef ArrowArrayStream* stream = NULL cdef ArrowSchema schema cdef ArrowArray array @@ -2557,11 +2566,11 @@ def ogr_write_arrow( array.release = NULL try: - path = get_ogr_vsimem_write_path(path_or_fp, driver) - is_vsi = path.startswith('/vsimem/') + # Setup in-memory handler if needed + path, use_tmp_vsimem = get_ogr_vsimem_write_path(path_or_fp, driver) layer_created = create_ogr_dataset_layer( - path, is_vsi, layer, driver, crs, geometry_type, encoding, + path, use_tmp_vsimem, layer, driver, crs, geometry_type, encoding, dataset_kwargs, layer_kwargs, append, dataset_metadata, layer_metadata, &ogr_dataset, &ogr_layer, @@ -2622,7 +2631,7 @@ def ogr_write_arrow( raise DataSourceError(f"Failed to write features to dataset {path}; {exc}") # copy in-memory file back to path_or_fp object - if is_vsi: + if use_tmp_vsimem: read_vsimem_to_buffer(path, path_or_fp) finally: @@ -2642,8 +2651,8 @@ def ogr_write_arrow( if ogr_dataset != NULL: ogr_close(ogr_dataset) - if is_vsi: - delete_vsimem_file(path) + if use_tmp_vsimem: + vsimem_rmtree_toplevel(path) cdef get_arrow_extension_metadata(const ArrowSchema* schema): diff --git a/pyogrio/_ogr.pxd b/pyogrio/_ogr.pxd index 9369ba71..8ce6a578 100644 --- a/pyogrio/_ogr.pxd +++ b/pyogrio/_ogr.pxd @@ -36,6 +36,10 @@ cdef extern from "cpl_error.h" nogil: void CPLPopErrorHandler() +cdef extern from "cpl_port.h": + ctypedef char **CSLConstList + + cdef extern from "cpl_string.h": char** CSLAddNameValue(char **list, const char *name, const char *value) char** CSLSetNameValue(char **list, const char *name, const char *value) @@ -53,6 +57,9 @@ cdef extern from "cpl_vsi.h" nogil: long st_mode int st_mtime + int VSIStatL(const char *path, VSIStatBufL *psStatBuf) + int VSI_ISDIR(int mode) + char** VSIReadDirRecursive(const char *path) int VSIFCloseL(VSILFILE *fp) int VSIFFlushL(VSILFILE *fp) int VSIUnlink(const char *path) @@ -61,7 +68,8 @@ cdef extern from "cpl_vsi.h" nogil: unsigned char *VSIGetMemFileBuffer(const char *path, vsi_l_offset *data_len, int take_ownership) int VSIMkdir(const char *path, long mode) - int VSIRmdirRecursive(const char *pszDirname) + int VSIMkdirRecursive(const char *path, long mode) + int VSIRmdirRecursive(const char *path) cdef extern from "ogr_core.h": diff --git a/pyogrio/_vsi.pxd b/pyogrio/_vsi.pxd index afa2633a..1c464489 100644 --- a/pyogrio/_vsi.pxd +++ b/pyogrio/_vsi.pxd @@ -1,4 +1,4 @@ -cdef str get_ogr_vsimem_write_path(object path_or_fp, str driver) +cdef tuple get_ogr_vsimem_write_path(object path_or_fp, str driver) cdef str read_buffer_to_vsimem(bytes bytes_buffer) cdef read_vsimem_to_buffer(str path, object out_buffer) -cdef delete_vsimem_file(str path) \ No newline at end of file +cpdef vsimem_rmtree_toplevel(str path) \ No newline at end of file diff --git a/pyogrio/_vsi.pyx b/pyogrio/_vsi.pyx index 47b8c11d..757c2c78 100644 --- a/pyogrio/_vsi.pyx +++ b/pyogrio/_vsi.pyx @@ -1,3 +1,4 @@ +import fnmatch from io import BytesIO from uuid import uuid4 @@ -8,28 +9,44 @@ from pyogrio._ogr cimport * from pyogrio._ogr import _get_driver_metadata_item -cdef str get_ogr_vsimem_write_path(object path_or_fp, str driver): - """ Return the original path or a /vsimem/ path +cdef tuple get_ogr_vsimem_write_path(object path_or_fp, str driver): + """Return the path to write to and whether it is a tmp vsimem filepath. - If passed a io.BytesIO object, this will return a /vsimem/ path that can be - used to create a new in-memory file with an extension inferred from the driver - if possible. Path will be contained in an in-memory directory to contain - sibling files (though drivers that create sibling files are not supported for - in-memory files). + If passed a io.BytesIO object to write to, a temporary vsimem file will be + used to write the data directly to memory. + Hence, a tuple will be returned with a /vsimem/ path and True to indicate + the path will be to a tmp vsimem file. + The path will have an extension inferred from the driver if possible. Path + will be contained in an in-memory directory to contain sibling files + (though drivers that create sibling files are not supported for in-memory + files). - Caller is responsible for deleting the directory via delete_vsimem_file() + Caller is responsible for deleting the directory via + vsimem_rmtree_toplevel(). Parameters ---------- path_or_fp : str or io.BytesIO object driver : str - """ + Returns + ------- + tuple of (path, use_tmp_vsimem) + Tuple of the path to write to and a bool indicating if the path is a + temporary vsimem filepath. + + """ + # The write path is not a BytesIO object, so return path as-is if not isinstance(path_or_fp, BytesIO): - return path_or_fp + return (path_or_fp, False) - # Create in-memory directory to contain auxiliary files - memfilename = uuid4().hex + # Check for existing bytes + if path_or_fp.getbuffer().nbytes > 0: + raise NotImplementedError("writing to existing in-memory object is not supported") + + # Create in-memory directory to contain auxiliary files. + # Prefix with "pyogrio_" so it is clear the directory was created by pyogrio. + memfilename = f"pyogrio_{uuid4().hex}" VSIMkdir(f"/vsimem/{memfilename}".encode("UTF-8"), 0666) # file extension is required for some drivers, set it based on driver metadata @@ -40,11 +57,7 @@ cdef str get_ogr_vsimem_write_path(object path_or_fp, str driver): path = f"/vsimem/{memfilename}/{memfilename}{ext}" - # check for existing bytes - if path_or_fp.getbuffer().nbytes > 0: - raise NotImplementedError("writing to existing in-memory object is not supported") - - return path + return (path, True) cdef str read_buffer_to_vsimem(bytes bytes_buffer): @@ -54,7 +67,8 @@ cdef str read_buffer_to_vsimem(bytes bytes_buffer): will be prefixed with /vsizip/ and suffixed with .zip to enable proper reading by GDAL. - Caller is responsible for deleting the in-memory file via delete_vsimem_file(). + Caller is responsible for deleting the in-memory file via + vsimem_rmtree_toplevel(). Parameters ---------- @@ -65,12 +79,15 @@ cdef str read_buffer_to_vsimem(bytes bytes_buffer): is_zipped = len(bytes_buffer) > 4 and bytes_buffer[:4].startswith(b"PK\x03\x04") ext = ".zip" if is_zipped else "" - path = f"/vsimem/{uuid4().hex}{ext}" + # Prefix with "pyogrio_" so it is clear the file was created by pyogrio. + path = f"/vsimem/pyogrio_{uuid4().hex}{ext}" # Create an in-memory object that references bytes_buffer # NOTE: GDAL does not copy the contents of bytes_buffer; it must remain # in scope through the duration of using this file - vsi_handle = VSIFileFromMemBuffer(path.encode("UTF-8"), bytes_buffer, num_bytes, 0) + vsi_handle = VSIFileFromMemBuffer( + path.encode("UTF-8"), bytes_buffer, num_bytes, 0 + ) if vsi_handle == NULL: raise OSError("failed to read buffer into in-memory file") @@ -88,8 +105,8 @@ cdef read_vsimem_to_buffer(str path, object out_buffer): """Copy bytes from in-memory file to buffer This will automatically unlink the in-memory file pointed to by path; caller - is still responsible for calling delete_vsimem_file() to cleanup any other - files contained in the in-memory directory. + is still responsible for calling vsimem_rmtree_toplevel() to cleanup any + other files contained in the in-memory directory. Parameters: ----------- @@ -118,23 +135,155 @@ cdef read_vsimem_to_buffer(str path, object out_buffer): CPLFree(vsi_buffer) -cdef delete_vsimem_file(str path): - """ Recursively delete in-memory path or directory containing path +cpdef vsimem_rmtree_toplevel(str path): + """Remove the top-level file or top-level directory containing the file. + + This is used for final cleanup of an in-memory dataset. The path can point + to either: + - a top-level file (directly in /vsimem/). + - a file in a directory, which may include sibling files. + - a zip file (reported as a directory by VSI_ISDIR). - This is used for final cleanup of an in-memory dataset, which may have been - created within a directory to contain sibling files. + Except for the first case, the top-level directory (direct subdirectory of + /vsimem/) will be determined and will be removed recursively. Additional VSI handlers may be chained to the left of /vsimem/ in path and will be ignored. + Even though it is only meant for "internal use", the function is declared + as cpdef, so it can be called from tests as well. + Parameters: ----------- path : str path to in-memory file + """ + cdef VSIStatBufL st_buf if "/vsimem/" not in path: - return + raise ValueError(f"Path is not a /vsimem/ path: '{path}'") + + # Determine the top-level directory of the file + mempath_parts = path.split("/vsimem/")[1].split("/") + if len(mempath_parts) == 0: + raise OSError("path to in-memory file or directory is required") + + toplevel_path = f"/vsimem/{mempath_parts[0]}" + + if not VSIStatL(toplevel_path.encode("UTF-8"), &st_buf) == 0: + raise FileNotFoundError(f"Path does not exist: '{path}'") + + if VSI_ISDIR(st_buf.st_mode): + errcode = VSIRmdirRecursive(toplevel_path.encode("UTF-8")) + else: + errcode = VSIUnlink(toplevel_path.encode("UTF-8")) + + if errcode != 0: + raise OSError(f"Error removing '{path}': {errcode=}") + + +def ogr_vsi_listtree(str path, str pattern): + """Recursively list the contents in a VSI directory. + + An fnmatch pattern can be specified to filter the directories/files + returned. + + Parameters: + ----------- + path : str + Path to the VSI directory to be listed. + pattern : str + Pattern to filter results, in fnmatch format. + + """ + cdef const char *path_c + cdef int n + cdef char** papszFiles + cdef VSIStatBufL st_buf - root = "/vsimem/" + path.split("/vsimem/")[1].split("/")[0] - VSIRmdirRecursive(root.encode("UTF-8")) + path_b = path.encode("UTF-8") + path_c = path_b + + if not VSIStatL(path_c, &st_buf) == 0: + raise FileNotFoundError(f"Path does not exist: '{path}'") + if not VSI_ISDIR(st_buf.st_mode): + raise NotADirectoryError(f"Path is not a directory: '{path}'") + + try: + papszFiles = VSIReadDirRecursive(path_c) + n = CSLCount(papszFiles) + files = [] + for i in range(n): + files.append(papszFiles[i].decode("UTF-8")) + finally: + CSLDestroy(papszFiles) + + # Apply filter pattern + if pattern is not None: + files = fnmatch.filter(files, pattern) + + # Prepend files with the base path + if not path.endswith("/"): + path = f"{path}/" + files = [f"{path}{file}" for file in files] + + return files + + +def ogr_vsi_rmtree(str path): + """Recursively remove VSI directory. + + Parameters: + ----------- + path : str + path to the VSI directory to be removed. + + """ + cdef const char *path_c + cdef VSIStatBufL st_buf + + try: + path_b = path.encode("UTF-8") + except UnicodeDecodeError: + path_b = path + path_c = path_b + if not VSIStatL(path_c, &st_buf) == 0: + raise FileNotFoundError(f"Path does not exist: '{path}'") + if not VSI_ISDIR(st_buf.st_mode): + raise NotADirectoryError(f"Path is not a directory: '{path}'") + if path.endswith("/vsimem") or path.endswith("/vsimem/"): + raise OSError("path to in-memory file or directory is required") + + errcode = VSIRmdirRecursive(path_c) + if errcode != 0: + raise OSError(f"Error in rmtree of '{path}': {errcode=}") + + +def ogr_vsi_unlink(str path): + """Remove VSI file. + + Parameters: + ----------- + path : str + path to the VSI file to be removed. + + """ + cdef const char *path_c + cdef VSIStatBufL st_buf + + try: + path_b = path.encode("UTF-8") + except UnicodeDecodeError: + path_b = path + path_c = path_b + + if not VSIStatL(path_c, &st_buf) == 0: + raise FileNotFoundError(f"Path does not exist: '{path}'") + + if VSI_ISDIR(st_buf.st_mode): + raise IsADirectoryError(f"Path is a directory: '{path}'") + + errcode = VSIUnlink(path_c) + if errcode != 0: + raise OSError(f"Error removing '{path}': {errcode=}") diff --git a/pyogrio/core.py b/pyogrio/core.py index add4725f..1fa18fa4 100644 --- a/pyogrio/core.py +++ b/pyogrio/core.py @@ -1,5 +1,8 @@ """Core functions to interact with OGR data sources.""" +from pathlib import Path +from typing import Optional, Union + from pyogrio._env import GDALEnv from pyogrio.util import ( _mask_to_wkb, @@ -23,6 +26,11 @@ ogr_list_drivers, set_gdal_config_options as _set_gdal_config_options, ) + from pyogrio._vsi import ( + ogr_vsi_listtree, + ogr_vsi_rmtree, + ogr_vsi_unlink, + ) _init_gdal_data() _init_proj_data() @@ -326,3 +334,53 @@ def get_gdal_data_path(): """ return _get_gdal_data_path() + + +def vsi_listtree(path: Union[str, Path], pattern: Optional[str] = None): + """Recursively list the contents of a VSI directory. + + An fnmatch pattern can be specified to filter the directories/files + returned. + + Parameters + ---------- + path : str or pathlib.Path + Path to the VSI directory to be listed. + pattern : str, optional + Pattern to filter results, in fnmatch format. + + """ + if isinstance(path, Path): + path = path.as_posix() + + return ogr_vsi_listtree(path, pattern=pattern) + + +def vsi_rmtree(path: Union[str, Path]): + """Recursively remove VSI directory. + + Parameters + ---------- + path : str or pathlib.Path + path to the VSI directory to be removed. + + """ + if isinstance(path, Path): + path = path.as_posix() + + ogr_vsi_rmtree(path) + + +def vsi_unlink(path: Union[str, Path]): + """Remove a VSI file. + + Parameters + ---------- + path : str or pathlib.Path + path to vsimem file to be removed + + """ + if isinstance(path, Path): + path = path.as_posix() + + ogr_vsi_unlink(path) diff --git a/pyogrio/raw.py b/pyogrio/raw.py index aaac0285..0f0c3063 100644 --- a/pyogrio/raw.py +++ b/pyogrio/raw.py @@ -563,7 +563,7 @@ def _get_write_path_driver(path, driver, append=False): ) else: - path = vsi_path(str(path)) + path = vsi_path(path) if driver is None: driver = detect_write_driver(path) diff --git a/pyogrio/tests/conftest.py b/pyogrio/tests/conftest.py index 262bc1a3..d6bea86b 100644 --- a/pyogrio/tests/conftest.py +++ b/pyogrio/tests/conftest.py @@ -17,6 +17,7 @@ HAS_PYPROJ, HAS_SHAPELY, ) +from pyogrio.core import vsi_rmtree from pyogrio.raw import read, write import pytest @@ -38,6 +39,15 @@ ALL_EXTS = [".fgb", ".geojson", ".geojsonl", ".gpkg", ".shp"] +START_FID = { + ".fgb": 0, + ".geojson": 0, + ".geojsonl": 0, + ".geojsons": 0, + ".gpkg": 1, + ".shp": 0, +} + def pytest_report_header(config): drivers = ", ".join( @@ -116,7 +126,7 @@ def naturalearth_lowres_all_ext(tmp_path, naturalearth_lowres, request): @pytest.fixture(scope="function") def naturalearth_lowres_vsi(tmp_path, naturalearth_lowres): - """Wrap naturalearth_lowres as a zip file for vsi tests""" + """Wrap naturalearth_lowres as a zip file for VSI tests""" path = tmp_path / f"{naturalearth_lowres.name}.zip" with ZipFile(path, mode="w", compression=ZIP_DEFLATED, compresslevel=5) as out: @@ -127,6 +137,22 @@ def naturalearth_lowres_vsi(tmp_path, naturalearth_lowres): return path, f"/vsizip/{path}/{naturalearth_lowres.name}" +@pytest.fixture(scope="function") +def naturalearth_lowres_vsimem(naturalearth_lowres): + """Write naturalearth_lowres to a vsimem file for VSI tests""" + + meta, _, geometry, field_data = read(naturalearth_lowres) + name = f"pyogrio_fixture_{naturalearth_lowres.stem}" + dst_path = Path(f"/vsimem/{name}/{name}.gpkg") + meta["spatial_index"] = False + meta["geometry_type"] = "MultiPolygon" + + write(dst_path, geometry, field_data, layer="naturalearth_lowres", **meta) + yield dst_path + + vsi_rmtree(dst_path.parent) + + @pytest.fixture(scope="session") def line_zm_file(): return _data_dir / "line_zm.gpkg" diff --git a/pyogrio/tests/test_arrow.py b/pyogrio/tests/test_arrow.py index 7b2d6673..0a89a92a 100644 --- a/pyogrio/tests/test_arrow.py +++ b/pyogrio/tests/test_arrow.py @@ -17,6 +17,7 @@ read_dataframe, read_info, set_gdal_config_options, + vsi_listtree, ) from pyogrio.errors import DataLayerError, DataSourceError, FieldError from pyogrio.raw import open_arrow, read_arrow, write, write_arrow @@ -162,6 +163,10 @@ def test_read_arrow_vsi(naturalearth_lowres_vsi): table = read_arrow(naturalearth_lowres_vsi[1])[1] assert len(table) == 177 + # Check temp file was cleaned up. Filter to files created by pyogrio, as GDAL keeps + # cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + def test_read_arrow_bytes(geojson_bytes): meta, table = read_arrow(geojson_bytes) @@ -169,12 +174,18 @@ def test_read_arrow_bytes(geojson_bytes): assert meta["fields"].shape == (5,) assert len(table) == 3 + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + def test_read_arrow_nonseekable_bytes(nonseekable_bytes): meta, table = read_arrow(nonseekable_bytes) assert meta["fields"].shape == (0,) assert len(table) == 1 + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + def test_read_arrow_filelike(geojson_filelike): meta, table = read_arrow(geojson_filelike) @@ -182,6 +193,9 @@ def test_read_arrow_filelike(geojson_filelike): assert meta["fields"].shape == (5,) assert len(table) == 3 + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + def test_open_arrow_pyarrow(naturalearth_lowres): with open_arrow(naturalearth_lowres, use_pyarrow=True) as (meta, reader): @@ -968,6 +982,9 @@ def test_write_memory_driver_required(naturalearth_lowres): geometry_name=meta["geometry_name"] or "wkb_geometry", ) + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + @requires_arrow_write_api @pytest.mark.parametrize("driver", ["ESRI Shapefile", "OpenFileGDB"]) @@ -1074,6 +1091,9 @@ def test_write_open_file_handle(tmp_path, naturalearth_lowres): geometry_name=meta["geometry_name"] or "wkb_geometry", ) + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + @requires_arrow_write_api def test_non_utf8_encoding_io_shapefile(tmp_path, encoded_text): diff --git a/pyogrio/tests/test_core.py b/pyogrio/tests/test_core.py index 1d593466..e0ff6d49 100644 --- a/pyogrio/tests/test_core.py +++ b/pyogrio/tests/test_core.py @@ -1,9 +1,12 @@ +from pathlib import Path + import numpy as np from numpy import allclose, array_equal from pyogrio import ( __gdal_geos_version__, __gdal_version__, + detect_write_driver, get_gdal_config_option, get_gdal_data_path, list_drivers, @@ -11,12 +14,15 @@ read_bounds, read_info, set_gdal_config_options, + vsi_listtree, + vsi_rmtree, + vsi_unlink, ) from pyogrio._compat import GDAL_GE_38 from pyogrio._env import GDALEnv -from pyogrio.core import detect_write_driver from pyogrio.errors import DataLayerError, DataSourceError -from pyogrio.tests.conftest import prepare_testfile, requires_shapely +from pyogrio.raw import read, write +from pyogrio.tests.conftest import START_FID, prepare_testfile, requires_shapely import pytest @@ -154,6 +160,7 @@ def test_list_drivers(): def test_list_layers( naturalearth_lowres, naturalearth_lowres_vsi, + naturalearth_lowres_vsimem, line_zm_file, curve_file, curve_polygon_file, @@ -168,6 +175,11 @@ def test_list_layers( list_layers(naturalearth_lowres_vsi[1]), [["naturalearth_lowres", "Polygon"]] ) + assert array_equal( + list_layers(naturalearth_lowres_vsimem), + [["naturalearth_lowres", "MultiPolygon"]], + ) + # Measured 3D is downgraded to plain 3D during read # Make sure this warning is raised with pytest.warns( @@ -207,22 +219,18 @@ def test_list_layers_filelike(geojson_filelike): assert layers[0, 0] == "test" -def test_read_bounds(naturalearth_lowres): - fids, bounds = read_bounds(naturalearth_lowres) - assert fids.shape == (177,) - assert bounds.shape == (4, 177) - - assert fids[0] == 0 - # Fiji; wraps antimeridian - assert allclose(bounds[:, 0], [-180.0, -18.28799, 180.0, -16.02088]) - +@pytest.mark.parametrize( + "testfile", + ["naturalearth_lowres", "naturalearth_lowres_vsimem", "naturalearth_lowres_vsi"], +) +def test_read_bounds(testfile, request): + path = request.getfixturevalue(testfile) + path = path if not isinstance(path, tuple) else path[1] -def test_read_bounds_vsi(naturalearth_lowres_vsi): - fids, bounds = read_bounds(naturalearth_lowres_vsi[1]) + fids, bounds = read_bounds(path) assert fids.shape == (177,) assert bounds.shape == (4, 177) - - assert fids[0] == 0 + assert fids[0] == START_FID[Path(path).suffix] # Fiji; wraps antimeridian assert allclose(bounds[:, 0], [-180.0, -18.28799, 180.0, -16.02088]) @@ -308,12 +316,9 @@ def test_read_bounds_bbox(naturalearth_lowres_all_ext): fids, bounds = read_bounds(naturalearth_lowres_all_ext, bbox=(-85, 8, -80, 10)) assert fids.shape == (2,) - if naturalearth_lowres_all_ext.suffix == ".gpkg": - # fid in gpkg is 1-based - assert array_equal(fids, [34, 35]) # PAN, CRI - else: - # fid in other formats is 0-based - assert array_equal(fids, [33, 34]) # PAN, CRI + fids_expected = np.array([33, 34]) # PAN, CRI + fids_expected += START_FID[naturalearth_lowres_all_ext.suffix] + assert array_equal(fids, fids_expected) assert bounds.shape == (4, 2) assert allclose( @@ -378,12 +383,8 @@ def test_read_bounds_mask(naturalearth_lowres_all_ext, mask, expected): fids = read_bounds(naturalearth_lowres_all_ext, mask=mask)[0] - if naturalearth_lowres_all_ext.suffix == ".gpkg": - # fid in gpkg is 1-based - assert array_equal(fids, np.array(expected) + 1) - else: - # fid in other formats is 0-based - assert array_equal(fids, expected) + fids_expected = np.array(expected) + START_FID[naturalearth_lowres_all_ext.suffix] + assert array_equal(fids, fids_expected) @pytest.mark.skipif( @@ -399,21 +400,15 @@ def test_read_bounds_bbox_intersects_vs_envelope_overlaps(naturalearth_lowres_al if __gdal_geos_version__ is None: # bboxes for CAN, RUS overlap but do not intersect geometries assert fids.shape == (4,) - if naturalearth_lowres_all_ext.suffix == ".gpkg": - # fid in gpkg is 1-based - assert array_equal(fids, [4, 5, 19, 28]) # CAN, USA, RUS, MEX - else: - # fid in other formats is 0-based - assert array_equal(fids, [3, 4, 18, 27]) # CAN, USA, RUS, MEX + fids_expected = np.array([3, 4, 18, 27]) # CAN, USA, RUS, MEX + fids_expected += START_FID[naturalearth_lowres_all_ext.suffix] + assert array_equal(fids, fids_expected) else: assert fids.shape == (2,) - if naturalearth_lowres_all_ext.suffix == ".gpkg": - # fid in gpkg is 1-based - assert array_equal(fids, [5, 28]) # USA, MEX - else: - # fid in other formats is 0-based - assert array_equal(fids, [4, 27]) # USA, MEX + fids_expected = np.array([4, 27]) # USA, MEX + fids_expected += START_FID[naturalearth_lowres_all_ext.suffix] + assert array_equal(fids, fids_expected) @pytest.mark.parametrize("naturalearth_lowres", [".shp", ".gpkg"], indirect=True) @@ -453,8 +448,14 @@ def test_read_info(naturalearth_lowres): raise ValueError(f"test not implemented for ext {naturalearth_lowres.suffix}") -def test_read_info_vsi(naturalearth_lowres_vsi): - meta = read_info(naturalearth_lowres_vsi[1]) +@pytest.mark.parametrize( + "testfile", ["naturalearth_lowres_vsimem", "naturalearth_lowres_vsi"] +) +def test_read_info_vsi(testfile, request): + path = request.getfixturevalue(testfile) + path = path if not isinstance(path, tuple) else path[1] + + meta = read_info(path) assert meta["fields"].shape == (5,) assert meta["features"] == 177 @@ -611,3 +612,67 @@ def test_error_handling_warning(capfd, naturalearth_lowres): read_info(naturalearth_lowres, INVALID="YES") assert capfd.readouterr().err == "" + + +def test_vsimem_listtree_rmtree_unlink(naturalearth_lowres): + """Test all basic functionalities of file handling in /vsimem/.""" + # Prepare test data in /vsimem + meta, _, geometry, field_data = read(naturalearth_lowres) + meta["spatial_index"] = False + meta["geometry_type"] = "MultiPolygon" + test_file_path = Path("/vsimem/pyogrio_test_naturalearth_lowres.gpkg") + test_dir_path = Path(f"/vsimem/pyogrio_dir_test/{naturalearth_lowres.stem}.gpkg") + + write(test_file_path, geometry, field_data, **meta) + write(test_dir_path, geometry, field_data, **meta) + + # Check if everything was created properly with listtree + files = vsi_listtree("/vsimem/") + assert test_file_path.as_posix() in files + assert test_dir_path.as_posix() in files + + # Check listtree with pattern + files = vsi_listtree("/vsimem/", pattern="pyogrio_dir_test*.gpkg") + assert test_file_path.as_posix() not in files + assert test_dir_path.as_posix() in files + + files = vsi_listtree("/vsimem/", pattern="pyogrio_test*.gpkg") + assert test_file_path.as_posix() in files + assert test_dir_path.as_posix() not in files + + # Remove test_dir and its contents + vsi_rmtree(test_dir_path.parent) + files = vsi_listtree("/vsimem/") + assert test_file_path.as_posix() in files + assert test_dir_path.as_posix() not in files + + # Remove test_file + vsi_unlink(test_file_path) + + +def test_vsimem_rmtree_error(naturalearth_lowres_vsimem): + with pytest.raises(NotADirectoryError, match="Path is not a directory"): + vsi_rmtree(naturalearth_lowres_vsimem) + + with pytest.raises(FileNotFoundError, match="Path does not exist"): + vsi_rmtree("/vsimem/non-existent") + + with pytest.raises( + OSError, match="path to in-memory file or directory is required" + ): + vsi_rmtree("/vsimem") + with pytest.raises( + OSError, match="path to in-memory file or directory is required" + ): + vsi_rmtree("/vsimem/") + + # Verify that naturalearth_lowres_vsimem still exists. + assert naturalearth_lowres_vsimem.as_posix() in vsi_listtree("/vsimem") + + +def test_vsimem_unlink_error(naturalearth_lowres_vsimem): + with pytest.raises(IsADirectoryError, match="Path is a directory"): + vsi_unlink(naturalearth_lowres_vsimem.parent) + + with pytest.raises(FileNotFoundError, match="Path does not exist"): + vsi_unlink("/vsimem/non-existent.gpkg") diff --git a/pyogrio/tests/test_geopandas_io.py b/pyogrio/tests/test_geopandas_io.py index 74efa6f7..c70ac820 100644 --- a/pyogrio/tests/test_geopandas_io.py +++ b/pyogrio/tests/test_geopandas_io.py @@ -7,7 +7,14 @@ import numpy as np -from pyogrio import __gdal_version__, list_drivers, list_layers, read_info +from pyogrio import ( + __gdal_version__, + list_drivers, + list_layers, + read_info, + vsi_listtree, + vsi_unlink, +) from pyogrio._compat import HAS_ARROW_WRITE_API, HAS_PYPROJ, PANDAS_GE_15 from pyogrio.errors import DataLayerError, DataSourceError, FeatureError, GeometryError from pyogrio.geopandas import PANDAS_GE_20, read_dataframe, write_dataframe @@ -18,6 +25,7 @@ from pyogrio.tests.conftest import ( ALL_EXTS, DRIVERS, + START_FID, requires_arrow_write_api, requires_gdal_geos, requires_pyarrow_api, @@ -371,12 +379,9 @@ def test_read_fid_as_index(naturalearth_lowres_all_ext, use_arrow): fid_as_index=True, **kwargs, ) - if naturalearth_lowres_all_ext.suffix in [".gpkg"]: - # File format where fid starts at 1 - assert_index_equal(df.index, pd.Index([3, 4], name="fid")) - else: - # File format where fid starts at 0 - assert_index_equal(df.index, pd.Index([2, 3], name="fid")) + fids_expected = pd.Index([2, 3], name="fid") + fids_expected += START_FID[naturalearth_lowres_all_ext.suffix] + assert_index_equal(df.index, fids_expected) def test_read_fid_as_index_only(naturalearth_lowres, use_arrow): @@ -1568,6 +1573,22 @@ def test_write_read_null(tmp_path, use_arrow): assert result_gdf["object_str"][2] is None +@pytest.mark.requires_arrow_write_api +def test_write_read_vsimem(naturalearth_lowres_vsi, use_arrow): + path, _ = naturalearth_lowres_vsi + mem_path = f"/vsimem/{path.name}" + + input = read_dataframe(path, use_arrow=use_arrow) + assert len(input) == 177 + + try: + write_dataframe(input, mem_path, use_arrow=use_arrow) + result = read_dataframe(mem_path, use_arrow=use_arrow) + assert len(result) == 177 + finally: + vsi_unlink(mem_path) + + @pytest.mark.parametrize( "wkt,geom_types", [ @@ -1974,6 +1995,9 @@ def test_write_memory(naturalearth_lowres, driver): check_dtype=not is_json, ) + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + def test_write_memory_driver_required(naturalearth_lowres): df = read_dataframe(naturalearth_lowres) @@ -1986,6 +2010,9 @@ def test_write_memory_driver_required(naturalearth_lowres): ): write_dataframe(df.head(1), buffer, driver=None, layer="test") + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + @pytest.mark.parametrize("driver", ["ESRI Shapefile", "OpenFileGDB"]) def test_write_memory_unsupported_driver(naturalearth_lowres, driver): @@ -2001,6 +2028,9 @@ def test_write_memory_unsupported_driver(naturalearth_lowres, driver): ): write_dataframe(df, buffer, driver=driver, layer="test") + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + @pytest.mark.parametrize("driver", ["GeoJSON", "GPKG"]) def test_write_memory_append_unsupported(naturalearth_lowres, driver): @@ -2013,6 +2043,9 @@ def test_write_memory_append_unsupported(naturalearth_lowres, driver): ): write_dataframe(df.head(1), buffer, driver=driver, layer="test", append=True) + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + def test_write_memory_existing_unsupported(naturalearth_lowres): df = read_dataframe(naturalearth_lowres) @@ -2024,6 +2057,9 @@ def test_write_memory_existing_unsupported(naturalearth_lowres): ): write_dataframe(df.head(1), buffer, driver="GeoJSON", layer="test") + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + def test_write_open_file_handle(tmp_path, naturalearth_lowres): """Verify that writing to an open file handle is not currently supported""" @@ -2045,6 +2081,9 @@ def test_write_open_file_handle(tmp_path, naturalearth_lowres): with z.open("test.geojson", "w") as f: write_dataframe(df.head(1), f) + # Check temp file was cleaned up. Filter, as gdal keeps cache files in /vsimem/. + assert vsi_listtree("/vsimem/", pattern="pyogrio_*") == [] + @pytest.mark.parametrize("ext", ["gpkg", "geojson"]) def test_non_utf8_encoding_io(tmp_path, ext, encoded_text): diff --git a/pyogrio/tests/test_path.py b/pyogrio/tests/test_path.py index 6a3010e7..9cc7943c 100644 --- a/pyogrio/tests/test_path.py +++ b/pyogrio/tests/test_path.py @@ -33,6 +33,7 @@ def change_cwd(path): [ # local file paths that should be passed through as is ("data.gpkg", "data.gpkg"), + (Path("data.gpkg"), "data.gpkg"), ("/home/user/data.gpkg", "/home/user/data.gpkg"), (r"C:\User\Documents\data.gpkg", r"C:\User\Documents\data.gpkg"), ("file:///home/user/data.gpkg", "/home/user/data.gpkg"), @@ -85,6 +86,8 @@ def change_cwd(path): "s3://testing/test.zip!a/b/item.shp", "/vsizip/vsis3/testing/test.zip/a/b/item.shp", ), + ("/vsimem/data.gpkg", "/vsimem/data.gpkg"), + (Path("/vsimem/data.gpkg"), "/vsimem/data.gpkg"), ], ) def test_vsi_path(path, expected): @@ -339,19 +342,23 @@ def test_uri_s3_dataframe(aws_env_setup): assert len(df) == 67 -def test_get_vsi_path_or_buffer_obj_to_string(): - path = Path("/tmp/test.gpkg") - assert get_vsi_path_or_buffer(path) == str(path) +@pytest.mark.parametrize( + "path, expected", + [ + (Path("/tmp/test.gpkg"), str(Path("/tmp/test.gpkg"))), + (Path("/vsimem/test.gpkg"), "/vsimem/test.gpkg"), + ], +) +def test_get_vsi_path_or_buffer_obj_to_string(path, expected): + """Verify that get_vsi_path_or_buffer retains forward slashes in /vsimem paths. + + The /vsimem paths should keep forward slashes for GDAL to recognize them as such. + However, on Windows systems, forward slashes are by default replaced by backslashes, + so this test verifies that this doesn't happen for /vsimem paths. + """ + assert get_vsi_path_or_buffer(path) == expected def test_get_vsi_path_or_buffer_fixtures_to_string(tmp_path): path = tmp_path / "test.gpkg" assert get_vsi_path_or_buffer(path) == str(path) - - -@pytest.mark.parametrize( - "raw_path", ["/vsimem/test.shp.zip", "/vsizip//vsimem/test.shp.zip"] -) -def test_vsimem_path_exception(raw_path): - with pytest.raises(ValueError, match=""): - vsi_path(raw_path) diff --git a/pyogrio/tests/test_util.py b/pyogrio/tests/test_util.py new file mode 100644 index 00000000..52ef2a83 --- /dev/null +++ b/pyogrio/tests/test_util.py @@ -0,0 +1,56 @@ +from pathlib import Path + +from pyogrio import vsi_listtree, vsi_unlink +from pyogrio.raw import read, write +from pyogrio.util import vsimem_rmtree_toplevel + +import pytest + + +def test_vsimem_rmtree_toplevel(naturalearth_lowres): + # Prepare test data in /vsimem/ + meta, _, geometry, field_data = read(naturalearth_lowres) + meta["spatial_index"] = False + meta["geometry_type"] = "MultiPolygon" + test_dir_path = Path(f"/vsimem/test/{naturalearth_lowres.stem}.gpkg") + test_dir2_path = Path(f"/vsimem/test2/test2/{naturalearth_lowres.stem}.gpkg") + + write(test_dir_path, geometry, field_data, **meta) + write(test_dir2_path, geometry, field_data, **meta) + + # Check if everything was created properly with listtree + files = vsi_listtree("/vsimem/") + assert test_dir_path.as_posix() in files + assert test_dir2_path.as_posix() in files + + # Test deleting parent dir of file in single directory + vsimem_rmtree_toplevel(test_dir_path) + files = vsi_listtree("/vsimem/") + assert test_dir_path.parent.as_posix() not in files + assert test_dir2_path.as_posix() in files + + # Test deleting top-level dir of file in a subdirectory + vsimem_rmtree_toplevel(test_dir2_path) + assert test_dir2_path.as_posix() not in vsi_listtree("/vsimem/") + + +def test_vsimem_rmtree_toplevel_error(naturalearth_lowres): + # Prepare test data in /vsimem + meta, _, geometry, field_data = read(naturalearth_lowres) + meta["spatial_index"] = False + meta["geometry_type"] = "MultiPolygon" + test_file_path = Path(f"/vsimem/pyogrio_test_{naturalearth_lowres.stem}.gpkg") + + write(test_file_path, geometry, field_data, **meta) + assert test_file_path.as_posix() in vsi_listtree("/vsimem/") + + # Deleting parent dir of non-existent file should raise an error. + with pytest.raises(FileNotFoundError, match="Path does not exist"): + vsimem_rmtree_toplevel("/vsimem/test/non-existent.gpkg") + + # File should still be there + assert test_file_path.as_posix() in vsi_listtree("/vsimem/") + + # Cleanup. + vsi_unlink(test_file_path) + assert test_file_path not in vsi_listtree("/vsimem/") diff --git a/pyogrio/util.py b/pyogrio/util.py index d0e76446..b018ad79 100644 --- a/pyogrio/util.py +++ b/pyogrio/util.py @@ -4,11 +4,14 @@ import sys from packaging.version import Version from pathlib import Path +from typing import Union from urllib.parse import urlparse +from pyogrio._vsi import vsimem_rmtree_toplevel as _vsimem_rmtree_toplevel + def get_vsi_path_or_buffer(path_or_buffer): - """Get vsi-prefixed path or bytes buffer depending on type of path_or_buffer. + """Get VSI-prefixed path or bytes buffer depending on type of path_or_buffer. If path_or_buffer is a bytes object, it will be returned directly and will be read into an in-memory dataset when passed to one of the Cython functions. @@ -29,9 +32,10 @@ def get_vsi_path_or_buffer(path_or_buffer): str or bytes """ - # force path objects to string to specifically ignore their read method + # treat Path objects here already to ignore their read method + to avoid backslashes + # on Windows. if isinstance(path_or_buffer, Path): - return vsi_path(str(path_or_buffer)) + return vsi_path(path_or_buffer) if isinstance(path_or_buffer, bytes): return path_or_buffer @@ -48,13 +52,14 @@ def get_vsi_path_or_buffer(path_or_buffer): return vsi_path(str(path_or_buffer)) -def vsi_path(path: str) -> str: - """Ensure path is a local path or a GDAL-compatible vsi path.""" - if "/vsimem/" in path: - raise ValueError( - "path cannot contain /vsimem/ directly; to use an in-memory dataset a " - "bytes object must be passed instead" - ) +def vsi_path(path: Union[str, Path]) -> str: + """Ensure path is a local path or a GDAL-compatible VSI path.""" + # Convert Path objects to string, but for VSI paths, keep posix style path. + if isinstance(path, Path): + if sys.platform == "win32" and path.as_posix().startswith("/vsi"): + path = path.as_posix() + else: + path = str(path) # path is already in GDAL format if path.startswith("/vsi"): @@ -217,3 +222,26 @@ def _mask_to_wkb(mask): raise ValueError("'mask' parameter must be a Shapely geometry") return shapely.to_wkb(mask) + + +def vsimem_rmtree_toplevel(path: Union[str, Path]): + """Remove the parent directory of the file path recursively. + + This is used for final cleanup of an in-memory dataset, which may have been + created within a directory to contain sibling files. + + Additional VSI handlers may be chained to the left of /vsimem/ in path and + will be ignored. + + Remark: function is defined here to be able to run tests on it. + + Parameters + ---------- + path : str or pathlib.Path + path to in-memory file + + """ + if isinstance(path, Path): + path = path.as_posix() + + _vsimem_rmtree_toplevel(path)