From bceb1a5a89b0c993ff032d15525d3e6541d5a74b Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:46:23 -0700 Subject: [PATCH] add expandable segments var (#775) * adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764) * adding new notebook for using fairchem models with NEBs * adding md tutorials * blocking code cells that arent needed or take too long * add expandable segments var * add note --------- Co-authored-by: Brook Wander <73855115+brookwander@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> --- src/fairchem/core/common/utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index cdc921ebc7..9b9c3a5170 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -46,6 +46,12 @@ from torch.nn.modules.module import _IncompatibleKeys +DEFAULT_ENV_VARS = { + # Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes). + # see https://pytorch.org/docs/stable/notes/cuda.html. + "PYTORCH_CUDA_ALLOC_CONF" : "expandable_segments:True", +} + # copied from https://stackoverflow.com/questions/33490870/parsing-yaml-in-python-detect-duplicated-keys # prevents loading YAMLS where keys have been overwritten class UniqueKeyLoader(yaml.SafeLoader): @@ -953,6 +959,12 @@ def check_traj_files(batch, traj_dir) -> bool: return all(fl.exists() for fl in traj_files) +def setup_env_vars() -> None: + for k, v in DEFAULT_ENV_VARS.items(): + os.environ[k] = v + logging.info(f"Setting env {k}={v}") + + @contextmanager def new_trainer_context(*, config: dict[str, Any], distributed: bool = False): from fairchem.core.common import distutils, gp_utils @@ -969,6 +981,7 @@ class _TrainingContext: trainer: BaseTrainer setup_logging() + setup_env_vars() original_config = config config = copy.deepcopy(original_config)