Skip to content

Commit

Permalink
Ensure PyNVML works correctly when installed with no GPUs (#4873)
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev authored Jun 4, 2021
1 parent 6b08926 commit 2bdec05
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 29 deletions.
1 change: 1 addition & 0 deletions continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- paramiko
- prometheus_client
- psutil
- pynvml # Only tested here
- pytest
- pytest-asyncio<0.14.0
- pytest-faulthandler
Expand Down
10 changes: 2 additions & 8 deletions distributed/dashboard/components/nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@
from distributed.dashboard.components import DashboardComponent, add_periodic_callback
from distributed.dashboard.components.scheduler import BOKEH_THEME, TICKS_1024, env
from distributed.dashboard.utils import update, without_property_validation
from distributed.diagnostics import nvml
from distributed.utils import log_errors

try:
import pynvml

pynvml.nvmlInit()

NVML_ENABLED = True
except Exception:
NVML_ENABLED = False
NVML_ENABLED = nvml.device_get_count() > 0


class GPUCurrentLoad(DashboardComponent):
Expand Down
39 changes: 29 additions & 10 deletions distributed/diagnostics/nvml.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
import os

import pynvml
try:
import pynvml
except ImportError:
pynvml = None

nvmlInit = None
nvmlInitialized = False
nvmlLibraryNotFound = False
nvmlOwnerPID = None


def init_once():
global nvmlInit
if nvmlInit is not None:
global nvmlInitialized, nvmlLibraryNotFound, nvmlOwnerPID
if pynvml is None or (nvmlInitialized is True and nvmlOwnerPID == os.getpid()):
return

from pynvml import nvmlInit as _nvmlInit
nvmlInitialized = True
nvmlOwnerPID = os.getpid()
try:
pynvml.nvmlInit()
except pynvml.NVMLError_LibraryNotFound:
nvmlLibraryNotFound = True


nvmlInit = _nvmlInit
nvmlInit()
def device_get_count():
init_once()
if nvmlLibraryNotFound or not nvmlInitialized:
return 0
else:
return pynvml.nvmlDeviceGetCount()


def _pynvml_handles():
count = pynvml.nvmlDeviceGetCount()
count = device_get_count()
if count == 0:
if nvmlLibraryNotFound:
raise RuntimeError("PyNVML is installed, but NVML is not")
else:
raise RuntimeError("No GPUs available")

try:
cuda_visible_devices = [
int(idx) for idx in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")
Expand All @@ -32,7 +53,6 @@ def _pynvml_handles():


def real_time():
init_once()
h = _pynvml_handles()
return {
"utilization": pynvml.nvmlDeviceGetUtilizationRates(h).gpu,
Expand All @@ -41,7 +61,6 @@ def real_time():


def one_time():
init_once()
h = _pynvml_handles()
return {
"memory-total": pynvml.nvmlDeviceGetMemoryInfo(h).total,
Expand Down
16 changes: 14 additions & 2 deletions distributed/diagnostics/tests/test_nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@


def test_one_time():
if nvml.device_get_count() < 1:
pytest.skip("No GPUs available")

output = nvml.one_time()
assert "memory-total" in output
assert "name" in output
Expand All @@ -17,6 +20,9 @@ def test_one_time():


def test_1_visible_devices():
if nvml.device_get_count() < 1:
pytest.skip("No GPUs available")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
output = nvml.one_time()
h = nvml._pynvml_handles()
Expand All @@ -25,8 +31,8 @@ def test_1_visible_devices():

@pytest.mark.parametrize("CVD", ["1,0", "0,1"])
def test_2_visible_devices(CVD):
if pynvml.nvmlDeviceGetCount() <= 1:
pytest.skip("Machine only has a single GPU")
if nvml.device_get_count() < 2:
pytest.skip("Less than two GPUs available")

os.environ["CUDA_VISIBLE_DEVICES"] = CVD
idx = int(CVD.split(",")[0])
Expand All @@ -42,6 +48,9 @@ def test_2_visible_devices(CVD):

@gen_cluster()
async def test_gpu_metrics(s, a, b):
if nvml.device_get_count() < 1:
pytest.skip("No GPUs available")

h = nvml._pynvml_handles()

assert "gpu" in a.metrics
Expand All @@ -58,6 +67,9 @@ async def test_gpu_metrics(s, a, b):

@gen_cluster()
async def test_gpu_monitoring(s, a, b):
if nvml.device_get_count() < 1:
pytest.skip("No GPUs available")

h = nvml._pynvml_handles()
res = await s.get_worker_monitor_info(recent=True)

Expand Down
4 changes: 2 additions & 2 deletions distributed/system_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, n=10000):
self.num_fds = deque(maxlen=n)
self.quantities["num_fds"] = self.num_fds

if nvml is not None:
if nvml.device_get_count() > 0:
gpu_extra = nvml.one_time()
self.gpu_name = gpu_extra["name"]
self.gpu_memory_total = gpu_extra["memory-total"]
Expand Down Expand Up @@ -91,7 +91,7 @@ def update(self):
self.num_fds.append(num_fds)
result["num_fds"] = num_fds

if nvml is not None:
if nvml.device_get_count() > 0:
gpu_metrics = nvml.real_time()
self.gpu_utilization.append(gpu_metrics["utilization"])
self.gpu_memory_used.append(gpu_metrics["memory-used"])
Expand Down
13 changes: 6 additions & 7 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
pingpong,
send_recv,
)
from .diagnostics import nvml
from .diagnostics.plugin import _get_worker_plugin_name
from .diskutils import WorkSpace
from .http import get_handlers
Expand Down Expand Up @@ -78,11 +79,6 @@
from .utils_perf import ThrottledGC, disable_gc_diagnosis, enable_gc_diagnosis
from .versions import get_versions

try:
from .diagnostics import nvml
except Exception:
nvml = None

logger = logging.getLogger(__name__)

LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
Expand Down Expand Up @@ -1085,7 +1081,7 @@ def get_monitor_info(self, comm=None, recent=False, start=0):
count=self.monitor.count,
last_time=self.monitor.last_time,
)
if nvml is not None:
if nvml.device_get_count() > 0:
result["gpu_name"] = self.monitor.gpu_name
result["gpu_memory_total"] = self.monitor.gpu_memory_total
return result
Expand Down Expand Up @@ -3849,7 +3845,10 @@ async def run(server, comm, function, args=(), kwargs=None, is_coro=None, wait=T

try:
from .diagnostics import nvml
except Exception:

if nvml.device_get_count() < 1:
raise RuntimeError
except (Exception, RuntimeError):
pass
else:

Expand Down

0 comments on commit 2bdec05

Please sign in to comment.