From 6805418dc8dd52ed28d45ee8520982208a11cba0 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Sun, 14 Aug 2022 00:37:00 -0400 Subject: [PATCH] :triangular_flag_on_post: Support line and polygon vector type Enable rasterization of line and polygon inputs too! Pretty much just two more elif statements. However, because rasterizing line and polygons using datashader results in boolean type xarray.DataArray outputs that can't be reprojected by rioxarray, had to cast them to uint8. Added parametrized unit tests that ensures the three vector input types work. --- zen3geo/datapipes/datashader.py | 31 +++++++++++---- zen3geo/tests/test_datapipes_datashader.py | 44 ++++++++++++++++++++-- 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/zen3geo/datapipes/datashader.py b/zen3geo/datapipes/datashader.py index d77f4aa..2a13360 100644 --- a/zen3geo/datapipes/datashader.py +++ b/zen3geo/datapipes/datashader.py @@ -9,6 +9,14 @@ datashader = None try: import spatialpandas + from spatialpandas.geometry import ( + PointDtype, + MultiPointDtype, + LineDtype, + MultiLineDtype, + PolygonDtype, + MultiPolygonDtype, + ) except ImportError: spatialpandas = None @@ -155,7 +163,8 @@ def __iter__(self) -> Iterator[xr.DataArray]: # Convert vector to spatialpandas format to allow datashader's # rasterization methods to work try: - _vector = spatialpandas.GeoDataFrame(data=vector.geometry) + columns = ["geometry"] if not hasattr(vector, "columns") else None + _vector = spatialpandas.GeoDataFrame(data=vector, columns=columns) except ValueError as e: raise ValueError( f"Unsupported geometry type(s) {set(vector.geom_type)} detected, " @@ -164,17 +173,23 @@ def __iter__(self) -> Iterator[xr.DataArray]: # Determine geometry type to know which rasterization method to use vector_dtype: spatialpandas.geometry.GeometryDtype = _vector.geometry.dtype - if isinstance( - vector_dtype, - ( - spatialpandas.geometry.PointDtype, - spatialpandas.geometry.MultiPointDtype, - ), - ): + + if isinstance(vector_dtype, (PointDtype, MultiPointDtype)): raster: xr.DataArray = canvas.points( source=_vector, geometry="geometry", **self.kwargs ) + elif isinstance(vector_dtype, (LineDtype, MultiLineDtype)): + raster: xr.DataArray = canvas.line( + source=_vector, geometry="geometry", **self.kwargs + ) + elif isinstance(vector_dtype, (PolygonDtype, MultiPolygonDtype)): + raster: xr.DataArray = canvas.polygons( + source=_vector, geometry="geometry", **self.kwargs + ) + # Convert boolean dtype rasters to uint8 to enable reprojection + if raster.dtype == "bool": + raster: xr.DataArray = raster.astype(dtype="uint8") # Set coordinate transform for raster and ensure affine # transform is correct (the y-coordinate goes from North to South) raster: xr.DataArray = raster.rio.set_crs(input_crs=canvas.crs) diff --git a/zen3geo/tests/test_datapipes_datashader.py b/zen3geo/tests/test_datapipes_datashader.py index 0d749f5..4f64b70 100644 --- a/zen3geo/tests/test_datapipes_datashader.py +++ b/zen3geo/tests/test_datapipes_datashader.py @@ -20,8 +20,9 @@ def fixture_geoms(): geometries = shapely.geometry.GeometryCollection( geoms=[ - shapely.geometry.Point(1, 0), - shapely.geometry.LineString([(10, 0), (10, 5), (0, 0)]), + shapely.geometry.MultiPoint([(4.5, 4.5), (3.5, 1), (6, 3.5)]), + shapely.geometry.LineString([(3, 5), (5, 3), (3, 2), (5, 0)]), + shapely.geometry.Polygon([(6, 5), (3.5, 2.5), (6, 0), (6, 2.5), (5, 2.5)]), ] ) return geometries @@ -57,9 +58,46 @@ def test_datashader_canvas_dataset(): assert hasattr(canvas, "raster") +@pytest.mark.parametrize( + ("geom_type", "sum_val"), [("Point", 3), ("Line", 13), ("Polygon", 15)] +) +def test_datashader_rasterize_vector_geometry(geometries, geom_type, sum_val): + """ + Ensure that DatashaderRasterizer works to rasterize a geopandas.GeoSeries + object of point, line or polygon type into an xarray.DataArray grid. + """ + gpd = pytest.importorskip("geopandas") + + canvas = datashader.Canvas( + plot_width=14, plot_height=10, x_range=(1, 8), y_range=(0, 5) + ) + canvas.crs = "EPSG:4326" + dp = IterableWrapper(iterable=[canvas]) + + geoms = [geom for geom in geometries.geoms if geom_type in geom.type] + vector = gpd.GeoSeries(data=geoms) + vector = vector.set_crs(epsg=4326) + dp_vector = IterableWrapper(iterable=[vector]) + + # Using class constructors + dp_canvas = DatashaderRasterizer(source_datapipe=dp, vector_datapipe=dp_vector) + # Using functional form (recommended) + dp_datashader = dp.rasterize_with_datashader(vector_datapipe=dp_vector) + + assert len(dp_datashader) == 1 + it = iter(dp_datashader) + dataarray = next(it) + + assert dataarray.data.sum() == sum_val + assert dataarray.dims == ("y", "x") + assert dataarray.rio.crs == "EPSG:4326" + assert dataarray.rio.shape == (10, 14) + assert dataarray.rio.transform().e == -0.5 + + def test_datashader_rasterize_missing_crs(geometries): """ - Ensure that DatashaderRasterizer raises a ValueError when either of the + Ensure that DatashaderRasterizer raises a ValueError when either the input datashader.Canvas or geopandas.GeoDataFrame has no crs attribute. """ gpd = pytest.importorskip("geopandas")