Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clib.Session: Refactor the __getitem__ special method to avoid calling API function GMT_Get_Enum repeatedly #3261

Merged
merged 9 commits into from
May 22, 2024
49 changes: 33 additions & 16 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@
np.datetime64: "GMT_DATETIME",
np.timedelta64: "GMT_LONG",
}
# Dictionary for storing the values of GMT constants.
GMT_CONSTANTS = {}

# Load the GMT library outside the Session class to avoid repeated loading.
_libgmt = load_libgmt()
Expand Down Expand Up @@ -239,23 +241,41 @@ def __exit__(self, exc_type, exc_value, traceback):
"""
self.destroy()

def __getitem__(self, name):
def __getitem__(self, name: str) -> int:
"""
Get the value of a GMT constant.

Parameters
----------
name
The name of the constant (e.g., ``"GMT_SESSION_EXTERNAL"``).

Returns
-------
value
Integer value of the constant. Do not rely on this value because it might
change.
"""
if name not in GMT_CONSTANTS:
GMT_CONSTANTS[name] = self.get_enum(name)
return GMT_CONSTANTS[name]

def get_enum(self, name: str) -> int:
"""
Get the value of a GMT constant (C enum) from gmt_resources.h.

Used to set configuration values for other API calls. Wraps
``GMT_Get_Enum``.
Used to set configuration values for other API calls. Wraps ``GMT_Get_Enum``.

Parameters
----------
name : str
The name of the constant (e.g., ``"GMT_SESSION_EXTERNAL"``)
name
The name of the constant (e.g., ``"GMT_SESSION_EXTERNAL"``).

Returns
-------
constant : int
Integer value of the constant. Do not rely on this value because it
might change.
value
Integer value of the constant. Do not rely on this value because it might
change.

Raises
------
Expand All @@ -266,18 +286,15 @@ def __getitem__(self, name):
"GMT_Get_Enum", argtypes=[ctp.c_void_p, ctp.c_char_p], restype=ctp.c_int
)

# The C lib introduced the void API pointer to GMT_Get_Enum so that
# it's consistent with other functions. It doesn't use the pointer so
# we can pass in None (NULL pointer). We can't give it the actual
# pointer because we need to call GMT_Get_Enum when creating a new API
# session pointer (chicken-and-egg type of thing).
# The C library introduced the void API pointer to GMT_Get_Enum so that it's
# consistent with other functions. It doesn't use the pointer so we can pass
# in None (NULL pointer). We can't give it the actual pointer because we need
# to call GMT_Get_Enum when creating a new API session pointer (chicken-and-egg
# type of thing).
session = None

value = c_get_enum(session, name.encode())

if value is None or value == -99999:
raise GMTCLibError(f"Constant '{name}' doesn't exist in libgmt.")

return value

def get_libgmt_func(self, name, argtypes=None, restype=None):
Expand Down
14 changes: 6 additions & 8 deletions pygmt/tests/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ def mock_get_libgmt_func(name, argtypes=None, restype=None):

def test_getitem():
"""
Test that I can get correct constants from the C lib.
Test getting the GMT constants from the C library.
"""
ses = clib.Session()
assert ses["GMT_SESSION_EXTERNAL"] != -99999
assert ses["GMT_MODULE_CMD"] != -99999
assert ses["GMT_PAD_DEFAULT"] != -99999
assert ses["GMT_DOUBLE"] != -99999
with pytest.raises(GMTCLibError):
ses["A_WHOLE_LOT_OF_JUNK"]
with clib.Session() as lib:
for name in ["GMT_SESSION_EXTERNAL", "GMT_MODULE_CMD", "GMT_DOUBLE"]:
assert lib[name] != -99999
with pytest.raises(GMTCLibError):
lib["A_WHOLE_LOT_OF_JUNK"]


def test_create_destroy_session():
Expand Down