Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
platform="auto",
max_threads=None,
max_gpus=None,
max_sire_threads=None,
opencl_platform_index=0,
oversubscription_factor=1,
replica_exchange=False,
Expand Down Expand Up @@ -346,6 +347,11 @@ def __init__(
Maximum number of GPUs to use for simulation (Default None, uses all available.)
Does nothing if platform is set to CPU.

max_sire_threads: int
Maximum number of CPU threads to use within Sire (e.g. for I/O operations).
(Default None, divides the total available threads between the number of
GPUs multiplied by the oversubscription factor.)

opencl_platform_index: int
The OpenCL platform index to use when multiple OpenCL implementations are
available on the system.
Expand Down Expand Up @@ -529,6 +535,7 @@ def __init__(
self.platform = platform
self.max_threads = max_threads
self.max_gpus = max_gpus
self.max_sire_threads = max_sire_threads
self.opencl_platform_index = opencl_platform_index
self.oversubscription_factor = oversubscription_factor
self.replica_exchange = replica_exchange
Expand Down Expand Up @@ -1552,6 +1559,20 @@ def max_gpus(self, max_gpus):
"CPU platform requested but max_gpus set - ignoring max_gpus"
)

@property
def max_sire_threads(self):
return self._max_sire_threads

@max_sire_threads.setter
def max_sire_threads(self, max_sire_threads):
if max_sire_threads is not None:
try:
self._max_sire_threads = int(max_sire_threads)
except:
raise ValueError("'max_sire_threads' must be of type 'int'")
else:
self._max_sire_threads = None

@property
def opencl_platform_index(self):
return self._opencl_platform_index
Expand Down
59 changes: 53 additions & 6 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,43 @@ def __init__(self, system, config):
else:
self._gcmc_kwargs = None

# Limit the number of CPU threads available to Sire when running in parallel.
if self._is_gpu:
# First get the total number of threads that are available to Sire.
total_threads = _sr.legacy.Base.get_max_num_threads()

# Get the number of GPU devices.
devices = self._get_gpu_devices(
self._config.platform,
log=False,
)

# Work out the number of GPU workers.
num_gpu_workers = len(devices) * self._config.oversubscription_factor

# Adjust based on the maximum number of GPUs.
if self._config.max_gpus is not None:
num_gpu_workers = min(
self._config.max_gpus * self._config.oversubscription_factor,
num_gpu_workers,
)

# Divide the threads by the number of GPUs and oversubscribe factor.
sire_threads = max(1, total_threads // num_gpu_workers)

if self._config.max_sire_threads is not None:
if self._config.max_sire_threads > sire_threads:
_logger.warning(
f"Requested 'max_sire_threads' of {self._config.max_sire_threads} exceeds "
f"the calculated maximum of {sire_threads}"
)
sire_threads = self._config.max_sire_threads

_logger.info(f"Setting maximum Sire CPU threads to {sire_threads}")

# Update the maximum number of threads.
_sr.legacy.Base.set_max_num_threads(sire_threads)

def _check_space(self):
"""
Check if the system has a periodic space.
Expand Down Expand Up @@ -1423,7 +1460,7 @@ def _systems_are_same(system0, system1, num_gcmc_waters=0):
return True, None

@staticmethod
def _get_gpu_devices(platform, oversubscription_factor=1):
def _get_gpu_devices(platform, oversubscription_factor=1, log=True):
"""
Get list of available GPUs from CUDA_VISIBLE_DEVICES,
OPENCL_VISIBLE_DEVICES, or HIP_VISIBLE_DEVICES.
Expand All @@ -1437,6 +1474,9 @@ def _get_gpu_devices(platform, oversubscription_factor=1):
oversubscription_factor: int
The number of concurrent workers per GPU. Default is 1.

log: bool
Whether to log the available devices. Default is True.

Returns
--------

Expand All @@ -1459,23 +1499,30 @@ def _get_gpu_devices(platform, oversubscription_factor=1):
raise ValueError("CUDA_VISIBLE_DEVICES not set")
else:
available_devices = _os.environ.get("CUDA_VISIBLE_DEVICES").split(",")
_logger.info(f"CUDA_VISIBLE_DEVICES set to {available_devices}")
if log:
_logger.info(f"CUDA_VISIBLE_DEVICES set to {available_devices}")
elif platform == "opencl":
if _os.environ.get("OPENCL_VISIBLE_DEVICES") is None:
raise ValueError("OPENCL_VISIBLE_DEVICES not set")
else:
available_devices = _os.environ.get("OPENCL_VISIBLE_DEVICES").split(",")
_logger.info(f"OPENCL_VISIBLE_DEVICES set to {available_devices}")
if log:
_logger.info(f"OPENCL_VISIBLE_DEVICES set to {available_devices}")
elif platform == "hip":
if _os.environ.get("HIP_VISIBLE_DEVICES") is None:
raise ValueError("HIP_VISIBLE_DEVICES not set")
else:
available_devices = _os.environ.get("HIP_VISIBLE_DEVICES").split(",")
_logger.info(f"HIP_VISIBLE_DEVICES set to {available_devices}")
if log:
_logger.info(f"HIP_VISIBLE_DEVICES set to {available_devices}")

num_gpus = len(available_devices)
_logger.info(f"Number of GPUs available: {num_gpus}")
_logger.info(f"Number of concurrent workers per GPU: {oversubscription_factor}")

if log:
_logger.info(f"Number of GPUs available: {num_gpus}")
_logger.info(
f"Number of concurrent workers per GPU: {oversubscription_factor}"
)

return available_devices

Expand Down