diff --git a/pygmt/src/legend.py b/pygmt/src/legend.py index 44c7da06570..ed34bc0d797 100644 --- a/pygmt/src/legend.py +++ b/pygmt/src/legend.py @@ -2,6 +2,7 @@ legend - Plot a legend. """ +import io import pathlib from pygmt.clib import Session @@ -30,7 +31,7 @@ @kwargs_to_strings(R="sequence", c="sequence_comma", p="sequence") def legend( self, - spec: str | pathlib.PurePath | None = None, + spec: str | pathlib.PurePath | io.StringIO | None = None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwargs, @@ -57,6 +58,7 @@ def legend( file - A string or a :class:`pathlib.PurePath` object pointing to the legend specification file + - A :class:`io.StringIO` object containing the legend specification. See :gmt-docs:`legend.html` for the definition of the legend specification. {projection} @@ -89,10 +91,11 @@ def legend( kwargs["F"] = box kind = data_kind(spec) - if kind not in {"vectors", "file"}: # kind="vectors" means spec is None + if kind not in {"vectors", "file", "stringio"}: # kind="vectors" means spec is None raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}") if kind == "file" and is_nonstr_iter(spec): raise GMTInvalidInput("Only one legend specification file is allowed.") with Session() as lib: - lib.call_module(module="legend", args=build_arg_list(kwargs, infile=spec)) + with lib.virtualfile_in(data=spec, required_data=False) as vintbl: + lib.call_module(module="legend", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/tests/test_legend.py b/pygmt/tests/test_legend.py index e2edfa8259b..3a63d74166e 100644 --- a/pygmt/tests/test_legend.py +++ b/pygmt/tests/test_legend.py @@ -2,6 +2,7 @@ Test Figure.legend. """ +import io from pathlib import Path import pytest @@ -100,6 +101,18 @@ def test_legend_specfile(legend_spec): fig = Figure() fig.basemap(projection="x6i", region=[0, 1, 0, 1], frame=True) fig.legend(specfile.name, position="JTM+jCM+w5i") + return fig + + +@pytest.mark.mpl_image_compare(filename="test_legend_specfile.png") +def test_legend_stringio(legend_spec): + """ + Test passing a legend specification via an io.StringIO object. + """ + spec = io.StringIO(legend_spec) + fig = Figure() + fig.basemap(projection="x6i", region=[0, 1, 0, 1], frame=True) + fig.legend(spec, position="JTM+jCM+w5i") return fig