Skip to content

Commit

Permalink
Add deterministic config to set_seed (#29778)
Browse files Browse the repository at this point in the history
* Add deterministic config

* Add note on slowdown

* English fails me again
  • Loading branch information
muellerzr authored Mar 21, 2024
1 parent f0bfb15 commit 10d232e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,24 @@ def enable_full_determinism(seed: int, warn_only: bool = False):
tf.config.experimental.enable_op_determinism()


def set_seed(seed: int):
def set_seed(seed: int, deterministic: bool = False):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
Args:
seed (`int`): The seed to set.
seed (`int`):
The seed to set.
deterministic (`bool`, *optional*, defaults to `False`):
Whether to use deterministic algorithms where available. Can slow down training.
"""
random.seed(seed)
np.random.seed(seed)
if is_torch_available():
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if deterministic:
torch.use_deterministic_algorithms(True)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
Expand All @@ -103,6 +108,8 @@ def set_seed(seed: int):
import tensorflow as tf

tf.random.set_seed(seed)
if deterministic:
tf.config.experimental.enable_op_determinism()


def neftune_post_forward_hook(module, input, output):
Expand Down

0 comments on commit 10d232e

Please sign in to comment.