Skip to content

Commit

Permalink
Add check for lat long location
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukhil Patel authored and Sukhil Patel committed Jul 30, 2024
1 parent 75aad90 commit fb95122
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
21 changes: 15 additions & 6 deletions ocf_datapipes/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,33 @@ def _get_idx_of_pixel_closest_to_poi(

def _get_idx_of_pixel_closest_to_poi_geostationary(
xr_data: xr.DataArray,
center_osgb: Location,
center_coordinate: Location,
) -> Location:
"""
Return x and y index location of pixel at center of region of interest.
Args:
xr_data: Xarray dataset
center_osgb: Center in OSGB coordinates
center_coordinate: Central coordinate
Returns:
Location for the center pixel in geostationary coordinates
"""

xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data)
if center_coordinate.coordinate_system == "osgb":
x, y = osgb_to_geostationary_area_coords(x=center_coordinate.x,
y=center_coordinate.y,
xr_data=xr_data)
elif center_coordinate.coordinate_system == "lon_lat":
x, y = lon_lat_to_geostationary_area_coords(x=center_coordinate.x,
y=center_coordinate.y,
xr_data=xr_data)
else:
raise NotImplementedError(f"Only 'osgb' and 'lon_lat' location coordinates are \
supported in conversion to geostationary \
- not '{center_coordinate.coordinate_system}'")

x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=xr_data)
center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")

# Check that the requested point lies within the data
assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max()
assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max()
Expand Down Expand Up @@ -390,7 +399,7 @@ def select_spatial_slice_pixels(
if xr_coords == "geostationary":
center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary(
xr_data=xr_data,
center_osgb=location,
center_coordinate=location,
)
else:
center_idx: Location = _get_idx_of_pixel_closest_to_poi(
Expand Down
32 changes: 29 additions & 3 deletions tests/select/test_select_spatial_slice.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import numpy as np
import xarray as xr
from ocf_datapipes.utils import Location

from ocf_datapipes.select import (
PickLocations,
SelectSpatialSliceMeters,
SelectSpatialSlicePixels,
)

from ocf_datapipes.select.select_spatial_slice import slice_spatial_pixel_window_from_xarray
from ocf_datapipes.select.select_spatial_slice import (
_get_idx_of_pixel_closest_to_poi_geostationary,
slice_spatial_pixel_window_from_xarray,
)
from ocf_datapipes.utils import Location


def test_slice_spatial_pixel_window_from_xarray_function():
Expand Down Expand Up @@ -158,3 +160,27 @@ def test_select_spatial_slice_meters_icon_global(passiv_datapipe, icon_global_da
# ICON global has roughly 13km spacing, so this should be around 7x7 grid
assert len(data.longitude) == 49
assert len(data.latitude) == 49

def test_get_idx_of_pixel_closest_to_poi_geostationary_lon_lat_location():
# Create dummy data
x = np.arange(5000000, -5000000, -5000)
y = np.arange(5000000, -5000000, -5000)[::-1]

xr_data = xr.Dataset(
data_vars=dict(
data=(["x_geostationary", "y_geostationary"], np.random.normal(size=(len(x), len(y)))),
),
coords=dict(
x_geostationary=(["x_geostationary"], x),
y_geostationary=(["y_geostationary"], y),
),
)
xr_data.attrs["area"] = 'msg_seviri_iodc_3km:\n description: MSG SEVIRI Indian Ocean Data Coverage service area definition with\n 3 km resolution\n projection:\n proj: geos\n lon_0: 41.5\n h: 35785831\n x_0: 0\n y_0: 0\n a: 6378169\n rf: 295.488065897014\n no_defs: null\n type: crs\n shape:\n height: 3712\n width: 3712\n area_extent:\n lower_left_xy: [5000000, 5000000]\n upper_right_xy: [-5000000, -5000000]\n units: m\n'


center = Location(x=77.1, y=28.6, coordinate_system="lon_lat")

location_center_idx = _get_idx_of_pixel_closest_to_poi_geostationary(xr_data=xr_data, center_coordinate=center)

assert location_center_idx.coordinate_system == 'idx'
assert location_center_idx.x == 2000

0 comments on commit fb95122

Please sign in to comment.