diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index e3be2be4894d..cf533063d6ec 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -11,6 +11,7 @@ import deepspeed import deepspeed.comm as dist import deepspeed.runtime.utils as ds_utils +from deepspeed.runtime.utils import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec @@ -111,8 +112,11 @@ def cifar_trainset(fp16=False): def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123): - with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], - device_type=get_accelerator().device_name()): + if required_torch_version(min_version=2.1): + fork_kwargs = {"device_type": get_accelerator().device_name()} + else: + fork_kwargs = {} + with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs): ds_utils.set_random_seed(seed) # disable dropout