Skip to content

Commit

Permalink
Patch/2d xarray point (#761)
Browse files Browse the repository at this point in the history
* fix point method for 2D dataarray

* add more tests
  • Loading branch information
vincentsarago authored Oct 29, 2024
1 parent b833783 commit 03cb853
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
6 changes: 5 additions & 1 deletion rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,11 @@ def point(

y, x = rowcol(ds.rio.transform(), ds_lon, ds_lat)

arr = ds[:, int(y[0]), int(x[0])].to_masked_array()
if ds.ndim == 2:
arr = numpy.expand_dims(ds[int(y[0]), int(x[0])].to_masked_array(), axis=0)
else:
arr = ds[:, int(y[0]), int(x[0])].to_masked_array()

arr.mask |= arr.data == ds.rio.nodata

return PointData(
Expand Down
34 changes: 31 additions & 3 deletions tests/test_io_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_xarray_reader():
assert info.count == 1
assert info.attrs

with XarrayReader(data) as dst:
stats = dst.statistics()
assert stats["2022-01-01T00:00:00.000000000"]
assert stats["2022-01-01T00:00:00.000000000"].min == 0.0
Expand Down Expand Up @@ -221,7 +220,6 @@ def test_xarray_reader_external_nodata():
assert info.width == 360
assert info.count == 1

with XarrayReader(data) as dst:
# TILE
img = dst.tile(0, 0, 1)
assert img.mask.all()
Expand Down Expand Up @@ -514,7 +512,6 @@ def test_xarray_reader_no_dims():
assert info.count == 1
assert info.attrs

with XarrayReader(data) as dst:
stats = dst.statistics()
assert stats["value"]
assert stats["value"].min == 0.0
Expand All @@ -533,6 +530,17 @@ def test_xarray_reader_no_dims():
assert img.band_names == ["value"]
assert img.dataset_statistics == ((arr.min(), arr.max()),)

pt = dst.point(0, 0)
assert pt.count == 1
assert pt.band_names == ["value"]
assert pt.coordinates
xys = [[0, 2.499], [0, 2.501], [-4.999, 0], [-5.001, 0], [-170, 80]]
for xy in xys:
x = xy[0]
y = xy[1]
pt = dst.point(x, y)
assert pt.data[0] == data.sel(x=x, y=y, method="nearest")


def test_xarray_reader_Y_axis():
"""test XarrayReader with 2D dataset."""
Expand Down Expand Up @@ -568,6 +576,16 @@ def test_xarray_reader_Y_axis():
img = dst.tile(1, 1, 2)
assert img.array[0, 0, 0] > img.array[0, -1, -1]

pt = dst.point(0, 0)
assert pt.count == 1
assert pt.coordinates
xys = [[0, 2.499], [0, 2.501], [-4.999, 0], [-5.001, 0], [-170, 80]]
for xy in xys:
x = xy[0]
y = xy[1]
pt = dst.point(x, y)
assert pt.data[0] == data.sel(x=x, y=y, method="nearest")

# Create a DataArray where the y coordinates are in decreasing order
# (this is typical for raster data)
# This array will have a negative y resolution in the affine transform
Expand Down Expand Up @@ -599,3 +617,13 @@ def test_xarray_reader_Y_axis():

img = dst.tile(1, 1, 2)
assert img.array[0, 0, 0] < img.array[0, -1, -1]

pt = dst.point(0, 0)
assert pt.count == 1
assert pt.coordinates
xys = [[0, 2.499], [0, 2.501], [-4.999, 0], [-5.001, 0], [-170, 80]]
for xy in xys:
x = xy[0]
y = xy[1]
pt = dst.point(x, y)
assert pt.data[0] == data.sel(x=x, y=y, method="nearest")

0 comments on commit 03cb853

Please sign in to comment.