Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance by avoiding loading the GMT library repeatedly #2930

Merged
merged 12 commits into from
Jan 2, 2024
5 changes: 4 additions & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
np.datetime64: "GMT_DATETIME",
}

# Load the GMT library outside the Session class to avoid repeating loading.
seisman marked this conversation as resolved.
Show resolved Hide resolved
_libgmt = load_libgmt()


class Session:
"""
Expand Down Expand Up @@ -308,7 +311,7 @@ def get_libgmt_func(self, name, argtypes=None, restype=None):
<class 'ctypes.CDLL.__init__.<locals>._FuncPtr'>
"""
if not hasattr(self, "_libgmt"):
self._libgmt = load_libgmt()
self._libgmt = _libgmt
function = getattr(self._libgmt, name)
if argtypes is not None:
function.argtypes = argtypes
Expand Down
40 changes: 40 additions & 0 deletions pygmt/tests/test_clib_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest
from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt
from pygmt.clib.session import Session
from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError


Expand Down Expand Up @@ -208,6 +209,45 @@ def test_brokenlib_brokenlib_workinglib(self):
assert check_libgmt(load_libgmt(lib_fullnames=lib_fullnames)) is None


class TestLibgmtCount:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to put this unit test in test_session_management.py?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to keep it in test_clib_loading.py because this unit test is actually not related to session management.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I kinda wanted to move the test since test_clib_loading.py has 300+ lines of code, while test_session_management.py has <100 lines, and we're kinda checking that Session() doesn't reload libgmt here, but up to you 🙂

"""
Test that the GMT library is not repeatly loaded in every session.
seisman marked this conversation as resolved.
Show resolved Hide resolved
"""

loaded_libgmt = load_libgmt() # Load the GMT library and reuse it when necessary
counter = 0 # Count how many times ctypes.CDLL is called

def _mock_ctypes_cdll_return(self, libname): # noqa: ARG002
"""
Mock ctypes.CDLL to count how many times the function is called.

If ctypes.CDLL is called, the counter increases by one.
"""
self.counter += 1 # Increase the counter
return self.loaded_libgmt

@pytest.fixture()
def _mock_ctypes(self, monkeypatch):
monkeypatch.setattr(ctypes, "CDLL", self._mock_ctypes_cdll_return)

@pytest.mark.usefixtures("_mock_ctypes")
def test_libgmt_load_counter(self):
"""
Make sure that the GMT library is not loaded in every session.
"""
with Session() as lib:
_ = lib
with Session() as lib:
_ = lib
assert self.counter == 0 # ctypes.CDLL is not called after two sessions.

# Explicitly calling load_libgmt to make sure the mock function is correct
load_libgmt()
assert self.counter == 1
load_libgmt()
assert self.counter == 2


###############################################################################
# Test clib_full_names
@pytest.fixture(scope="module", name="gmt_lib_names")
Expand Down