Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add support for detecting write drivers using GDAL #270

Merged
merged 7 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
performance impacts for some data sources that would otherwise return an
unknown count (count is used in `read_info`, `read`, `read_dataframe`) (#271).

- Automatically detect supported driver by extension for all available
write drivers and addition of `detect_write_driver` (#270)

### Bug fixes

- Fix int32 overflow when reading int64 columns (#260)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Core
----

.. automodule:: pyogrio
:members: list_drivers, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, __gdal_version__, __gdal_version_string__
:members: list_drivers, detect_write_driver, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, __gdal_version__, __gdal_version_string__

GeoPandas integration
---------------------
Expand Down
2 changes: 2 additions & 0 deletions pyogrio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pyogrio.core import (
list_drivers,
detect_write_driver,
list_layers,
read_bounds,
read_info,
Expand All @@ -27,6 +28,7 @@

__all__ = [
"list_drivers",
"detect_write_driver",
"list_layers",
"read_bounds",
"read_info",
Expand Down
43 changes: 43 additions & 0 deletions pyogrio/_ogr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,46 @@ def _get_driver_metadata_item(driver, metadata_item):
metadata = None

return metadata


def _get_drivers_for_path(path):
cdef OGRSFDriverH driver = NULL
cdef int i
cdef char *name_c

brendan-ward marked this conversation as resolved.
Show resolved Hide resolved
path = str(path).lower()

parts = os.path.splitext(path)
if len(parts) == 2 and len(parts[1]) > 1:
ext = parts[1][1:]
else:
ext = None


# allow specific drivers to have a .zip extension to match GDAL behavior
if ext == 'zip':
if path.endswith('.shp.zip'):
ext = 'shp.zip'
elif path.endswith('.gpkg.zip'):
ext = 'gpkg.zip'

drivers = []
for i in range(OGRGetDriverCount()):
driver = OGRGetDriver(i)
name_c = <char *>OGR_Dr_GetName(driver)
name = get_string(name_c)

if not ogr_driver_supports_write(name):
continue

# extensions is a space-delimited list of supported extensions
# for driver
extensions = _get_driver_metadata_item(name, "DMD_EXTENSIONS")
if ext is not None and extensions is not None and ext in extensions.lower().split(' '):
drivers.append(name)
else:
prefix = _get_driver_metadata_item(name, "DMD_CONNECTION_PREFIX")
if prefix is not None and path.startswith(prefix.lower()):
drivers.append(name)

return drivers
40 changes: 38 additions & 2 deletions pyogrio/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pyogrio._env import GDALEnv
from pyogrio.raw import _preprocess_options_key_value
from pyogrio.util import get_vsi_path
from pyogrio.util import get_vsi_path, _preprocess_options_key_value


with GDALEnv():
Expand All @@ -16,6 +15,7 @@
init_proj_data as _init_proj_data,
remove_virtual_file,
_register_drivers,
_get_drivers_for_path,
)
from pyogrio._err import _register_error_handler
from pyogrio._io import ogr_list_layers, ogr_read_bounds, ogr_read_info
Expand Down Expand Up @@ -58,6 +58,42 @@ def list_drivers(read=False, write=False):
return drivers


def detect_write_driver(path):
"""Attempt to infer the driver for a path by extension or prefix. Only
drivers that support write capabilities will be detected.

If the path cannot be resolved to a single driver, a ValueError will be
raised.

Parameters
----------
path : str

Returns
-------
str
name of the driver, if detected
"""
# try to infer driver from path
drivers = _get_drivers_for_path(path)

if len(drivers) == 0:
raise ValueError(
f"Could not infer driver from path: {path}; please specify driver "
"explicitly"
)

# if there are multiple drivers detected, user needs to specify the correct
# one manually
elif len(drivers) > 1:
raise ValueError(
f"Could not infer driver from path: {path}; multiple drivers are "
"available for that extension. Please specify driver explicitly"
)

return drivers[0]


def list_layers(path_or_buffer, /):
"""List layers available in an OGR data source.

Expand Down
12 changes: 9 additions & 3 deletions pyogrio/geopandas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import numpy as np
from pyogrio.raw import DRIVERS_NO_MIXED_SINGLE_MULTI, DRIVERS_NO_MIXED_DIMENSIONS
from pyogrio.raw import detect_driver, read, read_arrow, write
from pyogrio.raw import (
DRIVERS_NO_MIXED_SINGLE_MULTI,
DRIVERS_NO_MIXED_DIMENSIONS,
detect_write_driver,
read,
read_arrow,
write,
)
from pyogrio.errors import DataSourceError


Expand Down Expand Up @@ -312,7 +318,7 @@ def write_dataframe(
raise ValueError("'df' must be a DataFrame or GeoDataFrame")

if driver is None:
driver = detect_driver(path)
driver = detect_write_driver(path)

geometry_columns = df.columns[df.dtypes == "geometry"]
if len(geometry_columns) > 1:
Expand Down
58 changes: 3 additions & 55 deletions pyogrio/raw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings
import os

from pyogrio._env import GDALEnv
from pyogrio.core import detect_write_driver
from pyogrio.errors import DataSourceError
from pyogrio.util import get_vsi_path, vsi_path
from pyogrio.util import get_vsi_path, vsi_path, _preprocess_options_key_value

with GDALEnv():
from pyogrio._io import ogr_open_arrow, ogr_read, ogr_write
Expand All @@ -16,17 +16,6 @@
)


DRIVERS = {
".fgb": "FlatGeobuf",
".geojson": "GeoJSON",
".geojsonl": "GeoJSONSeq",
".geojsons": "GeoJSONSeq",
".gpkg": "GPKG",
".json": "GeoJSON",
".shp": "ESRI Shapefile",
}


DRIVERS_NO_MIXED_SINGLE_MULTI = {
"FlatGeobuf",
"GPKG",
Expand Down Expand Up @@ -310,26 +299,6 @@ def open_arrow(
remove_virtual_file(path)


def detect_driver(path):
# try to infer driver from path
parts = os.path.splitext(path)
if len(parts) != 2:
raise ValueError(
f"Could not infer driver from path: {path}; please specify driver "
"explicitly"
)

ext = parts[1].lower()
driver = DRIVERS.get(ext, None)
if driver is None:
raise ValueError(
f"Could not infer driver from path: {path}; please specify driver "
"explicitly"
)

return driver


def _parse_options_names(xml):
"""Convert metadata xml to list of names"""
# Based on Fiona's meta.py
Expand All @@ -347,27 +316,6 @@ def _parse_options_names(xml):
return options


def _preprocess_options_key_value(options):
"""
Preprocess options, eg `spatial_index=True` gets converted
to `SPATIAL_INDEX="YES"`.
"""
if not isinstance(options, dict):
raise TypeError(f"Expected options to be a dict, got {type(options)}")

result = {}
for k, v in options.items():
if v is None:
continue
k = k.upper()
if isinstance(v, bool):
v = "ON" if v else "OFF"
else:
v = str(v)
result[k] = v
return result


def write(
path,
geometry,
Expand All @@ -393,7 +341,7 @@ def write(
path = vsi_path(str(path))

if driver is None:
driver = detect_driver(path)
driver = detect_write_driver(path)

# verify that driver supports writing
if not ogr_driver_supports_write(driver):
Expand Down
15 changes: 15 additions & 0 deletions pyogrio/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@


_data_dir = Path(__file__).parent.resolve() / "fixtures"

# mapping of driver extension to driver name for well-supported drivers
DRIVERS = {
".fgb": "FlatGeobuf",
".geojson": "GeoJSON",
".geojsonl": "GeoJSONSeq",
".geojsons": "GeoJSONSeq",
".gpkg": "GPKG",
".json": "GeoJSON",
".shp": "ESRI Shapefile",
}

# mapping of driver name to extension
DRIVER_EXT = {driver: ext for ext, driver in DRIVERS.items()}

ALL_EXTS = [".fgb", ".geojson", ".geojsonl", ".gpkg", ".shp"]


Expand Down
53 changes: 53 additions & 0 deletions pyogrio/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_gdal_config_option,
get_gdal_data_path,
)
from pyogrio.core import detect_write_driver
from pyogrio.errors import DataSourceError, DataLayerError

from pyogrio._env import GDALEnv
Expand Down Expand Up @@ -44,6 +45,58 @@ def test_gdal_geos_version():
assert __gdal_geos_version__ is None or isinstance(__gdal_geos_version__, tuple)


@pytest.mark.parametrize(
"path,expected",
[
("test.shp", "ESRI Shapefile"),
("test.shp.zip", "ESRI Shapefile"),
("test.geojson", "GeoJSON"),
("test.geojsonl", "GeoJSONSeq"),
("test.gpkg", "GPKG"),
pytest.param(
"test.gpkg.zip",
"GPKG",
marks=pytest.mark.skipif(
__gdal_version__ < (3, 7, 0),
reason="writing *.gpkg.zip requires GDAL >= 3.7.0",
),
),
# postgres can be detected by prefix instead of extension
pytest.param(
"PG:dbname=test",
"PostgreSQL",
marks=pytest.mark.skipif(
"PostgreSQL" not in list_drivers(),
reason="PostgreSQL path test requires PostgreSQL driver",
),
),
],
)
def test_detect_write_driver(path, expected):
assert detect_write_driver(path) == expected


@pytest.mark.parametrize(
"path",
[
"test.svg", # only supports read
"test.", # not a valid extension
"test", # no extension or prefix
"test.foo", # not a valid extension
"FOO:test", # not a valid prefix
],
)
def test_detect_write_driver_unsupported(path):
with pytest.raises(ValueError, match="Could not infer driver from path"):
detect_write_driver(path)


@pytest.mark.parametrize("path", ["test.xml", "test.txt"])
def test_detect_write_driver_multiple_unsupported(path):
with pytest.raises(ValueError, match="multiple drivers are available"):
detect_write_driver(path)


@pytest.mark.parametrize(
"driver,expected",
[
Expand Down
3 changes: 1 addition & 2 deletions pyogrio/tests/test_geopandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from pyogrio.errors import DataLayerError, DataSourceError, FeatureError, GeometryError
from pyogrio.geopandas import read_dataframe, write_dataframe
from pyogrio.raw import (
DRIVERS,
DRIVERS_NO_MIXED_DIMENSIONS,
DRIVERS_NO_MIXED_SINGLE_MULTI,
)
from pyogrio.tests.conftest import ALL_EXTS
from pyogrio.tests.conftest import ALL_EXTS, DRIVERS

try:
import pandas as pd
Expand Down
7 changes: 2 additions & 5 deletions pyogrio/tests/test_raw_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
set_gdal_config_options,
__gdal_version__,
)
from pyogrio.raw import DRIVERS, read, write
from pyogrio.raw import read, write
from pyogrio.errors import DataSourceError, DataLayerError, FeatureError
from pyogrio.tests.conftest import prepare_testfile

# mapping of driver name to extension
DRIVER_EXT = {driver: ext for ext, driver in DRIVERS.items()}
from pyogrio.tests.conftest import prepare_testfile, DRIVERS, DRIVER_EXT


def test_read(naturalearth_lowres):
Expand Down
Loading
Loading