Skip to content

Commit

Permalink
🚩 Support line and polygon vector type
Browse files Browse the repository at this point in the history
Enable rasterization of line and polygon inputs too! Pretty much just two more elif statements. However, because rasterizing line and polygons using datashader results in boolean type xarray.DataArray outputs that can't be reprojected by rioxarray, had to cast them to uint8. Added parametrized unit tests that ensures the three vector input types work.
  • Loading branch information
weiji14 committed Aug 14, 2022
1 parent 5c36f39 commit 6805418
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
31 changes: 23 additions & 8 deletions zen3geo/datapipes/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
datashader = None
try:
import spatialpandas
from spatialpandas.geometry import (
PointDtype,
MultiPointDtype,
LineDtype,
MultiLineDtype,
PolygonDtype,
MultiPolygonDtype,
)
except ImportError:
spatialpandas = None

Expand Down Expand Up @@ -155,7 +163,8 @@ def __iter__(self) -> Iterator[xr.DataArray]:
# Convert vector to spatialpandas format to allow datashader's
# rasterization methods to work
try:
_vector = spatialpandas.GeoDataFrame(data=vector.geometry)
columns = ["geometry"] if not hasattr(vector, "columns") else None
_vector = spatialpandas.GeoDataFrame(data=vector, columns=columns)
except ValueError as e:
raise ValueError(
f"Unsupported geometry type(s) {set(vector.geom_type)} detected, "
Expand All @@ -164,17 +173,23 @@ def __iter__(self) -> Iterator[xr.DataArray]:

# Determine geometry type to know which rasterization method to use
vector_dtype: spatialpandas.geometry.GeometryDtype = _vector.geometry.dtype
if isinstance(
vector_dtype,
(
spatialpandas.geometry.PointDtype,
spatialpandas.geometry.MultiPointDtype,
),
):

if isinstance(vector_dtype, (PointDtype, MultiPointDtype)):
raster: xr.DataArray = canvas.points(
source=_vector, geometry="geometry", **self.kwargs
)
elif isinstance(vector_dtype, (LineDtype, MultiLineDtype)):
raster: xr.DataArray = canvas.line(
source=_vector, geometry="geometry", **self.kwargs
)
elif isinstance(vector_dtype, (PolygonDtype, MultiPolygonDtype)):
raster: xr.DataArray = canvas.polygons(
source=_vector, geometry="geometry", **self.kwargs
)

# Convert boolean dtype rasters to uint8 to enable reprojection
if raster.dtype == "bool":
raster: xr.DataArray = raster.astype(dtype="uint8")
# Set coordinate transform for raster and ensure affine
# transform is correct (the y-coordinate goes from North to South)
raster: xr.DataArray = raster.rio.set_crs(input_crs=canvas.crs)
Expand Down
44 changes: 41 additions & 3 deletions zen3geo/tests/test_datapipes_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def fixture_geoms():

geometries = shapely.geometry.GeometryCollection(
geoms=[
shapely.geometry.Point(1, 0),
shapely.geometry.LineString([(10, 0), (10, 5), (0, 0)]),
shapely.geometry.MultiPoint([(4.5, 4.5), (3.5, 1), (6, 3.5)]),
shapely.geometry.LineString([(3, 5), (5, 3), (3, 2), (5, 0)]),
shapely.geometry.Polygon([(6, 5), (3.5, 2.5), (6, 0), (6, 2.5), (5, 2.5)]),
]
)
return geometries
Expand Down Expand Up @@ -57,9 +58,46 @@ def test_datashader_canvas_dataset():
assert hasattr(canvas, "raster")


@pytest.mark.parametrize(
("geom_type", "sum_val"), [("Point", 3), ("Line", 13), ("Polygon", 15)]
)
def test_datashader_rasterize_vector_geometry(geometries, geom_type, sum_val):
"""
Ensure that DatashaderRasterizer works to rasterize a geopandas.GeoSeries
object of point, line or polygon type into an xarray.DataArray grid.
"""
gpd = pytest.importorskip("geopandas")

canvas = datashader.Canvas(
plot_width=14, plot_height=10, x_range=(1, 8), y_range=(0, 5)
)
canvas.crs = "EPSG:4326"
dp = IterableWrapper(iterable=[canvas])

geoms = [geom for geom in geometries.geoms if geom_type in geom.type]
vector = gpd.GeoSeries(data=geoms)
vector = vector.set_crs(epsg=4326)
dp_vector = IterableWrapper(iterable=[vector])

# Using class constructors
dp_canvas = DatashaderRasterizer(source_datapipe=dp, vector_datapipe=dp_vector)
# Using functional form (recommended)
dp_datashader = dp.rasterize_with_datashader(vector_datapipe=dp_vector)

assert len(dp_datashader) == 1
it = iter(dp_datashader)
dataarray = next(it)

assert dataarray.data.sum() == sum_val
assert dataarray.dims == ("y", "x")
assert dataarray.rio.crs == "EPSG:4326"
assert dataarray.rio.shape == (10, 14)
assert dataarray.rio.transform().e == -0.5


def test_datashader_rasterize_missing_crs(geometries):
"""
Ensure that DatashaderRasterizer raises a ValueError when either of the
Ensure that DatashaderRasterizer raises a ValueError when either the input
datashader.Canvas or geopandas.GeoDataFrame has no crs attribute.
"""
gpd = pytest.importorskip("geopandas")
Expand Down

0 comments on commit 6805418

Please sign in to comment.