Skip to content

Commit

Permalink
pygmt.grd2xyz: Improve performance by storing output in virtual files (
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman authored Mar 13, 2024
1 parent e3c580f commit 752305c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 104 deletions.
12 changes: 12 additions & 0 deletions pygmt/helpers/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,18 @@
input and skip trailing text. **Note**: If ``incols`` is also
used then the columns given to ``outcols`` correspond to the
order after the ``incols`` selection has taken place.""",
"outfile": """
outfile
File name for saving the result data. Required if ``output_type="file"``.
If specified, ``output_type`` will be forced to be ``"file"``.""",
"output_type": """
output_type
Desired output type of the result data.
- ``pandas`` will return a :class:`pandas.DataFrame` object.
- ``numpy`` will return a :class:`numpy.ndarray` object.
- ``file`` will save the result to the file specified by the ``outfile``
parameter.""",
"outgrid": """
outgrid : str or None
Name of the output netCDF grid file. For writing a specific grid
Expand Down
74 changes: 33 additions & 41 deletions pygmt/src/grd2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
grd2xyz - Convert grid to data table
"""

from typing import TYPE_CHECKING, Literal

import pandas as pd
import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
kwargs_to_strings,
use_alias,
validate_output_table_type,
)

if TYPE_CHECKING:
from collections.abc import Hashable

__doctest_skip__ = ["grd2xyz"]


Expand All @@ -33,7 +37,12 @@
s="skiprows",
)
@kwargs_to_strings(R="sequence", o="sequence_comma")
def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
def grd2xyz(
grid,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
**kwargs,
) -> pd.DataFrame | xr.DataArray | None:
r"""
Convert grid to data table.
Expand All @@ -47,15 +56,8 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
Parameters
----------
{grid}
output_type : str
Determine the format the xyz data will be returned in [Default is
``pandas``]:
- ``numpy`` - :class:`numpy.ndarray`
- ``pandas``- :class:`pandas.DataFrame`
- ``file`` - ASCII file (requires ``outfile``)
outfile : str
The file name for the output ASCII file.
{output_type}
{outfile}
cstyle : str
[**f**\|\ **i**].
Replace the x- and y-coordinates on output with the corresponding
Expand Down Expand Up @@ -118,13 +120,12 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
Returns
-------
ret : pandas.DataFrame or numpy.ndarray or None
ret
Return type depends on ``outfile`` and ``output_type``:
- None if ``outfile`` is set (output will be stored in file set by
``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is
not set (depends on ``output_type``)
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)
Example
-------
Expand All @@ -149,31 +150,22 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
"or 'file'."
)

# Set the default column names for the pandas dataframe header
dataframe_header = ["x", "y", "z"]
# Set the default column names for the pandas dataframe header.
column_names: list[Hashable] = ["x", "y", "z"]
# Let output pandas column names match input DataArray dimension names
if isinstance(grid, xr.DataArray) and output_type == "pandas":
if output_type == "pandas" and isinstance(grid, xr.DataArray):
# Reverse the dims because it is rows, columns ordered.
dataframe_header = [grid.dims[1], grid.dims[0], grid.name]

with GMTTempFile() as tmpfile:
with Session() as lib:
with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd:
if outfile is None:
outfile = tmpfile.name
lib.call_module(
module="grd2xyz",
args=build_arg_string(kwargs, infile=vingrd, outfile=outfile),
)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
result = pd.read_csv(
tmpfile.name, sep="\t", names=dataframe_header, comment=">"
column_names = [grid.dims[1], grid.dims[0], grid.name]

with Session() as lib:
with (
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
lib.call_module(
module="grd2xyz",
args=build_arg_string(kwargs, infile=vingrd, outfile=vouttbl),
)
return lib.virtualfile_to_dataset(
output_type=output_type, vfname=vouttbl, column_names=column_names
)
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None

if output_type == "numpy":
result = result.to_numpy()
return result
70 changes: 7 additions & 63 deletions pygmt/tests/test_grd2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
Test pygmt.grd2xyz.
"""

from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from pygmt import grd2xyz
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile
from pygmt.helpers.testing import load_static_earth_relief


Expand All @@ -24,70 +21,17 @@ def fixture_grid():
@pytest.mark.benchmark
def test_grd2xyz(grid):
"""
Make sure grd2xyz works as expected.
"""
xyz_data = grd2xyz(grid=grid, output_type="numpy")
assert xyz_data.shape == (112, 3)


def test_grd2xyz_format(grid):
Test the basic functionality of grd2xyz.
"""
Test that correct formats are returned.
"""
lon = -50.5
lat = -18.5
orig_val = grid.sel(lon=lon, lat=lat).to_numpy()
xyz_default = grd2xyz(grid=grid)
xyz_val = xyz_default[(xyz_default["lon"] == lon) & (xyz_default["lat"] == lat)][
"z"
].to_numpy()
assert isinstance(xyz_default, pd.DataFrame)
assert orig_val.size == 1
assert xyz_val.size == 1
np.testing.assert_allclose(orig_val, xyz_val)
xyz_array = grd2xyz(grid=grid, output_type="numpy")
assert isinstance(xyz_array, np.ndarray)
xyz_df = grd2xyz(grid=grid, output_type="pandas", outcols=None)
xyz_df = grd2xyz(grid=grid)
assert isinstance(xyz_df, pd.DataFrame)
assert list(xyz_df.columns) == ["lon", "lat", "z"]
assert xyz_df.shape == (112, 3)


def test_grd2xyz_file_output(grid):
"""
Test that grd2xyz returns a file output when it is specified.
"""
with GMTTempFile(suffix=".xyz") as tmpfile:
result = grd2xyz(grid=grid, outfile=tmpfile.name, output_type="file")
assert result is None # return value is None
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists


def test_grd2xyz_invalid_format(grid):
"""
Test that grd2xyz fails with incorrect format.
"""
with pytest.raises(GMTInvalidInput):
grd2xyz(grid=grid, output_type=1)


def test_grd2xyz_no_outfile(grid):
"""
Test that grd2xyz fails when a string output is set with no outfile.
"""
with pytest.raises(GMTInvalidInput):
grd2xyz(grid=grid, output_type="file")


def test_grd2xyz_outfile_incorrect_output_type(grid):
"""
Test that grd2xyz raises a warning when an outfile filename is set but the
output_type is not set to 'file'.
"""
with pytest.warns(RuntimeWarning):
with GMTTempFile(suffix=".xyz") as tmpfile:
result = grd2xyz(grid=grid, outfile=tmpfile.name, output_type="numpy")
assert result is None # return value is None
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists
lon, lat = -50.5, -18.5
orig_val = grid.sel(lon=lon, lat=lat).to_numpy()
xyz_val = xyz_df[(xyz_df["lon"] == lon) & (xyz_df["lat"] == lat)]["z"].to_numpy()
np.testing.assert_allclose(orig_val, xyz_val)


def test_grd2xyz_pandas_output_with_o(grid):
Expand Down

0 comments on commit 752305c

Please sign in to comment.