Skip to content

Commit

Permalink
ENH: Add support for detecting write drivers using GDAL (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
brendan-ward authored Sep 6, 2023
1 parent 3ad88d0 commit aacd996
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 68 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
performance impacts for some data sources that would otherwise return an
unknown count (count is used in `read_info`, `read`, `read_dataframe`) (#271).

- Automatically detect supported driver by extension for all available
write drivers and addition of `detect_write_driver` (#270)

### Bug fixes

- Fix int32 overflow when reading int64 columns (#260)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Core
----

.. automodule:: pyogrio
:members: list_drivers, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, __gdal_version__, __gdal_version_string__
:members: list_drivers, detect_write_driver, list_layers, read_bounds, read_info, set_gdal_config_options, get_gdal_config_option, __gdal_version__, __gdal_version_string__

GeoPandas integration
---------------------
Expand Down
2 changes: 2 additions & 0 deletions pyogrio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pyogrio.core import (
list_drivers,
detect_write_driver,
list_layers,
read_bounds,
read_info,
Expand All @@ -27,6 +28,7 @@

__all__ = [
"list_drivers",
"detect_write_driver",
"list_layers",
"read_bounds",
"read_info",
Expand Down
43 changes: 43 additions & 0 deletions pyogrio/_ogr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,46 @@ def _get_driver_metadata_item(driver, metadata_item):
metadata = None

return metadata


def _get_drivers_for_path(path):
cdef OGRSFDriverH driver = NULL
cdef int i
cdef char *name_c

path = str(path).lower()

parts = os.path.splitext(path)
if len(parts) == 2 and len(parts[1]) > 1:
ext = parts[1][1:]
else:
ext = None


# allow specific drivers to have a .zip extension to match GDAL behavior
if ext == 'zip':
if path.endswith('.shp.zip'):
ext = 'shp.zip'
elif path.endswith('.gpkg.zip'):
ext = 'gpkg.zip'

drivers = []
for i in range(OGRGetDriverCount()):
driver = OGRGetDriver(i)
name_c = <char *>OGR_Dr_GetName(driver)
name = get_string(name_c)

if not ogr_driver_supports_write(name):
continue

# extensions is a space-delimited list of supported extensions
# for driver
extensions = _get_driver_metadata_item(name, "DMD_EXTENSIONS")
if ext is not None and extensions is not None and ext in extensions.lower().split(' '):
drivers.append(name)
else:
prefix = _get_driver_metadata_item(name, "DMD_CONNECTION_PREFIX")
if prefix is not None and path.startswith(prefix.lower()):
drivers.append(name)

return drivers
40 changes: 38 additions & 2 deletions pyogrio/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pyogrio._env import GDALEnv
from pyogrio.raw import _preprocess_options_key_value
from pyogrio.util import get_vsi_path
from pyogrio.util import get_vsi_path, _preprocess_options_key_value


with GDALEnv():
Expand All @@ -16,6 +15,7 @@
init_proj_data as _init_proj_data,
remove_virtual_file,
_register_drivers,
_get_drivers_for_path,
)
from pyogrio._err import _register_error_handler
from pyogrio._io import ogr_list_layers, ogr_read_bounds, ogr_read_info
Expand Down Expand Up @@ -58,6 +58,42 @@ def list_drivers(read=False, write=False):
return drivers


def detect_write_driver(path):
"""Attempt to infer the driver for a path by extension or prefix. Only
drivers that support write capabilities will be detected.
If the path cannot be resolved to a single driver, a ValueError will be
raised.
Parameters
----------
path : str
Returns
-------
str
name of the driver, if detected
"""
# try to infer driver from path
drivers = _get_drivers_for_path(path)

if len(drivers) == 0:
raise ValueError(
f"Could not infer driver from path: {path}; please specify driver "
"explicitly"
)

# if there are multiple drivers detected, user needs to specify the correct
# one manually
elif len(drivers) > 1:
raise ValueError(
f"Could not infer driver from path: {path}; multiple drivers are "
"available for that extension. Please specify driver explicitly"
)

return drivers[0]


def list_layers(path_or_buffer, /):
"""List layers available in an OGR data source.
Expand Down
12 changes: 9 additions & 3 deletions pyogrio/geopandas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import numpy as np
from pyogrio.raw import DRIVERS_NO_MIXED_SINGLE_MULTI, DRIVERS_NO_MIXED_DIMENSIONS
from pyogrio.raw import detect_driver, read, read_arrow, write
from pyogrio.raw import (
DRIVERS_NO_MIXED_SINGLE_MULTI,
DRIVERS_NO_MIXED_DIMENSIONS,
detect_write_driver,
read,
read_arrow,
write,
)
from pyogrio.errors import DataSourceError


Expand Down Expand Up @@ -312,7 +318,7 @@ def write_dataframe(
raise ValueError("'df' must be a DataFrame or GeoDataFrame")

if driver is None:
driver = detect_driver(path)
driver = detect_write_driver(path)

geometry_columns = df.columns[df.dtypes == "geometry"]
if len(geometry_columns) > 1:
Expand Down
58 changes: 3 additions & 55 deletions pyogrio/raw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings
import os

from pyogrio._env import GDALEnv
from pyogrio.core import detect_write_driver
from pyogrio.errors import DataSourceError
from pyogrio.util import get_vsi_path, vsi_path
from pyogrio.util import get_vsi_path, vsi_path, _preprocess_options_key_value

with GDALEnv():
from pyogrio._io import ogr_open_arrow, ogr_read, ogr_write
Expand All @@ -16,17 +16,6 @@
)


DRIVERS = {
".fgb": "FlatGeobuf",
".geojson": "GeoJSON",
".geojsonl": "GeoJSONSeq",
".geojsons": "GeoJSONSeq",
".gpkg": "GPKG",
".json": "GeoJSON",
".shp": "ESRI Shapefile",
}


DRIVERS_NO_MIXED_SINGLE_MULTI = {
"FlatGeobuf",
"GPKG",
Expand Down Expand Up @@ -310,26 +299,6 @@ def open_arrow(
remove_virtual_file(path)


def detect_driver(path):
# try to infer driver from path
parts = os.path.splitext(path)
if len(parts) != 2:
raise ValueError(
f"Could not infer driver from path: {path}; please specify driver "
"explicitly"
)

ext = parts[1].lower()
driver = DRIVERS.get(ext, None)
if driver is None:
raise ValueError(
f"Could not infer driver from path: {path}; please specify driver "
"explicitly"
)

return driver


def _parse_options_names(xml):
"""Convert metadata xml to list of names"""
# Based on Fiona's meta.py
Expand All @@ -347,27 +316,6 @@ def _parse_options_names(xml):
return options


def _preprocess_options_key_value(options):
"""
Preprocess options, eg `spatial_index=True` gets converted
to `SPATIAL_INDEX="YES"`.
"""
if not isinstance(options, dict):
raise TypeError(f"Expected options to be a dict, got {type(options)}")

result = {}
for k, v in options.items():
if v is None:
continue
k = k.upper()
if isinstance(v, bool):
v = "ON" if v else "OFF"
else:
v = str(v)
result[k] = v
return result


def write(
path,
geometry,
Expand All @@ -393,7 +341,7 @@ def write(
path = vsi_path(str(path))

if driver is None:
driver = detect_driver(path)
driver = detect_write_driver(path)

# verify that driver supports writing
if not ogr_driver_supports_write(driver):
Expand Down
15 changes: 15 additions & 0 deletions pyogrio/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@


_data_dir = Path(__file__).parent.resolve() / "fixtures"

# mapping of driver extension to driver name for well-supported drivers
DRIVERS = {
".fgb": "FlatGeobuf",
".geojson": "GeoJSON",
".geojsonl": "GeoJSONSeq",
".geojsons": "GeoJSONSeq",
".gpkg": "GPKG",
".json": "GeoJSON",
".shp": "ESRI Shapefile",
}

# mapping of driver name to extension
DRIVER_EXT = {driver: ext for ext, driver in DRIVERS.items()}

ALL_EXTS = [".fgb", ".geojson", ".geojsonl", ".gpkg", ".shp"]


Expand Down
53 changes: 53 additions & 0 deletions pyogrio/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_gdal_config_option,
get_gdal_data_path,
)
from pyogrio.core import detect_write_driver
from pyogrio.errors import DataSourceError, DataLayerError

from pyogrio._env import GDALEnv
Expand Down Expand Up @@ -44,6 +45,58 @@ def test_gdal_geos_version():
assert __gdal_geos_version__ is None or isinstance(__gdal_geos_version__, tuple)


@pytest.mark.parametrize(
"path,expected",
[
("test.shp", "ESRI Shapefile"),
("test.shp.zip", "ESRI Shapefile"),
("test.geojson", "GeoJSON"),
("test.geojsonl", "GeoJSONSeq"),
("test.gpkg", "GPKG"),
pytest.param(
"test.gpkg.zip",
"GPKG",
marks=pytest.mark.skipif(
__gdal_version__ < (3, 7, 0),
reason="writing *.gpkg.zip requires GDAL >= 3.7.0",
),
),
# postgres can be detected by prefix instead of extension
pytest.param(
"PG:dbname=test",
"PostgreSQL",
marks=pytest.mark.skipif(
"PostgreSQL" not in list_drivers(),
reason="PostgreSQL path test requires PostgreSQL driver",
),
),
],
)
def test_detect_write_driver(path, expected):
assert detect_write_driver(path) == expected


@pytest.mark.parametrize(
"path",
[
"test.svg", # only supports read
"test.", # not a valid extension
"test", # no extension or prefix
"test.foo", # not a valid extension
"FOO:test", # not a valid prefix
],
)
def test_detect_write_driver_unsupported(path):
with pytest.raises(ValueError, match="Could not infer driver from path"):
detect_write_driver(path)


@pytest.mark.parametrize("path", ["test.xml", "test.txt"])
def test_detect_write_driver_multiple_unsupported(path):
with pytest.raises(ValueError, match="multiple drivers are available"):
detect_write_driver(path)


@pytest.mark.parametrize(
"driver,expected",
[
Expand Down
3 changes: 1 addition & 2 deletions pyogrio/tests/test_geopandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from pyogrio.errors import DataLayerError, DataSourceError, FeatureError, GeometryError
from pyogrio.geopandas import read_dataframe, write_dataframe
from pyogrio.raw import (
DRIVERS,
DRIVERS_NO_MIXED_DIMENSIONS,
DRIVERS_NO_MIXED_SINGLE_MULTI,
)
from pyogrio.tests.conftest import ALL_EXTS
from pyogrio.tests.conftest import ALL_EXTS, DRIVERS

try:
import pandas as pd
Expand Down
7 changes: 2 additions & 5 deletions pyogrio/tests/test_raw_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
set_gdal_config_options,
__gdal_version__,
)
from pyogrio.raw import DRIVERS, read, write
from pyogrio.raw import read, write
from pyogrio.errors import DataSourceError, DataLayerError, FeatureError
from pyogrio.tests.conftest import prepare_testfile

# mapping of driver name to extension
DRIVER_EXT = {driver: ext for ext, driver in DRIVERS.items()}
from pyogrio.tests.conftest import prepare_testfile, DRIVERS, DRIVER_EXT


def test_read(naturalearth_lowres):
Expand Down
Loading

0 comments on commit aacd996

Please sign in to comment.