From 0031049b5f7c8004d1b39bcfda5a298a93f407c1 Mon Sep 17 00:00:00 2001 From: Ivan Sorokin Date: Thu, 20 Jul 2023 21:25:00 +0100 Subject: [PATCH 1/3] improve from_pretrained for zero3 multi gpus mode --- src/transformers/modeling_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 45459ed22b0a6c..33572c79cf0c7b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -457,7 +457,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) return safe_load_file(checkpoint_file) try: - return torch.load(checkpoint_file, map_location="cpu") + if is_deepspeed_zero3_enabled() and torch.distributed.get_rank() > 0: + map_location = "meta" + else: + map_location = "cpu" + return torch.load(checkpoint_file, map_location=map_location) except Exception as e: try: with open(checkpoint_file) as f: From dfc1f4f5c39e025b0d6222451b8c11b83f7cf821 Mon Sep 17 00:00:00 2001 From: Ivan Sorokin Date: Fri, 21 Jul 2023 11:59:27 +0100 Subject: [PATCH 2/3] Add check if torch.distributed.is_initialized --- src/transformers/modeling_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 33572c79cf0c7b..c1dd249a475048 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,6 +29,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from packaging import version from torch import Tensor, nn from torch.nn import CrossEntropyLoss @@ -457,7 +458,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) return safe_load_file(checkpoint_file) try: - if is_deepspeed_zero3_enabled() and torch.distributed.get_rank() > 0: + if is_deepspeed_zero3_enabled() and dist.is_initialized() and dist.get_rank() > 0: map_location = "meta" else: map_location = "cpu" @@ -539,7 +540,7 @@ def load(module: nn.Module, state_dict, prefix=""): # manager gathers (unpartitions) the params of the current layer, then loads from # the state dict and then re-partitions them again with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: module._load_from_state_dict(*args) else: module._load_from_state_dict(*args) @@ -1479,7 +1480,7 @@ def _get_resized_embeddings( import deepspeed with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0): - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] else: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] @@ -1551,7 +1552,7 @@ def _get_resized_lm_head( params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: # Copy old lm head weights to new lm head if not transposed: new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[ From a5e2e56c557e7a93c00ea7376a05b5f816999bc0 Mon Sep 17 00:00:00 2001 From: Ivan Sorokin Date: Fri, 21 Jul 2023 20:24:38 +0100 Subject: [PATCH 3/3] Revert torch.distributed --- src/transformers/modeling_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c1dd249a475048..c7d31c09024e6b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,7 +29,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -import torch.distributed as dist from packaging import version from torch import Tensor, nn from torch.nn import CrossEntropyLoss @@ -458,7 +457,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) return safe_load_file(checkpoint_file) try: - if is_deepspeed_zero3_enabled() and dist.is_initialized() and dist.get_rank() > 0: + if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0: map_location = "meta" else: map_location = "cpu" @@ -540,7 +539,7 @@ def load(module: nn.Module, state_dict, prefix=""): # manager gathers (unpartitions) the params of the current layer, then loads from # the state dict and then re-partitions them again with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): - if dist.get_rank() == 0: + if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) else: module._load_from_state_dict(*args) @@ -1480,7 +1479,7 @@ def _get_resized_embeddings( import deepspeed with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0): - if dist.get_rank() == 0: + if torch.distributed.get_rank() == 0: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] else: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] @@ -1552,7 +1551,7 @@ def _get_resized_lm_head( params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): - if dist.get_rank() == 0: + if torch.distributed.get_rank() == 0: # Copy old lm head weights to new lm head if not transposed: new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[