diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 73ea013..f39b17f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,10 @@ name: "Run Tests" on: [push, pull_request, workflow_dispatch] +defaults: + run: + shell: bash + jobs: unit-tests: name: "Unit Tests" @@ -22,6 +26,9 @@ jobs: python-version: "3.10" - os: ubuntu-latest python-version: "3.11" + - os: ubuntu-latest + python-version: "3.11" + pynvml-version: 11.495.46 - os: windows-latest python-version: "3.8" - os: windows-latest @@ -46,6 +53,9 @@ jobs: - name: Install dependencies run: | pip install -e ".[test]" + if [ -n "${{ matrix.pynvml-version }}" ]; then + pip install nvidia-ml-py==${{ matrix.pynvml-version }} + fi python -m gpustat --version - name: Run tests diff --git a/gpustat/nvml.py b/gpustat/nvml.py index 86b2fce..a3f8256 100644 --- a/gpustat/nvml.py +++ b/gpustat/nvml.py @@ -1,10 +1,9 @@ """Imports pynvml with sanity checks and custom patches.""" -import textwrap +import functools import os - - -pynvml = None +import sys +import textwrap # If this environment variable is set, we will bypass pynvml version validation # so that legacy pynvml (nvidia-ml-py3) can be used. This would be useful @@ -25,7 +24,10 @@ hasattr(pynvml, 'nvmlDeviceGetComputeRunningProcesses_v2') ) and not ALLOW_LEGACY_PYNVML: raise RuntimeError("pynvml library is outdated.") + except (ImportError, SyntaxError, RuntimeError) as e: + _pynvml = sys.modules.get('pynvml', None) + raise ImportError(textwrap.dedent( """\ pynvml is missing or an outdated version is installed. @@ -33,7 +35,7 @@ We require nvidia-ml-py>=11.450.129, and nvidia-ml-py3 shall not be used. For more details, please refer to: https://github.com/wookayin/gpustat/issues/107 - Your pynvml installation: """ + repr(pynvml) + + Your pynvml installation: """ + repr(_pynvml) + """ ----------------------------------------------------------- @@ -48,4 +50,60 @@ """)) from e +# Monkey-patch nvml due to breaking changes in pynvml. +# See #107, #141, and test_gpustat.py for more details. + +_original_nvmlGetFunctionPointer = pynvml._nvmlGetFunctionPointer + + +class pynvml_monkeypatch: + + @staticmethod # Note: must be defined as a staticmethod to allow mocking. + def original_nvmlGetFunctionPointer(name): + return _original_nvmlGetFunctionPointer(name) + + FUNCTION_FALLBACKS = { + # for pynvml._nvmlGetFunctionPointer + 'nvmlDeviceGetComputeRunningProcesses_v3': 'nvmlDeviceGetComputeRunningProcesses_v2', + 'nvmlDeviceGetGraphicsRunningProcesses_v3': 'nvmlDeviceGetGraphicsRunningProcesses_v2', + } + + @staticmethod + @functools.wraps(pynvml._nvmlGetFunctionPointer) + def _nvmlGetFunctionPointer(name): + """Our monkey-patched pynvml._nvmlGetFunctionPointer(). + + See also: + test_gpustat::NvidiaDriverMock for test scenarios + """ + + try: + ret = pynvml_monkeypatch.original_nvmlGetFunctionPointer(name) + return ret + except pynvml.NVMLError as e: + if e.value != pynvml.NVML_ERROR_FUNCTION_NOT_FOUND: # type: ignore + raise + + if name in pynvml_monkeypatch.FUNCTION_FALLBACKS: + # Lack of ...Processes_v3 APIs happens for + # OLD drivers < 510.39.01 && pynvml >= 11.510, where + # we fallback to v2 APIs. (see #107 for more details) + + ret = pynvml_monkeypatch.original_nvmlGetFunctionPointer( + pynvml_monkeypatch.FUNCTION_FALLBACKS[name] + ) + # populate the cache, so this handler won't get executed again + pynvml._nvmlGetFunctionPointer_cache[name] = ret + + else: + # Unknown case, cannot handle. re-raise again + raise + + return ret + + +setattr(pynvml, '_nvmlGetFunctionPointer', + pynvml_monkeypatch._nvmlGetFunctionPointer) + + __all__ = ['pynvml'] diff --git a/gpustat/test_gpustat.py b/gpustat/test_gpustat.py index 78560a5..749081a 100644 --- a/gpustat/test_gpustat.py +++ b/gpustat/test_gpustat.py @@ -13,10 +13,10 @@ import psutil import pytest -from mockito import mock, unstub, when +from mockito import mock, unstub, when, when2 import gpustat -from gpustat.nvml import pynvml +from gpustat.nvml import pynvml, pynvml_monkeypatch MB = 1024 * 1024 @@ -46,8 +46,7 @@ def _configure_mock(N=pynvml, when(N).nvmlShutdown().thenReturn() when(N).nvmlSystemGetDriverVersion().thenReturn('415.27.mock') - when(N)._nvmlGetFunctionPointer('nvmlErrorString')\ - .thenCallOriginalImplementation() + when(N)._nvmlGetFunctionPointer(...).thenCallOriginalImplementation() NUM_GPUS = 3 mock_handles = [types.SimpleNamespace(value='mock-handle-%d' % i, index=i) @@ -323,21 +322,24 @@ def _nvmlDeviceGetGraphicsRunningProcesses_v2(handle, c_count, c_procs): return pynvml.NVML_ERROR_NOT_SUPPORTED return pynvml.NVML_SUCCESS - def _fn_notfound(*args, **kwargs): - return pynvml.NVML_ERROR_FUNCTION_NOT_FOUND - + # Note: N._nvmlGetFunctionPointer might have been monkey-patched, + # so this mock should decorate the underlying, unwrapped raw function, + # NOT a monkey-patched version of pynvml._nvmlGetFunctionPointer. for v in [1, 2, 3]: _v = f'_v{v}' if v != 1 else '' # backward compatible v3 -> v2 - when(N) \ - ._nvmlGetFunctionPointer(f'nvmlDeviceGetComputeRunningProcesses{_v}') \ - .thenReturn(_nvmlDeviceGetComputeRunningProcesses_v2 - if v <= self.nvmlDeviceGetComputeRunningProcesses_v - else _fn_notfound) - when(N) \ - ._nvmlGetFunctionPointer(f'nvmlDeviceGetGraphicsRunningProcesses{_v}') \ - .thenReturn(_nvmlDeviceGetGraphicsRunningProcesses_v2 - if v <= self.nvmlDeviceGetComputeRunningProcesses_v - else _fn_notfound) + stub = when2(pynvml_monkeypatch.original_nvmlGetFunctionPointer, + f'nvmlDeviceGetComputeRunningProcesses{_v}') + if v <= self.nvmlDeviceGetComputeRunningProcesses_v: + stub.thenReturn(_nvmlDeviceGetComputeRunningProcesses_v2) + else: + stub.thenRaise(pynvml.NVMLError(pynvml.NVML_ERROR_FUNCTION_NOT_FOUND)) + + stub = when2(pynvml_monkeypatch.original_nvmlGetFunctionPointer, + f'nvmlDeviceGetGraphicsRunningProcesses{_v}') + if v <= self.nvmlDeviceGetComputeRunningProcesses_v: + stub.thenReturn(_nvmlDeviceGetGraphicsRunningProcesses_v2) + else: + stub.thenRaise(pynvml.NVMLError(pynvml.NVML_ERROR_FUNCTION_NOT_FOUND)) def __getattr__(self, k): return self.feat[k] diff --git a/setup.py b/setup.py index 6f11491..1911641 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ def run(self): install_requires = [ - 'nvidia-ml-py>=11.450.129,<=11.495.46', # see #107 + 'nvidia-ml-py>=11.450.129', # see #107, #143 'psutil>=5.6.0', # GH-1447 'blessed>=1.17.1', # GH-126 ]