Skip to content

Commit

Permalink
Refactor plot and plot3d to use virtualfile_from_data (#990)
Browse files Browse the repository at this point in the history
Added an additional `extra_arrays` parameter to the
`virtualfile_from_data` function to accept optional
numpy arrays from the plot and plot3d functions.

* Just use virtualfile_from_matrix for non-datetime 2D numpy arrays

More efficient to pass in whole 2D numpy array matrix
as a virtualfile to GMT, and this fixes the segmentation
fault crash on test_plot3d_matrix_color when the data
was passed in via virtualfile_from_vectors instead.

* Add docstring on extra_arrays parameter in virtualfile_from_data
* Use virtualfile_from_matrix on int/uint/float types and add a test

Co-authored-by: Dongdong Tian <seisman.info@gmail.com>
  • Loading branch information
weiji14 and seisman authored Mar 8, 2021
1 parent 5385fa5 commit c2684ba
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 25 deletions.
25 changes: 21 additions & 4 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,9 @@ def virtualfile_from_grid(self, grid):
with self.open_virtual_file(*args) as vfile:
yield vfile

def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=None):
def virtualfile_from_data(
self, check_kind=None, data=None, x=None, y=None, z=None, extra_arrays=None
):
"""
Store any data inside a virtual file.
Expand All @@ -1378,6 +1380,9 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No
raster grid, a vector matrix/arrays, or other supported data input.
x/y/z : 1d arrays or None
x, y and z columns as numpy arrays.
extra_arrays : list of 1d arrays
Optional. A list of numpy arrays in addition to x, y and z. All
of these arrays must be of the same size as the x/y/z arrays.
Returns
-------
Expand Down Expand Up @@ -1430,14 +1435,26 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No
if kind in ("file", "grid"):
_data = (data,)
elif kind == "vectors":
_data = (x, y, z)
_data = [np.atleast_1d(x), np.atleast_1d(y)]
if z is not None:
_data.append(np.atleast_1d(z))
if extra_arrays:
_data.extend(extra_arrays)
elif kind == "matrix": # turn 2D arrays into list of vectors
try:
# pandas.DataFrame and xarray.Dataset types
_data = [array for _, array in data.items()]
except AttributeError:
# Python lists, tuples, and numpy ndarray types
_data = np.atleast_2d(np.asanyarray(data).T)
try:
# Just use virtualfile_from_matrix for 2D numpy.ndarray
# which are signed integer (i), unsigned integer (u) or
# floating point (f) types
assert data.ndim == 2 and data.dtype.kind in "iuf"
_virtualfile_from = self.virtualfile_from_matrix
_data = (data,)
except (AssertionError, AttributeError):
# Python lists, tuples, and numpy ndarray types
_data = np.atleast_2d(np.asanyarray(data).T)

# Finally create the virtualfile from the data, to be passed into GMT
file_context = _virtualfile_from(*_data)
Expand Down
13 changes: 3 additions & 10 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
plot - Plot in two dimensions.
"""
import numpy as np
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
is_nonstr_iter,
kwargs_to_strings,
Expand Down Expand Up @@ -226,14 +224,9 @@ def plot(self, x=None, y=None, data=None, sizes=None, direction=None, **kwargs):

with Session() as lib:
# Choose how data will be passed in to the module
if kind == "file":
file_context = dummy_context(data)
elif kind == "matrix":
file_context = lib.virtualfile_from_matrix(data)
elif kind == "vectors":
file_context = lib.virtualfile_from_vectors(
np.atleast_1d(x), np.atleast_1d(y), *extra_arrays
)
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays
)

with file_context as fname:
arg_str = " ".join([fname, build_arg_string(kwargs)])
Expand Down
13 changes: 3 additions & 10 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
plot3d - Plot in three dimensions.
"""
import numpy as np
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
is_nonstr_iter,
kwargs_to_strings,
Expand Down Expand Up @@ -189,14 +187,9 @@ def plot3d(

with Session() as lib:
# Choose how data will be passed in to the module
if kind == "file":
file_context = dummy_context(data)
elif kind == "matrix":
file_context = lib.virtualfile_from_matrix(data)
elif kind == "vectors":
file_context = lib.virtualfile_from_vectors(
np.atleast_1d(x), np.atleast_1d(y), np.atleast_1d(z), *extra_arrays
)
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, x=x, y=y, z=z, extra_arrays=extra_arrays
)

with file_context as fname:
arg_str = " ".join([fname, build_arg_string(kwargs)])
Expand Down
11 changes: 10 additions & 1 deletion pygmt/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def test_info():
assert output == expected_output


def test_info_2d_list():
"""
Make sure info works on a 2d list.
"""
output = info(table=[[0, 8], [3, 5], [6, 2]])
expected_output = "<vector memory>: N = 3 <0/6> <2/8>\n"
assert output == expected_output


def test_info_dataframe():
"""
Make sure info works on pandas.DataFrame inputs.
Expand Down Expand Up @@ -105,7 +114,7 @@ def test_info_2d_array():
table = np.loadtxt(POINTS_DATA)
output = info(table=table)
expected_output = (
"<vector memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
"<matrix memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
)
assert output == expected_output

Expand Down

0 comments on commit c2684ba

Please sign in to comment.