Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pygmt.surface tests #1568

Merged
merged 20 commits into from
Nov 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pygmt/helpers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def download_test_data():
"@EGM96_to_36.txt",
"@MaunaLoa_CO2.txt",
"@Table_5_11.txt",
"@Table_5_11_mean.xyz",
"@fractures_06.txt",
"@hotspots.txt",
"@ridge.txt",
Expand Down
175 changes: 112 additions & 63 deletions pygmt/tests/test_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,118 +3,167 @@
"""
import os

import pandas as pd
import pytest
import xarray as xr
from pygmt import surface, which
from pygmt.datasets import load_sample_bathymetry
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import data_kind
from pygmt.helpers import GMTTempFile, data_kind

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
TEMP_GRID = os.path.join(TEST_DATA_DIR, "tmp_grid.nc")

@pytest.fixture(scope="module", name="data")
def fixture_data():
"""
Load Table 5.11 in Davis: Statistics and Data Analysis in Geology.
"""
fname = which("@Table_5_11_mean.xyz", download="c")
return pd.read_csv(
fname, sep=r"\s+", header=None, names=["x", "y", "z"], skiprows=1
)


@pytest.fixture(scope="module", name="region")
def fixture_region():
"""
Define the region.
"""
return [0, 4, 0, 8]

@pytest.fixture(scope="module", name="ship_data")
def fixture_ship_data():

@pytest.fixture(scope="module", name="spacing")
def fixture_spacing():
"""
Load the data from the sample bathymetry dataset.
Define the spacing.
"""
return load_sample_bathymetry()
return "1"


def test_surface_input_file():
@pytest.fixture(scope="module", name="expected_grid")
def fixture_grid_result():
"""
Load the expected grdcut grid result.
"""
return xr.DataArray(
data=[
[962.2361, 909.526, 872.2578, 876.5983, 950.573],
[944.369, 905.8253, 872.1614, 901.8583, 943.6854],
[911.0599, 865.305, 845.5058, 855.7317, 867.1638],
[878.5973, 851.71, 814.6884, 812.1127, 819.9591],
[842.0522, 815.2896, 788.2292, 777.0031, 785.6345],
[854.2515, 813.3035, 781, 742.3641, 735.6497],
[882.972, 818.4636, 773.611, 718.7798, 685.4824],
[897.4532, 822.9642, 756.4472, 687.594, 626.2299],
[910.2932, 823.3307, 737.9952, 651.4994, 565.9981],
],
coords=dict(
y=[0, 1, 2, 3, 4, 5, 6, 7, 8],
x=[0, 1, 2, 3, 4],
),
dims=[
"y",
"x",
],
)


def check_values(grid, expected_grid):
"""
Check the attributes and values of the DataArray returned by surface.
"""
assert isinstance(grid, xr.DataArray)
assert grid.gmt.registration == 0 # Gridline registration
assert grid.gmt.gtype == 0 # Cartesian type
xr.testing.assert_allclose(a=grid, b=expected_grid)


def test_surface_input_file(region, spacing, expected_grid):
"""
Run surface by passing in a filename.
"""
fname = which("@tut_ship.xyz", download="c")
output = surface(data=fname, spacing="5m", region=[245, 255, 20, 30])
assert isinstance(output, xr.DataArray)
assert output.gmt.registration == 0 # Gridline registration
assert output.gmt.gtype == 0 # Cartesian type
output = surface(
data="@Table_5_11_mean.xyz",
spacing=spacing,
region=region,
verbose="e", # Suppress warnings for IEEE 754 rounding
)
check_values(output, expected_grid)


def test_surface_input_data_array(ship_data):
def test_surface_input_data_array(data, region, spacing, expected_grid):
"""
Run surface by passing in a numpy array into data.
"""
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
output = surface(data=data, spacing="5m", region=[245, 255, 20, 30])
assert isinstance(output, xr.DataArray)
data = data.values # convert pandas.DataFrame to numpy.ndarray
output = surface(
data=data,
spacing=spacing,
region=region,
verbose="e", # Suppress warnings for IEEE 754 rounding
)
check_values(output, expected_grid)


def test_surface_input_xyz(ship_data):
def test_surface_input_xyz(data, region, spacing, expected_grid):
"""
Run surface by passing in x, y, z numpy.ndarrays individually.
"""
output = surface(
x=ship_data.longitude,
y=ship_data.latitude,
z=ship_data.bathymetry,
spacing="5m",
region=[245, 255, 20, 30],
x=data.x,
y=data.y,
z=data.z,
spacing=spacing,
region=region,
verbose="e", # Suppress warnings for IEEE 754 rounding
)
assert isinstance(output, xr.DataArray)
check_values(output, expected_grid)


def test_surface_wrong_kind_of_input(ship_data):
def test_surface_wrong_kind_of_input(data, region, spacing):
"""
Run surface using grid input that is not file/matrix/vectors.
"""
data = ship_data.bathymetry.to_xarray() # convert pandas.Series to xarray.DataArray
data = data.z.to_xarray() # convert pandas.Series to xarray.DataArray
assert data_kind(data) == "grid"
with pytest.raises(GMTInvalidInput):
surface(data=data, spacing="5m", region=[245, 255, 20, 30])
surface(data=data, spacing=spacing, region=region)


def test_surface_with_outgrid_param(ship_data):
def test_surface_with_outgrid_param(data, region, spacing, expected_grid):
"""
Run surface with the -Goutputfile.nc parameter.
"""
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
try:
data = data.values # convert pandas.DataFrame to numpy.ndarray
with GMTTempFile(suffix=".nc") as tmpfile:
output = surface(
data=data, spacing="5m", region=[245, 255, 20, 30], outgrid=TEMP_GRID
data=data,
spacing=spacing,
region=region,
outgrid=tmpfile.name,
verbose="e", # Suppress warnings for IEEE 754 rounding
)
assert output is None # check that output is None since outgrid is set
assert os.path.exists(path=TEMP_GRID) # check that outgrid exists at path
with xr.open_dataarray(TEMP_GRID) as grid:
assert isinstance(grid, xr.DataArray) # ensure netcdf grid loads ok
finally:
os.remove(path=TEMP_GRID)
assert os.path.exists(path=tmpfile.name) # check that outgrid exists at path
with xr.open_dataarray(tmpfile.name) as grid:
check_values(grid, expected_grid)


def test_surface_deprecate_outfile_to_outgrid(ship_data):
def test_surface_deprecate_outfile_to_outgrid(data, region, spacing, expected_grid):
"""
Make sure that the old parameter "outfile" is supported and it reports a
warning.
"""
with pytest.warns(expected_warning=FutureWarning) as record:
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
try:
data = data.values # convert pandas.DataFrame to numpy.ndarray
with GMTTempFile(suffix=".nc") as tmpfile:
output = surface(
data=data, spacing="5m", region=[245, 255, 20, 30], outfile=TEMP_GRID
data=data,
spacing=spacing,
region=region,
outfile=tmpfile.name,
verbose="e", # Suppress warnings for IEEE 754 rounding
)
assert output is None # check that output is None since outfile is set
assert os.path.exists(path=TEMP_GRID) # check that file exists at path

with xr.open_dataarray(TEMP_GRID) as grid:
assert isinstance(grid, xr.DataArray) # ensure netcdf grid loads ok
finally:
os.remove(path=TEMP_GRID)
assert os.path.exists(path=tmpfile.name) # check that file exists at path
with xr.open_dataarray(tmpfile.name) as grid:
check_values(grid, expected_grid)
assert len(record) == 1 # check that only one warning was raised


def test_surface_short_aliases(ship_data):
"""
Run surface using short aliases -I for spacing, -R for region, -G for
outgrid.
"""
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
try:
output = surface(data=data, I="5m", R=[245, 255, 20, 30], G=TEMP_GRID)
assert output is None # check that output is None since outgrid is set
assert os.path.exists(path=TEMP_GRID) # check that outgrid exists at path
with xr.open_dataarray(TEMP_GRID) as grid:
assert isinstance(grid, xr.DataArray) # ensure netcdf grid loads ok
finally:
os.remove(path=TEMP_GRID)