diff --git a/pypots/utils/random.py b/pypots/utils/random.py index 9ec3af57..9543f801 100644 --- a/pypots/utils/random.py +++ b/pypots/utils/random.py @@ -7,7 +7,7 @@ import numpy as np import torch - +import random from .logging import logger RANDOM_SEED = 2204 @@ -25,6 +25,9 @@ def set_random_seed(random_seed: int = RANDOM_SEED) -> None: globals()["RANDOM_SEED"] = random_seed np.random.seed(random_seed) torch.manual_seed(random_seed) + random.seed(random_seed) + torch.cuda.manual_seed_all(random_seed) + # torch.backends.cudnn.deterministic = True logger.info(f"Have set the random seed as {random_seed} for numpy and pytorch.")