diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 64bcd55cf4e..3e79f558a8a 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1360,7 +1360,9 @@ def virtualfile_from_grid(self, grid): with self.open_virtual_file(*args) as vfile: yield vfile - def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=None): + def virtualfile_from_data( + self, check_kind=None, data=None, x=None, y=None, z=None, extra_arrays=None + ): """ Store any data inside a virtual file. @@ -1378,6 +1380,9 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No raster grid, a vector matrix/arrays, or other supported data input. x/y/z : 1d arrays or None x, y and z columns as numpy arrays. + extra_arrays : list of 1d arrays + Optional. A list of numpy arrays in addition to x, y and z. All + of these arrays must be of the same size as the x/y/z arrays. Returns ------- @@ -1430,14 +1435,26 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No if kind in ("file", "grid"): _data = (data,) elif kind == "vectors": - _data = (x, y, z) + _data = [np.atleast_1d(x), np.atleast_1d(y)] + if z is not None: + _data.append(np.atleast_1d(z)) + if extra_arrays: + _data.extend(extra_arrays) elif kind == "matrix": # turn 2D arrays into list of vectors try: # pandas.DataFrame and xarray.Dataset types _data = [array for _, array in data.items()] except AttributeError: - # Python lists, tuples, and numpy ndarray types - _data = np.atleast_2d(np.asanyarray(data).T) + try: + # Just use virtualfile_from_matrix for 2D numpy.ndarray + # which are signed integer (i), unsigned integer (u) or + # floating point (f) types + assert data.ndim == 2 and data.dtype.kind in "iuf" + _virtualfile_from = self.virtualfile_from_matrix + _data = (data,) + except (AssertionError, AttributeError): + # Python lists, tuples, and numpy ndarray types + _data = np.atleast_2d(np.asanyarray(data).T) # Finally create the virtualfile from the data, to be passed into GMT file_context = _virtualfile_from(*_data) diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index f367d816562..1d50bb4e416 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -1,13 +1,11 @@ """ plot - Plot in two dimensions. """ -import numpy as np from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( build_arg_string, data_kind, - dummy_context, fmt_docstring, is_nonstr_iter, kwargs_to_strings, @@ -226,14 +224,9 @@ def plot(self, x=None, y=None, data=None, sizes=None, direction=None, **kwargs): with Session() as lib: # Choose how data will be passed in to the module - if kind == "file": - file_context = dummy_context(data) - elif kind == "matrix": - file_context = lib.virtualfile_from_matrix(data) - elif kind == "vectors": - file_context = lib.virtualfile_from_vectors( - np.atleast_1d(x), np.atleast_1d(y), *extra_arrays - ) + file_context = lib.virtualfile_from_data( + check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays + ) with file_context as fname: arg_str = " ".join([fname, build_arg_string(kwargs)]) diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index 3b7ad917bdb..52b925fcaea 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -1,13 +1,11 @@ """ plot3d - Plot in three dimensions. """ -import numpy as np from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( build_arg_string, data_kind, - dummy_context, fmt_docstring, is_nonstr_iter, kwargs_to_strings, @@ -189,14 +187,9 @@ def plot3d( with Session() as lib: # Choose how data will be passed in to the module - if kind == "file": - file_context = dummy_context(data) - elif kind == "matrix": - file_context = lib.virtualfile_from_matrix(data) - elif kind == "vectors": - file_context = lib.virtualfile_from_vectors( - np.atleast_1d(x), np.atleast_1d(y), np.atleast_1d(z), *extra_arrays - ) + file_context = lib.virtualfile_from_data( + check_kind="vector", data=data, x=x, y=y, z=z, extra_arrays=extra_arrays + ) with file_context as fname: arg_str = " ".join([fname, build_arg_string(kwargs)]) diff --git a/pygmt/tests/test_info.py b/pygmt/tests/test_info.py index 1c40657ea1d..ab1093b29d2 100644 --- a/pygmt/tests/test_info.py +++ b/pygmt/tests/test_info.py @@ -29,6 +29,15 @@ def test_info(): assert output == expected_output +def test_info_2d_list(): + """ + Make sure info works on a 2d list. + """ + output = info(table=[[0, 8], [3, 5], [6, 2]]) + expected_output = ": N = 3 <0/6> <2/8>\n" + assert output == expected_output + + def test_info_dataframe(): """ Make sure info works on pandas.DataFrame inputs. @@ -105,7 +114,7 @@ def test_info_2d_array(): table = np.loadtxt(POINTS_DATA) output = info(table=table) expected_output = ( - ": N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n" + ": N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n" ) assert output == expected_output