diff --git a/pygmt/helpers/decorators.py b/pygmt/helpers/decorators.py index 046cffa5514..28041911d23 100644 --- a/pygmt/helpers/decorators.py +++ b/pygmt/helpers/decorators.py @@ -254,6 +254,18 @@ input and skip trailing text. **Note**: If ``incols`` is also used then the columns given to ``outcols`` correspond to the order after the ``incols`` selection has taken place.""", + "outfile": """ + outfile + File name for saving the result data. Required if ``output_type="file"``. + If specified, ``output_type`` will be forced to be ``"file"``.""", + "output_type": """ + output_type + Desired output type of the result data. + + - ``pandas`` will return a :class:`pandas.DataFrame` object. + - ``numpy`` will return a :class:`numpy.ndarray` object. + - ``file`` will save the result to the file specified by the ``outfile`` + parameter.""", "outgrid": """ outgrid : str or None Name of the output netCDF grid file. For writing a specific grid diff --git a/pygmt/src/grd2xyz.py b/pygmt/src/grd2xyz.py index eade93473c2..17cfcb246bc 100644 --- a/pygmt/src/grd2xyz.py +++ b/pygmt/src/grd2xyz.py @@ -2,12 +2,13 @@ grd2xyz - Convert grid to data table """ +from typing import TYPE_CHECKING, Literal + import pandas as pd import xarray as xr from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_string, fmt_docstring, kwargs_to_strings, @@ -15,6 +16,9 @@ validate_output_table_type, ) +if TYPE_CHECKING: + from collections.abc import Hashable + __doctest_skip__ = ["grd2xyz"] @@ -33,7 +37,12 @@ s="skiprows", ) @kwargs_to_strings(R="sequence", o="sequence_comma") -def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs): +def grd2xyz( + grid, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, +) -> pd.DataFrame | xr.DataArray | None: r""" Convert grid to data table. @@ -47,15 +56,8 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs): Parameters ---------- {grid} - output_type : str - Determine the format the xyz data will be returned in [Default is - ``pandas``]: - - - ``numpy`` - :class:`numpy.ndarray` - - ``pandas``- :class:`pandas.DataFrame` - - ``file`` - ASCII file (requires ``outfile``) - outfile : str - The file name for the output ASCII file. + {output_type} + {outfile} cstyle : str [**f**\|\ **i**]. Replace the x- and y-coordinates on output with the corresponding @@ -118,13 +120,12 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs): Returns ------- - ret : pandas.DataFrame or numpy.ndarray or None + ret Return type depends on ``outfile`` and ``output_type``: - - None if ``outfile`` is set (output will be stored in file set by - ``outfile``) - - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is - not set (depends on ``output_type``) + - None if ``outfile`` is set (output will be stored in file set by ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) Example ------- @@ -149,31 +150,22 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs): "or 'file'." ) - # Set the default column names for the pandas dataframe header - dataframe_header = ["x", "y", "z"] + # Set the default column names for the pandas dataframe header. + column_names: list[Hashable] = ["x", "y", "z"] # Let output pandas column names match input DataArray dimension names - if isinstance(grid, xr.DataArray) and output_type == "pandas": + if output_type == "pandas" and isinstance(grid, xr.DataArray): # Reverse the dims because it is rows, columns ordered. - dataframe_header = [grid.dims[1], grid.dims[0], grid.name] - - with GMTTempFile() as tmpfile: - with Session() as lib: - with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd: - if outfile is None: - outfile = tmpfile.name - lib.call_module( - module="grd2xyz", - args=build_arg_string(kwargs, infile=vingrd, outfile=outfile), - ) - - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - result = pd.read_csv( - tmpfile.name, sep="\t", names=dataframe_header, comment=">" + column_names = [grid.dims[1], grid.dims[0], grid.name] + + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + lib.call_module( + module="grd2xyz", + args=build_arg_string(kwargs, infile=vingrd, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset( + output_type=output_type, vfname=vouttbl, column_names=column_names ) - elif outfile != tmpfile.name: # return None if outfile set, output in outfile - result = None - - if output_type == "numpy": - result = result.to_numpy() - return result diff --git a/pygmt/tests/test_grd2xyz.py b/pygmt/tests/test_grd2xyz.py index b6f8e92c1ea..ab3feccf80c 100644 --- a/pygmt/tests/test_grd2xyz.py +++ b/pygmt/tests/test_grd2xyz.py @@ -2,14 +2,11 @@ Test pygmt.grd2xyz. """ -from pathlib import Path - import numpy as np import pandas as pd import pytest from pygmt import grd2xyz from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -24,70 +21,17 @@ def fixture_grid(): @pytest.mark.benchmark def test_grd2xyz(grid): """ - Make sure grd2xyz works as expected. - """ - xyz_data = grd2xyz(grid=grid, output_type="numpy") - assert xyz_data.shape == (112, 3) - - -def test_grd2xyz_format(grid): + Test the basic functionality of grd2xyz. """ - Test that correct formats are returned. - """ - lon = -50.5 - lat = -18.5 - orig_val = grid.sel(lon=lon, lat=lat).to_numpy() - xyz_default = grd2xyz(grid=grid) - xyz_val = xyz_default[(xyz_default["lon"] == lon) & (xyz_default["lat"] == lat)][ - "z" - ].to_numpy() - assert isinstance(xyz_default, pd.DataFrame) - assert orig_val.size == 1 - assert xyz_val.size == 1 - np.testing.assert_allclose(orig_val, xyz_val) - xyz_array = grd2xyz(grid=grid, output_type="numpy") - assert isinstance(xyz_array, np.ndarray) - xyz_df = grd2xyz(grid=grid, output_type="pandas", outcols=None) + xyz_df = grd2xyz(grid=grid) assert isinstance(xyz_df, pd.DataFrame) assert list(xyz_df.columns) == ["lon", "lat", "z"] + assert xyz_df.shape == (112, 3) - -def test_grd2xyz_file_output(grid): - """ - Test that grd2xyz returns a file output when it is specified. - """ - with GMTTempFile(suffix=".xyz") as tmpfile: - result = grd2xyz(grid=grid, outfile=tmpfile.name, output_type="file") - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists - - -def test_grd2xyz_invalid_format(grid): - """ - Test that grd2xyz fails with incorrect format. - """ - with pytest.raises(GMTInvalidInput): - grd2xyz(grid=grid, output_type=1) - - -def test_grd2xyz_no_outfile(grid): - """ - Test that grd2xyz fails when a string output is set with no outfile. - """ - with pytest.raises(GMTInvalidInput): - grd2xyz(grid=grid, output_type="file") - - -def test_grd2xyz_outfile_incorrect_output_type(grid): - """ - Test that grd2xyz raises a warning when an outfile filename is set but the - output_type is not set to 'file'. - """ - with pytest.warns(RuntimeWarning): - with GMTTempFile(suffix=".xyz") as tmpfile: - result = grd2xyz(grid=grid, outfile=tmpfile.name, output_type="numpy") - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists + lon, lat = -50.5, -18.5 + orig_val = grid.sel(lon=lon, lat=lat).to_numpy() + xyz_val = xyz_df[(xyz_df["lon"] == lon) & (xyz_df["lat"] == lat)]["z"].to_numpy() + np.testing.assert_allclose(orig_val, xyz_val) def test_grd2xyz_pandas_output_with_o(grid):