Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor plot and plot3d to use virtualfile_from_data #990

Merged
merged 8 commits into from
Mar 8, 2021
24 changes: 20 additions & 4 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,9 @@ def virtualfile_from_grid(self, grid):
with self.open_virtual_file(*args) as vfile:
yield vfile

def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=None):
def virtualfile_from_data(
self, check_kind=None, data=None, x=None, y=None, z=None, extra_arrays=None
):
"""
Store any data inside a virtual file.

Expand All @@ -1378,6 +1380,9 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No
raster grid, a vector matrix/arrays, or other supported data input.
x/y/z : 1d arrays or None
x, y and z columns as numpy arrays.
extra_arrays : list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
extra_arrays : list
extra_arrays : list of 1d arrays

Optional. A list of numpy arrays in addition to x, y and z. All
of these arrays must be of the same size as the x/y/z arrays.

Returns
-------
Expand Down Expand Up @@ -1430,14 +1435,25 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No
if kind in ("file", "grid"):
_data = (data,)
elif kind == "vectors":
_data = (x, y, z)
_data = [np.atleast_1d(x), np.atleast_1d(y)]
if z is not None:
_data.append(np.atleast_1d(z))
if extra_arrays:
_data.extend(extra_arrays)
elif kind == "matrix": # turn 2D arrays into list of vectors
try:
# pandas.DataFrame and xarray.Dataset types
_data = [array for _, array in data.items()]
except AttributeError:
# Python lists, tuples, and numpy ndarray types
_data = np.atleast_2d(np.asanyarray(data).T)
try:
# Just use virtualfile_from_matrix for 2D
# numpy.ndarray which are not datetime (M) types
assert data.ndim == 2 and not data.dtype.kind == "M"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert data.ndim == 2 and not data.dtype.kind == "M"

The GMT_Put_Matrix() function only supports a few numeric data types, but numpy.dtype.kind can have values other than M, for example, data.dtype.kind == "O" will pass the assert statement but is not supported by virtualfile_from_matrix. Perhaps check data.dtype.kind in 'iuf'?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll try data.dtype.kind in 'iuf' and see if it works on the test suite.

_virtualfile_from = self.virtualfile_from_matrix
_data = (data,)
except (AssertionError, AttributeError):
# Python lists, tuples, and numpy ndarray types
_data = np.atleast_2d(np.asanyarray(data).T)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would appreciate advice on improving/refactoring this chunk of code. I'm half thinking whether to move some of the logic to pygmt/helpers/utils.py, i.e. have a new 'kind' besides file/grid/matrix/vectors:

if isinstance(data, str):
kind = "file"
elif isinstance(data, xr.DataArray):
kind = "grid"
elif data is not None:
kind = "matrix"
else:
kind = "vectors"

Or we could just keep things like this as it is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't read the codes carefully, but I think the "matrix" kind is quite confusing.

The GMT API function GMT_Put_Matrix(), the PyGMT wrapper put_matrix() and the virtualfile function virtualfile_from_matrix() all require a simple 2d matrix with a single dtype (e.g., np.float or np.double).

However, currently, data types like pandas.DataFrame are also "matrix". So we have to check the data types to choose either virtualfile_from_vector or virtualfile_from_matrix.

I agree with you that we can/shoudl add a new kind to distinguish between a GMT-compatible "matrix" and a more complicated data structure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we keep this logic as is for now to resolve #1021 (which helps with #1020), and refactor things later to have more specific 'matrix' types?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.


# Finally create the virtualfile from the data, to be passed into GMT
file_context = _virtualfile_from(*_data)
Expand Down
13 changes: 3 additions & 10 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
plot - Plot in two dimensions.
"""
import numpy as np
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
is_nonstr_iter,
kwargs_to_strings,
Expand Down Expand Up @@ -226,14 +224,9 @@ def plot(self, x=None, y=None, data=None, sizes=None, direction=None, **kwargs):

with Session() as lib:
# Choose how data will be passed in to the module
if kind == "file":
file_context = dummy_context(data)
elif kind == "matrix":
file_context = lib.virtualfile_from_matrix(data)
elif kind == "vectors":
file_context = lib.virtualfile_from_vectors(
np.atleast_1d(x), np.atleast_1d(y), *extra_arrays
)
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays
)

with file_context as fname:
arg_str = " ".join([fname, build_arg_string(kwargs)])
Expand Down
13 changes: 3 additions & 10 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
plot3d - Plot in three dimensions.
"""
import numpy as np
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
is_nonstr_iter,
kwargs_to_strings,
Expand Down Expand Up @@ -189,14 +187,9 @@ def plot3d(

with Session() as lib:
# Choose how data will be passed in to the module
if kind == "file":
file_context = dummy_context(data)
elif kind == "matrix":
file_context = lib.virtualfile_from_matrix(data)
elif kind == "vectors":
file_context = lib.virtualfile_from_vectors(
np.atleast_1d(x), np.atleast_1d(y), np.atleast_1d(z), *extra_arrays
)
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, x=x, y=y, z=z, extra_arrays=extra_arrays
)

with file_context as fname:
arg_str = " ".join([fname, build_arg_string(kwargs)])
Expand Down
2 changes: 1 addition & 1 deletion pygmt/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_info_2d_array():
table = np.loadtxt(POINTS_DATA)
output = info(table=table)
expected_output = (
"<vector memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
"<matrix memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
)
assert output == expected_output

Expand Down