Skip to content

Commit

Permalink
Standardize output types for grdhisteq methods (GenericMappingTools#1812
Browse files Browse the repository at this point in the history
)

Modifies grdhisteq methods, so that an `xarray.DataArray` is returned
only if outgrid is None. Also updated the docstring returns section for
compute_bins to clarify the dependence on output_type.

Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com>
Co-authored-by: Dongdong Tian <seisman.info@gmail.com>
  • Loading branch information
3 people authored and Josh Sixsmith committed Dec 21, 2022
1 parent c625775 commit 7b483cd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
19 changes: 11 additions & 8 deletions pygmt/src/grdhisteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _grdhisteq(grid, output_type, **kwargs):
def equalize_grid(
grid,
*,
outgrid=True,
outgrid=None,
divisions=None,
region=None,
gaussian=None,
Expand All @@ -163,7 +163,7 @@ def equalize_grid(
----------
grid : str or xarray.DataArray
The file name of the input grid or the grid loaded as a DataArray.
outgrid : str or bool or None
outgrid : str or None
The name of the output netCDF file with extension .nc to store the
grid in.
divisions : int
Expand All @@ -183,7 +183,7 @@ def equalize_grid(
ret: xarray.DataArray or None
Return type depends on the ``outgrid`` parameter:
- xarray.DataArray if ``outgrid`` is True or None
- xarray.DataArray if ``outgrid`` is None
- None if ``outgrid`` is a str (grid output is stored in
``outgrid``)
Expand Down Expand Up @@ -211,9 +211,11 @@ def equalize_grid(
with GMTTempFile(suffix=".nc") as tmpfile:
if isinstance(outgrid, str):
output_type = "file"
else:
elif outgrid is None:
output_type = "xarray"
outgrid = tmpfile.name
else:
raise GMTInvalidInput("Must specify 'outgrid' as a string or None.")
return grdhisteq._grdhisteq(
grid=grid,
output_type=output_type,
Expand Down Expand Up @@ -281,12 +283,13 @@ def compute_bins(
Returns
-------
ret: pandas.DataFrame or None
Return type depends on the ``outfile`` parameter:
ret : pandas.DataFrame or numpy.ndarray or None
Return type depends on ``outfile`` and ``output_type``:
- pandas.DataFrame if ``outfile`` is True or None
- None if ``outfile`` is a str (file output is stored in
- None if ``outfile`` is set (output will be stored in file set by
``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if
``outfile`` is not set (depends on ``output_type``)
Example
-------
Expand Down
15 changes: 11 additions & 4 deletions pygmt/tests/test_grdhisteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def test_equalize_grid_outgrid_file(grid, expected_grid, region):
xr.testing.assert_allclose(a=temp_grid, b=expected_grid)


@pytest.mark.parametrize("outgrid", [True, None])
def test_equalize_grid_outgrid(grid, outgrid, expected_grid, region):
def test_equalize_grid_outgrid(grid, expected_grid, region):
"""
Test grdhisteq.equalize_grid with ``outgrid=True`` and ``outgrid=None``.
Test grdhisteq.equalize_grid with ``outgrid=None``.
"""
temp_grid = grdhisteq.equalize_grid(
grid=grid, divisions=2, region=region, outgrid=outgrid
grid=grid, divisions=2, region=region, outgrid=None
)
assert temp_grid.gmt.gtype == 1 # Geographic grid
assert temp_grid.gmt.registration == 1 # Pixel registration
Expand Down Expand Up @@ -135,3 +134,11 @@ def test_compute_bins_invalid_format(grid):
grdhisteq.compute_bins(grid=grid, output_type=1)
with pytest.raises(GMTInvalidInput):
grdhisteq.compute_bins(grid=grid, output_type="pandas", header="o+c")


def test_equalize_grid_invalid_format(grid):
"""
Test that equalize_grid fails with incorrect format.
"""
with pytest.raises(GMTInvalidInput):
grdhisteq.equalize_grid(grid=grid, outgrid=True)

0 comments on commit 7b483cd

Please sign in to comment.