diff --git a/digits/device_query.py b/digits/device_query.py index 7454a30f0..1122f597f 100755 --- a/digits/device_query.py +++ b/digits/device_query.py @@ -109,27 +109,52 @@ def get_library(name): Returns a ctypes.CDLL or None """ try: - if platform.system() == 'Linux': - return ctypes.cdll.LoadLibrary('%s.so' % name) - elif platform.system() == 'Darwin': - return ctypes.cdll.LoadLibrary('%s.dylib' % name) - elif platform.system() == 'Windows': - return ctypes.windll.LoadLibrary('%s.dll' % name) + if platform.system() == 'Windows': + return ctypes.windll.LoadLibrary(name) + else: + return ctypes.cdll.LoadLibrary(name) except OSError: pass return None -devices = None - def get_cudart(): - if not platform.system() == 'Windows': - return get_library('libcudart') + """ + Return the ctypes.DLL object for cudart or None + """ + if platform.system() == 'Windows': + arch = platform.architecture()[0] + for ver in range(90,50,-5): + cudart = get_library('cudart%s_%d.dll' % (arch[:2], ver)) + if cudart is not None: + return cudart + else: + for name in ( + 'libcudart.so.7.0', + 'libcudart.so.7.5', + 'libcudart.so.8.0', + 'libcudart.so'): + cudart = get_library(name) + if cudart is not None: + return cudart + return None - arch = platform.architecture()[0] - for ver in range(90,50,-5): - cudart = get_library('cudart%s_%d' % (arch[:2], ver)) - if cudart is not None: - return cudart +def get_nvml(): + """ + Return the ctypes.DLL object for cudart or None + """ + if platform.system() == 'Windows': + return get_library('nvml.dll') + else: + for name in ( + 'libnvidia-ml.so.1', + 'libnvidia-ml.so', + 'nvml.so'): + nvml = get_library(name) + if nvml is not None: + return nvml + return None + +devices = None def get_devices(force_reload=False): """ @@ -192,11 +217,9 @@ def get_nvml_info(device_id): if device is None: return None - nvml = get_library('libnvidia-ml') + nvml = get_nvml() if nvml is None: - nvml = get_library('nvml') - if nvml is None: - return None + return None rc = nvml.nvmlInit() if rc != 0: diff --git a/digits/model/forms.py b/digits/model/forms.py index 867ffc003..bc162d7d6 100644 --- a/digits/model/forms.py +++ b/digits/model/forms.py @@ -274,13 +274,15 @@ def validate_custom_network_snapshot(form, field): select_gpu = wtforms.RadioField('Select which GPU you would like to use', choices = [('next', 'Next available')] + [( index, - '#%s - %s%s' % ( + '#%s - %s (%s memory)' % ( index, get_device(index).name, - ' (%s memory)' % sizeof_fmt(get_nvml_info(index)['memory']['total']) - if get_nvml_info(index) and 'memory' in get_nvml_info(index) else '', - ), - ) for index in config_value('gpu_list').split(',') if index], + sizeof_fmt( + get_nvml_info(index)['memory']['total'] + if get_nvml_info(index) and 'memory' in get_nvml_info(index) + else get_device(index).totalGlobalMem) + ), + ) for index in config_value('gpu_list').split(',') if index], default = 'next', ) @@ -288,13 +290,15 @@ def validate_custom_network_snapshot(form, field): select_gpus = utils.forms.SelectMultipleField('Select which GPU[s] you would like to use', choices = [( index, - '#%s - %s%s' % ( + '#%s - %s (%s memory)' % ( index, get_device(index).name, - ' (%s memory)' % sizeof_fmt(get_nvml_info(index)['memory']['total']) - if get_nvml_info(index) and 'memory' in get_nvml_info(index) else '', - ), - ) for index in config_value('gpu_list').split(',') if index], + sizeof_fmt( + get_nvml_info(index)['memory']['total'] + if get_nvml_info(index) and 'memory' in get_nvml_info(index) + else get_device(index).totalGlobalMem) + ), + ) for index in config_value('gpu_list').split(',') if index], tooltip = "The job won't start until all of the chosen GPUs are available." )