Skip to content

Commit

Permalink
Basic curvefit implementation (#4849)
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang authored Mar 31, 2021
1 parent 57a4479 commit ddc352f
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 1 deletion.
3 changes: 2 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ Computation
Dataset.integrate
Dataset.map_blocks
Dataset.polyfit
Dataset.curvefit

**Aggregation**:
:py:attr:`~Dataset.all`
Expand Down Expand Up @@ -375,7 +376,7 @@ Computation
DataArray.integrate
DataArray.polyfit
DataArray.map_blocks

DataArray.curvefit

**Aggregation**:
:py:attr:`~DataArray.all`
Expand Down
83 changes: 83 additions & 0 deletions doc/user-guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,89 @@ The inverse operation is done with :py:meth:`~xarray.polyval`,
.. note::
These methods replicate the behaviour of :py:func:`numpy.polyfit` and :py:func:`numpy.polyval`.


.. _compute.curvefit:

Fitting arbitrary functions
===========================

Xarray objects also provide an interface for fitting more complex functions using
:py:meth:`scipy.optimize.curve_fit`. :py:meth:`~xarray.DataArray.curvefit` accepts
user-defined functions and can fit along multiple coordinates.

For example, we can fit a relationship between two ``DataArray`` objects, maintaining
a unique fit at each spatial coordinate but aggregating over the time dimension:

.. ipython:: python
def exponential(x, a, xc):
return np.exp((x - xc) / a)
x = np.arange(-5, 5, 0.1)
t = np.arange(-5, 5, 0.1)
X, T = np.meshgrid(x, t)
Z1 = np.random.uniform(low=-5, high=5, size=X.shape)
Z2 = exponential(Z1, 3, X)
Z3 = exponential(Z1, 1, -X)
ds = xr.Dataset(
data_vars=dict(
var1=(["t", "x"], Z1), var2=(["t", "x"], Z2), var3=(["t", "x"], Z3)
),
coords={"t": t, "x": x},
)
ds[["var2", "var3"]].curvefit(
coords=ds.var1,
func=exponential,
reduce_dims="t",
bounds={"a": (0.5, 5), "xc": (-5, 5)},
)
We can also fit multi-dimensional functions, and even use a wrapper function to
simultaneously fit a summation of several functions, such as this field containing
two gaussian peaks:

.. ipython:: python
def gaussian_2d(coords, a, xc, yc, xalpha, yalpha):
x, y = coords
z = a * np.exp(
-np.square(x - xc) / 2 / np.square(xalpha)
- np.square(y - yc) / 2 / np.square(yalpha)
)
return z
def multi_peak(coords, *args):
z = np.zeros(coords[0].shape)
for i in range(len(args) // 5):
z += gaussian_2d(coords, *args[i * 5 : i * 5 + 5])
return z
x = np.arange(-5, 5, 0.1)
y = np.arange(-5, 5, 0.1)
X, Y = np.meshgrid(x, y)
n_peaks = 2
names = ["a", "xc", "yc", "xalpha", "yalpha"]
names = [f"{name}{i}" for i in range(n_peaks) for name in names]
Z = gaussian_2d((X, Y), 3, 1, 1, 2, 1) + gaussian_2d((X, Y), 2, -1, -2, 1, 1)
Z += np.random.normal(scale=0.1, size=Z.shape)
da = xr.DataArray(Z, dims=["y", "x"], coords={"y": y, "x": x})
da.curvefit(
coords=["x", "y"],
func=multi_peak,
param_names=names,
kwargs={"maxfev": 10000},
)
.. note::
This method replicates the behavior of :py:func:`scipy.optimize.curve_fit`.


.. _compute.broadcasting:

Broadcasting by dimension name
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ New Features
:py:meth:`~pandas.core.groupby.GroupBy.get_group`.
By `Deepak Cherian <https://github.com/dcherian>`_.
- Disable the `cfgrib` backend if the `eccodes` library is not installed (:pull:`5083`). By `Baudouin Raoult <https://github.com/b8raoult>`_.
- Added :py:meth:`DataArray.curvefit` and :py:meth:`Dataset.curvefit` for general curve fitting applications. (:issue:`4300`, :pull:`4849`)
By `Sam Levang <https://github.com/slevang>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
78 changes: 78 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4418,6 +4418,84 @@ def query(
)
return ds[self.name]

def curvefit(
self,
coords: Union[Union[str, "DataArray"], Iterable[Union[str, "DataArray"]]],
func: Callable[..., Any],
reduce_dims: Union[Hashable, Iterable[Hashable]] = None,
skipna: bool = True,
p0: Dict[str, Any] = None,
bounds: Dict[str, Any] = None,
param_names: Sequence[str] = None,
kwargs: Dict[str, Any] = None,
):
"""
Curve fitting optimization for arbitrary functions.
Wraps `scipy.optimize.curve_fit` with `apply_ufunc`.
Parameters
----------
coords : DataArray, str or sequence of DataArray, str
Independent coordinate(s) over which to perform the curve fitting. Must share
at least one dimension with the calling object. When fitting multi-dimensional
functions, supply `coords` as a sequence in the same order as arguments in
`func`. To fit along existing dimensions of the calling object, `coords` can
also be specified as a str or sequence of strs.
func : callable
User specified function in the form `f(x, *params)` which returns a numpy
array of length `len(x)`. `params` are the fittable parameters which are optimized
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
coordinates, e.g. `f((x0, x1), *params)`.
reduce_dims : str or sequence of str
Additional dimension(s) over which to aggregate while fitting. For example,
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
aggregate all lat and lon points and fit the specified function along the
time dimension.
skipna : bool, optional
Whether to skip missing values when fitting. Default is True.
p0 : dictionary, optional
Optional dictionary of parameter names to initial guesses passed to the
`curve_fit` `p0` arg. If none or only some parameters are passed, the rest will
be assigned initial values following the default scipy behavior.
bounds : dictionary, optional
Optional dictionary of parameter names to bounding values passed to the
`curve_fit` `bounds` arg. If none or only some parameters are passed, the rest
will be unbounded following the default scipy behavior.
param_names: seq, optional
Sequence of names for the fittable parameters of `func`. If not supplied,
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
kwargs : dictionary
Additional keyword arguments to passed to scipy curve_fit.
Returns
-------
curvefit_results : Dataset
A single dataset which contains:
[var]_curvefit_coefficients
The coefficients of the best fit.
[var]_curvefit_covariance
The covariance matrix of the coefficient estimates.
See also
--------
DataArray.polyfit
scipy.optimize.curve_fit
"""
return self._to_temp_dataset().curvefit(
coords,
func,
reduce_dims=reduce_dims,
skipna=skipna,
p0=p0,
bounds=bounds,
param_names=param_names,
kwargs=kwargs,
)

# this needs to be at the end, or mypy will confuse with `str`
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = utils.UncachedAccessor(StringAccessor)
Expand Down
Loading

0 comments on commit ddc352f

Please sign in to comment.