Skip to content

Commit

Permalink
tests, qa and added basic to_xarray method to ndarray Wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyCMWF committed Jul 14, 2023
1 parent 26689ba commit cc33143
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 55 deletions.
37 changes: 22 additions & 15 deletions earthkit/data/utils/module_inputs_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
Module containing methods to transform the inputs of functions based on the function type setting,
common signitures or mapping defined at call time
"""
from ast import Module
import inspect
import types
import typing as T
from ast import Module
from functools import wraps

from earthkit.data import transform
Expand All @@ -37,6 +37,7 @@ def ensure_iterable(input_item):
return [input_item]
return input_item


def ensure_tuple(input_item):
"""Ensure that an item is iterable"""
if not isinstance(input_item, tuple):
Expand Down Expand Up @@ -69,6 +70,7 @@ def transform_function_inputs(
[type]
[description]
"""

def _wrapper(kwarg_types, convert_types, *args, **kwargs):
kwarg_types = {**kwarg_types}
signature = inspect.signature(function)
Expand All @@ -80,26 +82,20 @@ def _wrapper(kwarg_types, convert_types, *args, **kwargs):
arg_names.append(name)
kwargs[name] = arg

# Expand any Wrapper objects to their native data format:
for k, v in kwargs.items():
if isinstance(v, Wrapper):
try:
kwargs[k] = v.data
except:
pass

convert_kwargs = [k for k in kwargs if k in mapping]
# Only convert some data-types, this can be used to prevent conversion for for functions which
# accept a long-list of formats, e.g. numpy methods can accept xarray, pandas and more

# Filter for convert_types
if convert_types:
# Ensure convert_types is a dictionary
if not isinstance(convert_types, dict):
convert_types = {key: convert_types for key in convert_kwargs}

convert_kwargs = [
k for k in convert_kwargs if isinstance(kwargs[k], ensure_tuple(convert_types.get(k, ())))
k
for k in convert_kwargs
if isinstance(kwargs[k], ensure_tuple(convert_types.get(k, ())))
]

# transform args/kwargs
Expand All @@ -111,13 +107,20 @@ def _wrapper(kwarg_types, convert_types, *args, **kwargs):
for kwarg_type in kwarg_types:
try:
kwargs[key] = transform(value, kwarg_type)
except:
except Exception:
# Transform was not possible, move to next kwarg type.
# If no transform is possible, format is unchanged and we rely on function to raise
# an Error.
continue
break

# Anything that is still a Wrapper object, expand to native data format:
for k, v in [(_k, _v) for _k, _v in kwargs.items() if isinstance(_v, Wrapper)]:
try:
kwargs[k] = v.data
except Exception:
pass

# Extract args from kwargs:
args = [kwargs.pop(name) for name in arg_names]
return function(*args, **kwargs)
Expand Down Expand Up @@ -163,8 +166,12 @@ def transform_module_inputs(in_module, **decorator_kwargs):
for name in dir(in_module):
func = getattr(in_module, name)
# Wrap any functions that are not hidden
if not name.startswith('_') and isinstance(func, types.FunctionType):
setattr(wrapped_module, name, transform_function_inputs(func, **decorator_kwargs))
if not name.startswith("_") and isinstance(func, types.FunctionType):
setattr(
wrapped_module,
name,
transform_function_inputs(func, **decorator_kwargs),
)
else:
# If not a func, we just copy
setattr(wrapped_module, name, func)
Expand Down
12 changes: 12 additions & 0 deletions earthkit/data/wrappers/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def to_numpy(self):
"""
return self.data

def to_xarray(self, **kwargs):
"""
Return an xarray.DataArray representation of the data.
Returns
-------
xarray.DataArray
"""
import xarray as xr

return xr.DataArray(self.data, **kwargs)


def wrapper(data, *args, **kwargs):
import numpy as np
Expand Down
97 changes: 57 additions & 40 deletions tests/utils/test_module_inputs_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,86 @@
#

import numpy as np
import xarray as xr
import pandas as pd
import xarray as xr

from earthkit.data.utils import module_inputs_wrapper
from earthkit.data import from_source, from_object
from earthkit.data import from_object, from_source
from earthkit.data.readers import Reader
from earthkit.data.wrappers import Wrapper
from earthkit.data.utils import module_inputs_wrapper

TEST_NP = np.arange(10)
TEST_NP2 = np.arange(10)

TEST_DF = pd.DataFrame({'index': TEST_NP, 'data': TEST_NP2}).set_index('index')
TEST_DF = pd.DataFrame({"index": TEST_NP, "data": TEST_NP2}).set_index("index")

TEST_DA = xr.DataArray(TEST_NP, name='test')
TEST_DA2 = xr.DataArray(TEST_NP2, name='test2')
TEST_DA = xr.DataArray(TEST_NP, name="test")
TEST_DA2 = xr.DataArray(TEST_NP2, name="test2")
TEST_DS = TEST_DA.to_dataset()
TEST_DS['test2'] = TEST_DA2
TEST_DS["test2"] = TEST_DA2

EK_GRIB_READER = from_source("file", "tests/data/test_single.grib")
EK_XARRAY_WRAPPER = from_object(TEST_DS)
EK_NUMPY_WRAPPER = from_object(TEST_NP)

XR_TYPES = (xr.Dataset, xr.DataArray, xr.Variable)
WRAPPED_XR_ONES_LIKE = module_inputs_wrapper.transform_function_inputs(
xr.ones_like, kwarg_types={"other": XR_TYPES}
)

def test_transform_xarray_function_inputs():
xr_types = (xr.Dataset, xr.DataArray, xr.Variable)
this_xr_ones_like = module_inputs_wrapper.transform_function_inputs(
xr.ones_like, kwarg_types = {'other': xr_types}
)
WRAPPED_NP_MEAN = module_inputs_wrapper.transform_function_inputs(
np.mean,
kwarg_types={"a": np.ndarray},
convert_types=(
Reader
), # Only convert Earthkit.data.Reader (np.mean can handle xarray and pandas)
)


def test_transform_function_inputs_reader_to_xarray():
# Check EK GribReader object
ek_reader_result = this_xr_ones_like(EK_GRIB_READER)
assert isinstance(ek_reader_result, xr_types)
assert ek_reader_result == xr.ones_like(EK_GRIB_READER.to_xarray())

# Check EK XarrayWrapper object
ek_wrapper_result = this_xr_ones_like(EK_XARRAY_WRAPPER)
assert isinstance(ek_wrapper_result, xr_types)
assert ek_wrapper_result == xr.ones_like(EK_XARRAY_WRAPPER.data)


def test_transform_numpy_function_inputs():
this_np_mean = module_inputs_wrapper.transform_function_inputs(
np.mean, kwarg_types={"a": np.ndarray},
convert_types=(Reader),
tranformers={Wrapper: lambda x: x.data} # Only conver Earthkit.data.Reader and Wrapper types
)
ek_reader_result = WRAPPED_XR_ONES_LIKE(EK_GRIB_READER)
assert isinstance(ek_reader_result, XR_TYPES)
assert ek_reader_result.equals(xr.ones_like(EK_GRIB_READER.to_xarray()))


def test_transform_function_inputs_wrapper_to_xarray():
# EK XarrayWrapper object
ek_wrapper_result = WRAPPED_XR_ONES_LIKE(EK_XARRAY_WRAPPER)
assert isinstance(ek_wrapper_result, XR_TYPES)
assert ek_wrapper_result.equals(xr.ones_like(EK_XARRAY_WRAPPER.data))
# EK NumpyWrapper object
ek_wrapper_result = WRAPPED_XR_ONES_LIKE(EK_NUMPY_WRAPPER)
assert isinstance(ek_wrapper_result, XR_TYPES)
assert ek_wrapper_result.equals(xr.ones_like(TEST_DA))


def test_transform_function_inputs_reader_to_numpy():
# Test with Earthkit.data GribReader object
assert this_np_mean(EK_GRIB_READER) == np.mean(EK_GRIB_READER.to_numpy())
assert isinstance(this_np_mean(EK_GRIB_READER), type(np.mean(TEST_NP)))
assert WRAPPED_NP_MEAN(EK_GRIB_READER) == np.mean(EK_GRIB_READER.to_numpy())
assert isinstance(WRAPPED_NP_MEAN(EK_GRIB_READER), np.float64)


def test_transform_function_inputs_wrapper_to_numpy():
# Test with Earthkit.data XarrayWrapper object
ek_object_result = this_np_mean(EK_XARRAY_WRAPPER)
ek_object_result = WRAPPED_NP_MEAN(EK_XARRAY_WRAPPER)
assert ek_object_result == np.mean(TEST_DS)
assert isinstance(ek_object_result, type(EK_XARRAY_WRAPPER.data))

# Test with Earthkit.data NumpyWrapper object
ek_object_result = WRAPPED_NP_MEAN(EK_NUMPY_WRAPPER)
assert ek_object_result == np.mean(TEST_NP)
assert isinstance(ek_object_result, np.float64)


def test_transform_function_inputs_xarray_to_numpy():
# Test with xarray.DataArray object
assert this_np_mean(TEST_DA) == np.mean(TEST_NP)
assert this_np_mean(TEST_DA) == np.mean(TEST_DA)
assert isinstance(this_np_mean(TEST_DA), xr.DataArray)

# Test with pandas.DataFrame object
assert this_np_mean(TEST_DF) == np.mean(TEST_NP)
assert this_np_mean(TEST_DF) == np.mean(TEST_DF)
# assert isinstance(this_np_mean(TEST_DF), )
assert WRAPPED_NP_MEAN(TEST_DA) == np.mean(TEST_NP)
assert WRAPPED_NP_MEAN(TEST_DA) == np.mean(TEST_DA)
assert isinstance(WRAPPED_NP_MEAN(TEST_DA), xr.DataArray)


def test_transform_function_inputs_pandas_to_numpy():
# Test with pandas.DataFrame object
assert WRAPPED_NP_MEAN(TEST_DF) == np.mean(TEST_NP)
assert WRAPPED_NP_MEAN(TEST_DF) == np.mean(TEST_DF)
assert isinstance(WRAPPED_NP_MEAN(TEST_DF), np.float64)

0 comments on commit cc33143

Please sign in to comment.