Skip to content

Commit

Permalink
first implementation + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyCMWF committed Jul 13, 2023
1 parent 47972d6 commit 26689ba
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 9 deletions.
79 changes: 70 additions & 9 deletions earthkit/data/utils/module_inputs_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

"""
Module containing methods to transform the inputs of functions based on the function type setting,
common signitures or mapping defined at call time
Expand All @@ -9,6 +18,7 @@
from functools import wraps

from earthkit.data import transform
from earthkit.data.wrappers import Wrapper

try:
UNION_TYPES = [T.Union, types.UnionType]
Expand All @@ -27,43 +37,94 @@ def ensure_iterable(input_item):
return [input_item]
return input_item

def ensure_tuple(input_item):
"""Ensure that an item is iterable"""
if not isinstance(input_item, tuple):
return tuple(ensure_iterable(input_item))
return input_item

def transform_function_inputs(function, **kwarg_types):

def transform_function_inputs(
function: T.Callable,
kwarg_types: T.Dict[str, T.Any] = {},
convert_types: T.Union[T.Tuple[T.Any], T.Dict[str, T.Tuple[T.Any]]] = (),
**decorator_kwargs
) -> T.Callable:
"""
Transform the inputs to a function to match the requirements.
earthkit.data handles the input arg/kwarg format.
Parameters
----------
function : Callable
Method to be wrapped
kwarg_types : Dict[str: type]
Mapping of accepted object types for each arg/kwarg
convert_types : Tuple[type]
List of data-types to try to convert, this can be useful when the function is versitile and can
accept a large number of data-types, hence only a small number of types should be converted.
Returns
-------
[type]
[description]
"""
def _wrapper(kwarg_types, *args, **kwargs):
def _wrapper(kwarg_types, convert_types, *args, **kwargs):
kwarg_types = {**kwarg_types}
signature = inspect.signature(function)
mapping = signature_mapping(signature, kwarg_types)

# convert args to kwargs for ease of looping:
# Add args to kwargs for ease of looping:
arg_names = []
for arg, name in zip(args, signature.parameters):
arg_names.append(name)
kwargs[name] = arg

kwargs_with_mapping = [k for k in kwargs if k in mapping]
# transform args/kwargs if mapping available
for key in kwargs_with_mapping:
# Expand any Wrapper objects to their native data format:
for k, v in kwargs.items():
if isinstance(v, Wrapper):
try:
kwargs[k] = v.data
except:
pass

convert_kwargs = [k for k in kwargs if k in mapping]
# Only convert some data-types, this can be used to prevent conversion for for functions which
# accept a long-list of formats, e.g. numpy methods can accept xarray, pandas and more

# Filter for convert_types
if convert_types:
# Ensure convert_types is a dictionary
if not isinstance(convert_types, dict):
convert_types = {key: convert_types for key in convert_kwargs}

convert_kwargs = [
k for k in convert_kwargs if isinstance(kwargs[k], ensure_tuple(convert_types.get(k, ())))
]

# transform args/kwargs
for key in convert_kwargs:
value = kwargs[key]
kwarg_types = ensure_iterable(mapping[key])
# Transform value if necessary
if type(value) not in kwarg_types:
for kwarg_type in kwarg_types:
try:
kwargs[key] = transform(value, kwarg_type)
except ValueError:
except:
# Transform was not possible, move to next kwarg type.
# If no transform is possible, format is unchanged and we rely on function to raise
# an Error.
continue
break

return function(**kwargs)
# Extract args from kwargs:
args = [kwargs.pop(name) for name in arg_names]
return function(*args, **kwargs)

@wraps(function)
def wrapper(*args, **kwargs):
return _wrapper(kwarg_types, *args, **kwargs)
return _wrapper(kwarg_types, convert_types, *args, **kwargs)

return wrapper

Expand Down
76 changes: 76 additions & 0 deletions tests/utils/test_module_inputs_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import numpy as np
import xarray as xr
import pandas as pd

from earthkit.data.utils import module_inputs_wrapper
from earthkit.data import from_source, from_object
from earthkit.data.readers import Reader
from earthkit.data.wrappers import Wrapper

TEST_NP = np.arange(10)
TEST_NP2 = np.arange(10)

TEST_DF = pd.DataFrame({'index': TEST_NP, 'data': TEST_NP2}).set_index('index')

TEST_DA = xr.DataArray(TEST_NP, name='test')
TEST_DA2 = xr.DataArray(TEST_NP2, name='test2')
TEST_DS = TEST_DA.to_dataset()
TEST_DS['test2'] = TEST_DA2

EK_GRIB_READER = from_source("file", "tests/data/test_single.grib")
EK_XARRAY_WRAPPER = from_object(TEST_DS)
EK_NUMPY_WRAPPER = from_object(TEST_NP)


def test_transform_xarray_function_inputs():
xr_types = (xr.Dataset, xr.DataArray, xr.Variable)
this_xr_ones_like = module_inputs_wrapper.transform_function_inputs(
xr.ones_like, kwarg_types = {'other': xr_types}
)

# Check EK GribReader object
ek_reader_result = this_xr_ones_like(EK_GRIB_READER)
assert isinstance(ek_reader_result, xr_types)
assert ek_reader_result == xr.ones_like(EK_GRIB_READER.to_xarray())

# Check EK XarrayWrapper object
ek_wrapper_result = this_xr_ones_like(EK_XARRAY_WRAPPER)
assert isinstance(ek_wrapper_result, xr_types)
assert ek_wrapper_result == xr.ones_like(EK_XARRAY_WRAPPER.data)


def test_transform_numpy_function_inputs():
this_np_mean = module_inputs_wrapper.transform_function_inputs(
np.mean, kwarg_types={"a": np.ndarray},
convert_types=(Reader),
tranformers={Wrapper: lambda x: x.data} # Only conver Earthkit.data.Reader and Wrapper types
)
# Test with Earthkit.data GribReader object
assert this_np_mean(EK_GRIB_READER) == np.mean(EK_GRIB_READER.to_numpy())
assert isinstance(this_np_mean(EK_GRIB_READER), type(np.mean(TEST_NP)))

# Test with Earthkit.data XarrayWrapper object
ek_object_result = this_np_mean(EK_XARRAY_WRAPPER)
assert ek_object_result == np.mean(TEST_DS)
assert isinstance(ek_object_result, type(EK_XARRAY_WRAPPER.data))

# Test with xarray.DataArray object
assert this_np_mean(TEST_DA) == np.mean(TEST_NP)
assert this_np_mean(TEST_DA) == np.mean(TEST_DA)
assert isinstance(this_np_mean(TEST_DA), xr.DataArray)

# Test with pandas.DataFrame object
assert this_np_mean(TEST_DF) == np.mean(TEST_NP)
assert this_np_mean(TEST_DF) == np.mean(TEST_DF)
# assert isinstance(this_np_mean(TEST_DF), )


0 comments on commit 26689ba

Please sign in to comment.