diff --git a/pygmt/src/grdhisteq.py b/pygmt/src/grdhisteq.py index c12199fe717..09555d05e94 100644 --- a/pygmt/src/grdhisteq.py +++ b/pygmt/src/grdhisteq.py @@ -141,7 +141,7 @@ def _grdhisteq(grid, output_type, **kwargs): def equalize_grid( grid, *, - outgrid=True, + outgrid=None, divisions=None, region=None, gaussian=None, @@ -163,7 +163,7 @@ def equalize_grid( ---------- grid : str or xarray.DataArray The file name of the input grid or the grid loaded as a DataArray. - outgrid : str or bool or None + outgrid : str or None The name of the output netCDF file with extension .nc to store the grid in. divisions : int @@ -183,7 +183,7 @@ def equalize_grid( ret: xarray.DataArray or None Return type depends on the ``outgrid`` parameter: - - xarray.DataArray if ``outgrid`` is True or None + - xarray.DataArray if ``outgrid`` is None - None if ``outgrid`` is a str (grid output is stored in ``outgrid``) @@ -211,9 +211,11 @@ def equalize_grid( with GMTTempFile(suffix=".nc") as tmpfile: if isinstance(outgrid, str): output_type = "file" - else: + elif outgrid is None: output_type = "xarray" outgrid = tmpfile.name + else: + raise GMTInvalidInput("Must specify 'outgrid' as a string or None.") return grdhisteq._grdhisteq( grid=grid, output_type=output_type, @@ -281,12 +283,13 @@ def compute_bins( Returns ------- - ret: pandas.DataFrame or None - Return type depends on the ``outfile`` parameter: + ret : pandas.DataFrame or numpy.ndarray or None + Return type depends on ``outfile`` and ``output_type``: - - pandas.DataFrame if ``outfile`` is True or None - - None if ``outfile`` is a str (file output is stored in + - None if ``outfile`` is set (output will be stored in file set by ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if + ``outfile`` is not set (depends on ``output_type``) Example ------- diff --git a/pygmt/tests/test_grdhisteq.py b/pygmt/tests/test_grdhisteq.py index 9a4fd004b3d..a812e0d5565 100644 --- a/pygmt/tests/test_grdhisteq.py +++ b/pygmt/tests/test_grdhisteq.py @@ -66,13 +66,12 @@ def test_equalize_grid_outgrid_file(grid, expected_grid, region): xr.testing.assert_allclose(a=temp_grid, b=expected_grid) -@pytest.mark.parametrize("outgrid", [True, None]) -def test_equalize_grid_outgrid(grid, outgrid, expected_grid, region): +def test_equalize_grid_outgrid(grid, expected_grid, region): """ - Test grdhisteq.equalize_grid with ``outgrid=True`` and ``outgrid=None``. + Test grdhisteq.equalize_grid with ``outgrid=None``. """ temp_grid = grdhisteq.equalize_grid( - grid=grid, divisions=2, region=region, outgrid=outgrid + grid=grid, divisions=2, region=region, outgrid=None ) assert temp_grid.gmt.gtype == 1 # Geographic grid assert temp_grid.gmt.registration == 1 # Pixel registration @@ -135,3 +134,11 @@ def test_compute_bins_invalid_format(grid): grdhisteq.compute_bins(grid=grid, output_type=1) with pytest.raises(GMTInvalidInput): grdhisteq.compute_bins(grid=grid, output_type="pandas", header="o+c") + + +def test_equalize_grid_invalid_format(grid): + """ + Test that equalize_grid fails with incorrect format. + """ + with pytest.raises(GMTInvalidInput): + grdhisteq.equalize_grid(grid=grid, outgrid=True)