Skip to content

Commit

Permalink
Merge pull request #658 from lukeyeager/device-query-updates
Browse files Browse the repository at this point in the history
Device query updates
  • Loading branch information
lukeyeager committed Mar 29, 2016
2 parents 87587ba + cd614ea commit 6b18cd0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 29 deletions.
61 changes: 42 additions & 19 deletions digits/device_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,52 @@ def get_library(name):
Returns a ctypes.CDLL or None
"""
try:
if platform.system() == 'Linux':
return ctypes.cdll.LoadLibrary('%s.so' % name)
elif platform.system() == 'Darwin':
return ctypes.cdll.LoadLibrary('%s.dylib' % name)
elif platform.system() == 'Windows':
return ctypes.windll.LoadLibrary('%s.dll' % name)
if platform.system() == 'Windows':
return ctypes.windll.LoadLibrary(name)
else:
return ctypes.cdll.LoadLibrary(name)
except OSError:
pass
return None

devices = None

def get_cudart():
if not platform.system() == 'Windows':
return get_library('libcudart')
"""
Return the ctypes.DLL object for cudart or None
"""
if platform.system() == 'Windows':
arch = platform.architecture()[0]
for ver in range(90,50,-5):
cudart = get_library('cudart%s_%d.dll' % (arch[:2], ver))
if cudart is not None:
return cudart
else:
for name in (
'libcudart.so.7.0',
'libcudart.so.7.5',
'libcudart.so.8.0',
'libcudart.so'):
cudart = get_library(name)
if cudart is not None:
return cudart
return None

arch = platform.architecture()[0]
for ver in range(90,50,-5):
cudart = get_library('cudart%s_%d' % (arch[:2], ver))
if cudart is not None:
return cudart
def get_nvml():
"""
Return the ctypes.DLL object for cudart or None
"""
if platform.system() == 'Windows':
return get_library('nvml.dll')
else:
for name in (
'libnvidia-ml.so.1',
'libnvidia-ml.so',
'nvml.so'):
nvml = get_library(name)
if nvml is not None:
return nvml
return None

devices = None

def get_devices(force_reload=False):
"""
Expand Down Expand Up @@ -192,11 +217,9 @@ def get_nvml_info(device_id):
if device is None:
return None

nvml = get_library('libnvidia-ml')
nvml = get_nvml()
if nvml is None:
nvml = get_library('nvml')
if nvml is None:
return None
return None

rc = nvml.nvmlInit()
if rc != 0:
Expand Down
24 changes: 14 additions & 10 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,27 +274,31 @@ def validate_custom_network_snapshot(form, field):
select_gpu = wtforms.RadioField('Select which GPU you would like to use',
choices = [('next', 'Next available')] + [(
index,
'#%s - %s%s' % (
'#%s - %s (%s memory)' % (
index,
get_device(index).name,
' (%s memory)' % sizeof_fmt(get_nvml_info(index)['memory']['total'])
if get_nvml_info(index) and 'memory' in get_nvml_info(index) else '',
),
) for index in config_value('gpu_list').split(',') if index],
sizeof_fmt(
get_nvml_info(index)['memory']['total']
if get_nvml_info(index) and 'memory' in get_nvml_info(index)
else get_device(index).totalGlobalMem)
),
) for index in config_value('gpu_list').split(',') if index],
default = 'next',
)

# Select N of several GPUs
select_gpus = utils.forms.SelectMultipleField('Select which GPU[s] you would like to use',
choices = [(
index,
'#%s - %s%s' % (
'#%s - %s (%s memory)' % (
index,
get_device(index).name,
' (%s memory)' % sizeof_fmt(get_nvml_info(index)['memory']['total'])
if get_nvml_info(index) and 'memory' in get_nvml_info(index) else '',
),
) for index in config_value('gpu_list').split(',') if index],
sizeof_fmt(
get_nvml_info(index)['memory']['total']
if get_nvml_info(index) and 'memory' in get_nvml_info(index)
else get_device(index).totalGlobalMem)
),
) for index in config_value('gpu_list').split(',') if index],
tooltip = "The job won't start until all of the chosen GPUs are available."
)

Expand Down

0 comments on commit 6b18cd0

Please sign in to comment.