From 10d232e88e19979a8737aa2a6557bfbed8be8255 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 21 Mar 2024 11:07:39 -0400 Subject: [PATCH] Add deterministic config to `set_seed` (#29778) * Add deterministic config * Add note on slowdown * English fails me again --- src/transformers/trainer_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 0faf657387ba99..cd8bd79278e328 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -82,12 +82,15 @@ def enable_full_determinism(seed: int, warn_only: bool = False): tf.config.experimental.enable_op_determinism() -def set_seed(seed: int): +def set_seed(seed: int, deterministic: bool = False): """ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed). Args: - seed (`int`): The seed to set. + seed (`int`): + The seed to set. + deterministic (`bool`, *optional*, defaults to `False`): + Whether to use deterministic algorithms where available. Can slow down training. """ random.seed(seed) np.random.seed(seed) @@ -95,6 +98,8 @@ def set_seed(seed: int): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available + if deterministic: + torch.use_deterministic_algorithms(True) if is_torch_npu_available(): torch.npu.manual_seed_all(seed) if is_torch_xpu_available(): @@ -103,6 +108,8 @@ def set_seed(seed: int): import tensorflow as tf tf.random.set_seed(seed) + if deterministic: + tf.config.experimental.enable_op_determinism() def neftune_post_forward_hook(module, input, output):